In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'
device = 'cuda'

In [2]:
segments = []
with open('./sampled_files.txt') as infile:
    for line in infile:
        segments.append(line.rstrip())  
levels = {}
vocab = set([';','{'])
for file in segments:
    with open(file) as input_file:
        level = []
        for row in input_file:
            level.append(list(row.rstrip()))
            vocab |= set(row.rstrip())
        levels[file] = level

v2i = {v:i for i,v in enumerate(sorted(vocab))}
i2v = {i:v for v,i in v2i.items()}


t_levels = {}
for file,level in levels.items():
    t_level = []
    for row in range(len(level[0])):
        column = []
        for col in range(len(level)):
            column.append(level[col][row])
        t_level.append(column)
    t_levels[file] = t_level
    
chunks = {}
for file,level in t_levels.items():
    chunk = level
    chunk = [''.join(c) for c in chunk]
    chunks[file] = ';'.join(chunk)
    
    
class RNNVAE(nn.Module):
    def __init__(self,e_layers,d_layers,
                 dropout,vocab_size,
                 enc_size,
                 dec_size,latent_size,
                 rnn_type):
        super(RNNVAE, self).__init__()
        self.latent_dim = latent_size
        self.enc_size = enc_size
        self.dec_size = dec_size
        self.embed = nn.Embedding(vocab_size,enc_size)
        self.encoder = rnn_type(enc_size,hidden_size=enc_size,
                              num_layers=e_layers,batch_first=True,
                                dropout=dropout, bidirectional=True)
        
        self.to_mu = nn.Linear(enc_size*2,latent_size)
        self.to_logvar = nn.Linear(enc_size*2,latent_size)
        
        self.decoder = rnn_type(enc_size+latent_size,hidden_size=dec_size,
                              num_layers=d_layers,dropout=dropout,
                                batch_first=True,)
        self.hidden_to_vocab = nn.Linear(dec_size,vocab_size)
        self.CE = nn.CrossEntropyLoss()
    def forward(self,input_sequence,start_tok):
        
        batch_size, seq_length = input_sequence.size(0), input_sequence.size(1)
        device = input_sequence.device
        embedded = self.embed(input_sequence)
        out , _= self.encoder(embedded)
        
        f_enc = out[:,0,:out.size()[-1]//2]
        r_enc = out[:,-1,out.size()[-1]//2:]
        enc = torch.cat((f_enc,r_enc),dim=-1)
        mu,logvar = self.to_mu(enc), self.to_logvar(enc)
        
        std = torch.exp(0.5 * logvar)
        z = torch.randn([batch_size,self.latent_dim],device=device)
        z = z * std + mu       
        z = z.repeat(1,seq_length).view(batch_size,seq_length,-1)
        dec_inp = torch.cat((
             self.embed(torch.ones(batch_size,1,dtype=torch.long,device=device)*start_tok),
            embedded[:,:-1,:]),dim=1)        
        dec_inp = torch.cat((dec_inp,z),dim=-1)
        dec_out, _ =self.decoder(dec_inp)
        dec_logit = self.hidden_to_vocab(dec_out)
        
        loss = self.CE(dec_logit.view(batch_size*seq_length,-1),input_sequence.view(-1))
        KL_div =  -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return loss, KL_div, dec_logit
    
class RNNVAE(nn.Module):
    def __init__(self,e_layers,d_layers,
                 dropout,vocab_size,
                 enc_size,
                 dec_size,latent_size,
                 rnn_type):
        super(RNNVAE, self).__init__()
        self.latent_dim = latent_size
        self.enc_size = enc_size
        self.dec_size = dec_size
        self.rnn_type = rnn_type
        self.embed = nn.Embedding(vocab_size,enc_size)
        self.encoder = rnn_type(enc_size,hidden_size=enc_size,
                              num_layers=e_layers,batch_first=True,
                                dropout=dropout, bidirectional=True)
        
        self.to_mu = nn.Linear(enc_size*2,latent_size)
        self.to_logvar = nn.Linear(enc_size*2,latent_size)
        self.latent_to_h = nn.Linear(latent_size,dec_size*d_layers)
        self.latent_to_c = nn.Linear(latent_size,dec_size*d_layers)
        
        self.decoder = rnn_type(enc_size,hidden_size=dec_size,
                              num_layers=d_layers,dropout=dropout,
                                batch_first=True,)
        self.hidden_to_vocab = nn.Linear(dec_size,vocab_size)
        self.CE = nn.CrossEntropyLoss()
    def forward(self,input_sequence,start_tok):
        
        batch_size, seq_length = input_sequence.size(0), input_sequence.size(1)
        device = input_sequence.device
        embedded = self.embed(input_sequence)
        out , _= self.encoder(embedded)
        
        f_enc = out[:,0,:out.size()[-1]//2]
        r_enc = out[:,-1,out.size()[-1]//2:]
        enc = torch.cat((f_enc,r_enc),dim=-1)
        mu,logvar = self.to_mu(enc), self.to_logvar(enc)
        
        std = torch.exp(0.5 * logvar)
        z = torch.randn([batch_size,self.latent_dim],device=device)
        z = z * std + mu        
        dec_h = self.latent_to_h(z)
        dec_h = dec_h.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()
        dec_c = self.latent_to_c(z)
        dec_c = dec_c.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()

        dec_inp = torch.cat((
             self.embed(torch.ones(batch_size,1,dtype=torch.long,device=device)*start_tok),
            embedded[:,:-1,:]),dim=1)        
        if self.rnn_type == nn.LSTM:
            dec_out, _ =self.decoder(dec_inp,(dec_h,dec_c))
        else:
            dec_out, _ =self.decoder(dec_inp,dec_h)
            
        dec_logit = self.hidden_to_vocab(dec_out)
        
        loss = self.CE(dec_logit.view(batch_size*seq_length,-1),input_sequence.view(-1))
        KL_div =  -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return loss, KL_div, dec_logit

e_layer = 3
d_layer = 2
e_size = 1024
l_size = 128
d_size = 256
dropout = 0.5
rnn_vae = RNNVAE(e_layer,d_layer,dropout,len(vocab),e_size,d_size,l_size,nn.GRU)    

#rnn_vae = torch.load('rnn_vae.model')
rnn_vae = torch.load(f'rnn_vaeEnc{e_layer}_{e_size}_Dec{d_layer}_{d_size}_Lat{l_size}.model').to(device).eval()

In [3]:
def encode(self,input_sequence):
        
    batch_size, seq_length = input_sequence.size(0), input_sequence.size(1)
    device = input_sequence.device
    embedded1 = self.embed(input_sequence)
    out1 , _= self.encoder(embedded1)

    f_enc1 = out1[:,0,:out1.size()[-1]//2]
    r_enc1 = out1[:,-1,out1.size()[-1]//2:]
    enc1 = torch.cat((f_enc1,r_enc1),dim=-1)
    return self.to_mu(enc1)

all_vectors = []
batch_size = 128
with torch.no_grad():
    sequences = []
    for chunk in chunks.values():
        sequences.append([v2i[t] for t in chunk])
        if len(sequences) == batch_size:
            sequences = torch.tensor(sequences).to(device)
            vectors = encode(rnn_vae,sequences)
            print(vectors.size())
            all_vectors += vectors.tolist()
            sequences = []
            
            
    
    sequences = torch.tensor(sequences).to(device)
    vectors = encode(rnn_vae,sequences)
    print(vectors.size())
    all_vectors += vectors.tolist()
    sequences = []
    
    
import numpy as np

vecs = np.array(all_vectors)

print(np.mean(vecs,axis=0),np.std(vecs,axis=0))    

torch.Size([128, 128])
torch.Size([128, 128])
torch.Size([128, 128])
torch.Size([116, 128])
[-0.01022032  0.00784601  0.00492877  0.00456151 -0.01184349 -0.00215228
 -0.01276376 -0.01348369  0.00779179  0.00437778 -0.00500527 -0.0084651
 -0.00245407  0.00383932  0.01418269  0.0024195   0.01178668 -0.00063515
 -0.00943421  0.00891193  0.00752145  0.01465803  0.04864955  0.00453371
  0.02830202 -0.01125048  0.01113066  0.00389368 -0.02286244 -0.03127266
 -0.00973023  0.00711678  0.01390911 -0.0222053   0.00427276 -0.00774809
  0.00828575  0.01135813 -0.00387075  0.02228152 -0.01952129 -0.00789322
  0.01812268 -0.00361058 -0.01626958  0.00238269 -0.00471098 -0.00591152
 -0.00629522 -0.00756995 -0.03471468 -0.00529516 -0.00334545 -0.00128012
 -0.00042518  0.01180841  0.00055108  0.01747989 -0.00290926  0.00668062
 -0.00463295  0.00506237  0.00045063  0.01647843  0.00881275 -0.00482568
  0.02341403  0.00292657 -0.00656581 -0.00956701  0.00365607  0.01249447
 -0.00062665  0.00735776  0.00331

In [4]:
import numpy as np

vecs = np.array(all_vectors)


import matplotlib.pyplot as plt
for d in range(vecs.shape[1]):
    plt.hist(vecs[:,d])
    plt.show()
    

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

In [5]:
import random
import torch.nn.functional as F
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size x vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
        
        logits[indices_to_remove] = filter_value
    return logits


mus  = np.mean(vecs,axis=0)
stds = np.std(vecs,axis=0)


def decode(self,z,seq_length,start_tok,temp):
    dec_h = self.latent_to_h(z)
    dec_h = dec_h.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()
    dec_c = self.latent_to_c(z)
    dec_c = dec_c.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()

    dec_inp = self.embed(torch.ones(batch_size,1,dtype=torch.long,device=device)*start_tok)
    tokens = []    
    scores = []
    for ii in range(seq_length):   
        if self.rnn_type == nn.LSTM:
            dec_out, (dec_h,dec_c) =self.decoder(dec_inp,(dec_h,dec_c))
        else:
            dec_out, dec_h =self.decoder(dec_inp,dec_h)
            
        dec_logit = self.hidden_to_vocab(dec_out[:,-1,:]).view(batch_size,-1)
        dec_logit = top_k_top_p_filtering(dec_logit,top_k=5)
        next_token = torch.multinomial(F.softmax(dec_logit/temp, dim=-1), num_samples=1)
        
        tokens.append(next_token.detach())
        chosen_scores = []
        for ii,tok in enumerate(tokens[-1]):
            chosen_scores.append(dec_logit[ii,tok].item())
        scores.append(chosen_scores)
        dec_inp = self.embed(next_token).view(batch_size,1,-1)        

    return tokens,scores

def prettify(level):
    levelstr = ''
    level = level.split(';')
    
    level = [c for c in level if len(c) == len(level[0])]
    width = len(level)
    height = len(level[0]) 
    for y in range(height):
        for x in range(width):

            levelstr += level[x][y]
        levelstr += '\n'
    return levelstr


from tqdm import tqdm_notebook
out_folder = f'gru_generation_{l_size}'

!mkdir {out_folder}
!rm {out_folder}/*

batch_size = 32
total_count = 5000

count = 0
sequence_length = len(chunk)
acceptance_temp = 1.0
with torch.no_grad():
    while total_count > 0:

        samples = 5
        best_scores = np.ones(batch_size)*-np.inf
        best_samples = [None]*batch_size
        good = False
        Zs = torch.tensor(np.random.normal(mus,stds,(batch_size,mus.shape[0])),dtype=torch.float32).to(device)
        for _ in range(samples):   
            good = True

            with torch.no_grad():
                rnn_vae.eval()
                encoded, scores = decode(rnn_vae,Zs,sequence_length,v2i['{'],1.0)

            scores = np.array(scores)
            scores = np.mean(scores,axis=0)
            for ii,(sc,be) in enumerate(zip(scores,best_scores)):
                if random.random() < np.exp( (sc - be)/acceptance_temp):
                    generation = []
                    for batch in encoded:
                        token = batch[ii].item()
                        generation.append(token)
                    pretty = prettify(''.join([i2v[t] for t in generation]))
                    if len(pretty) < 495:
                        continue
                    best_scores[ii] = sc
                    best_samples[ii] = pretty
            if np.mean(best_scores) == -np.inf:
                good = False
            good = good and samples <= 0
            
        for out_level in best_samples:
            if out_level is None:
                continue
            if len(out_level) < 495:
                continue
            
            total_count -= 1
            count += 1
            with open(f'{out_folder}/{count}.txt','w') as outfile:
                outfile.write(out_level)

rm: cannot remove 'gru_generation_128/*': No such file or directory
