In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'

In [None]:
segments = !ls by_screen/* -d


In [None]:
g2segs = {}
for segment in segments:
    g = segment.split('/')[1].split('_')[0]
    if g not in g2segs:
        g2segs[g] = []
    g2segs[g].append(segment)
max_segs = 0
for g in g2segs:
    max_segs = max(len(g2segs[g]),max_segs)
import random
segments = []
running_total = 0
for g in g2segs:
    num_segs = len(g2segs[g])
    sampling_prob = max_segs /num_segs 
    random.shuffle(g2segs[g])
    
    t = 0
    for seg in g2segs[g]:
        t += sampling_prob
        while t > 0:
            t -= 1
            segments.append(seg)
    print(num_segs,len(segments)-running_total)
    running_total = len(segments)
    

In [None]:
levels = []
vocab = set([';','{'])
for file in segments:
    with open(file) as input_file:
        level = []
        for row in input_file:
            level.append(list(row.rstrip()))
            vocab |= set(row.rstrip())
        levels.append(level)
print(vocab)

v2i = {v:i for i,v in enumerate(sorted(vocab))}
i2v = {i:v for v,i in v2i.items()}

t_levels = []
for level in levels:
    t_level = []
    for row in range(len(level[0])):
        column = []
        for col in range(len(level)):
            column.append(level[col][row])
        t_level.append(column)
    t_levels.append(t_level)
            

In [None]:

chunks = []
for level in t_levels:
    chunk = level
    chunk = [''.join(c) for c in chunk]
    chunks.append(';'.join(chunk))
        

In [1]:

class RNNVAE(nn.Module):
    def __init__(self,e_layers,d_layers,
                 dropout,vocab_size,
                 enc_size,
                 dec_size,latent_size,
                 rnn_type):
        super(RNNVAE, self).__init__()
        self.latent_dim = latent_size
        self.enc_size = enc_size
        self.dec_size = dec_size
        self.embed = nn.Embedding(vocab_size,enc_size)
        self.encoder = rnn_type(enc_size,hidden_size=enc_size,
                              num_layers=e_layers,batch_first=True,
                                dropout=dropout, bidirectional=True)
        
        self.to_mu = nn.Linear(enc_size*2,latent_size)
        self.to_logvar = nn.Linear(enc_size*2,latent_size)
        
        self.decoder = rnn_type(enc_size+latent_size,hidden_size=dec_size,
                              num_layers=d_layers,dropout=dropout,
                                batch_first=True,)
        self.hidden_to_vocab = nn.Linear(dec_size,vocab_size)
        self.CE = nn.CrossEntropyLoss()
    def forward(self,input_sequence,start_tok):
        
        batch_size, seq_length = input_sequence.size(0), input_sequence.size(1)
        device = input_sequence.device
        embedded = self.embed(input_sequence)
        out , _= self.encoder(embedded)
        
        f_enc = out[:,0,:out.size()[-1]//2]
        r_enc = out[:,-1,out.size()[-1]//2:]
        enc = torch.cat((f_enc,r_enc),dim=-1)
        mu,logvar = self.to_mu(enc), self.to_logvar(enc)
        
        std = torch.exp(0.5 * logvar)
        z = torch.randn([batch_size,self.latent_dim],device=device)
        z = z * std + mu       
        z = z.repeat(1,seq_length).view(batch_size,seq_length,-1)
        dec_inp = torch.cat((
             self.embed(torch.ones(batch_size,1,dtype=torch.long,device=device)*start_tok),
            embedded[:,:-1,:]),dim=1)        
        dec_inp = torch.cat((dec_inp,z),dim=-1)
        dec_out, _ =self.decoder(dec_inp)
        dec_logit = self.hidden_to_vocab(dec_out)
        
        loss = self.CE(dec_logit.view(batch_size*seq_length,-1),input_sequence.view(-1))
        KL_div =  -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return loss, KL_div, dec_logit
    
class RNNVAE(nn.Module):
    def __init__(self,e_layers,d_layers,
                 dropout,vocab_size,
                 enc_size,
                 dec_size,latent_size,
                 rnn_type):
        super(RNNVAE, self).__init__()
        self.latent_dim = latent_size
        self.enc_size = enc_size
        self.dec_size = dec_size
        self.rnn_type = rnn_type
        self.embed = nn.Embedding(vocab_size,enc_size)
        self.encoder = rnn_type(enc_size,hidden_size=enc_size,
                              num_layers=e_layers,batch_first=True,
                                dropout=dropout, bidirectional=True)
        
        self.to_mu = nn.Linear(enc_size*2,latent_size)
        self.to_logvar = nn.Linear(enc_size*2,latent_size)
        self.latent_to_h = nn.Linear(latent_size,dec_size*d_layers)
        self.latent_to_c = nn.Linear(latent_size,dec_size*d_layers)
        
        self.decoder = rnn_type(enc_size,hidden_size=dec_size,
                              num_layers=d_layers,dropout=dropout,
                                batch_first=True,)
        self.hidden_to_vocab = nn.Linear(dec_size,vocab_size)
        self.CE = nn.CrossEntropyLoss()
    def forward(self,input_sequence,start_tok):
        
        batch_size, seq_length = input_sequence.size(0), input_sequence.size(1)
        device = input_sequence.device
        embedded = self.embed(input_sequence)
        out , _= self.encoder(embedded)
        
        f_enc = out[:,0,:out.size()[-1]//2]
        r_enc = out[:,-1,out.size()[-1]//2:]
        enc = torch.cat((f_enc,r_enc),dim=-1)
        mu,logvar = self.to_mu(enc), self.to_logvar(enc)
        
        std = torch.exp(0.5 * logvar)
        z = torch.randn([batch_size,self.latent_dim],device=device)
        z = z * std + mu        
        dec_h = self.latent_to_h(z)
        dec_h = dec_h.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()
        dec_c = self.latent_to_c(z)
        dec_c = dec_c.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()

        dec_inp = torch.cat((
             self.embed(torch.ones(batch_size,1,dtype=torch.long,device=device)*start_tok),
            embedded[:,:-1,:]),dim=1)        
        if self.rnn_type == nn.LSTM:
            dec_out, _ =self.decoder(dec_inp,(dec_h,dec_c))
        else:
            dec_out, _ =self.decoder(dec_inp,dec_h)
            
        dec_logit = self.hidden_to_vocab(dec_out)
        
        loss = self.CE(dec_logit.view(batch_size*seq_length,-1),input_sequence.view(-1))
        KL_div =  -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return loss, KL_div, dec_logit
e_layer = 3
d_layer = 2
e_size = 1024
l_size = 256
d_size = 256
dropout = 0.5
rnn_vae = RNNVAE(e_layer,d_layer,dropout,len(vocab),e_size,d_size,l_size,nn.GRU)  

#rnn_vae = torch.load('rnn_vae.model')
try:
    rnn_vae = torch.load(f'rnn_vaeEnc{e_layer}_{e_size}_Dec{d_layer}_{d_size}_Lat{l_size}.model')
except:
    pass

optimizer = optim.Adam(rnn_vae.parameters(),lr=1e-4)
losses = []

NameError: name 'nn' is not defined

In [None]:
optimizer = optim.Adam(rnn_vae.parameters(),lr=1e-5)
from tqdm import tqdm_notebook
import random
import numpy as np
device = 'cuda'
epochs = 100
batch_size = 32
show_every = 16
rnn_vae.to(device)
#rnn_vae.encoder.dropout = 0.10
#rnn_vae.decoder.dropout = 0.10
rnn_vae.train()
warmup = 98
ratio = 1
use_annealing = False
for _ in range(10):
    
    
    annealing_rate = 0.99995
    annealing = annealing_rate
    if len(losses) > 0:
        annealing = 1-ratio*np.mean(losses_[-show_every*2:,1])/np.mean(losses_[-show_every*2:,2])
        if annealing < 0:
            annealing =1
    if not use_annealing:
        annealing = annealing_rate
    for epoch in tqdm_notebook(range(epochs)):
        random.shuffle(chunks)
        for batch in tqdm_notebook(range(0,len(chunks),batch_size)):
            rnn_vae.train()
            rnn_vae.zero_grad()
            batch = chunks[batch:batch+batch_size]
            batch = [[v2i[t] for t in c] for c in batch]
            input_sequence =torch.tensor(batch).to(device)
            loss, KL_div, dec_logit = rnn_vae(input_sequence,v2i['{'])
            loss_KL_div = loss + KL_div*(1.0-annealing)
            losses.append((loss_KL_div.item(),loss.item(),KL_div.item()))
            loss_KL_div.backward()
            torch.nn.utils.clip_grad_norm_(rnn_vae.parameters(), 1)
            optimizer.step()
            if len(losses) % show_every == 0:
                losses_ = np.array(losses)
                print(epoch,np.mean(losses_[-show_every*2:,0]),
                      np.mean(losses_[-show_every*2:,1]),
                      np.mean(losses_[-show_every*2:,2]))
                annealing = 1-ratio*np.mean(losses_[-show_every*2:,1])/np.mean(losses_[-show_every*2:,2])
                if annealing < 0:
                    annealing =annealing_rate
                if not use_annealing:
                    annealing = annealing_rate
                
        if epoch >= warmup:
            annealing *= annealing_rate

In [None]:
import os
os.system(f'mv rnn_vaeEnc{e_layer}_{e_size}_Dec{d_layer}_{d_size}_Lat{l_size}.model rnn_vaeEnc{e_layer}_{e_size}_Dec{d_layer}_{d_size}_Lat{l_size}.model_old')
torch.save(rnn_vae, f'rnn_vaeEnc{e_layer}_{e_size}_Dec{d_layer}_{d_size}_Lat{l_size}.model')

In [None]:
import torch.nn.functional as F
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size x vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
        
        logits[indices_to_remove] = filter_value
    return logits

def autoencode(self,input_sequence,start_tok,temp):
        
    batch_size, seq_length = input_sequence.size(0), input_sequence.size(1)
    device = input_sequence.device
    embedded = self.embed(input_sequence)
    out , _= self.encoder(embedded)

    f_enc = out[:,0,:out.size()[-1]//2]
    r_enc = out[:,-1,out.size()[-1]//2:]
    enc = torch.cat((f_enc,r_enc),dim=-1)
    z = self.to_mu(enc)
    dec_h = self.latent_to_h(z)
    dec_h = dec_h.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()
    dec_c = self.latent_to_c(z)
    dec_c = dec_c.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()

    dec_inp = self.embed(torch.ones(batch_size,1,dtype=torch.long,device=device)*start_tok)
    tokens = []    
    scores = []
    for ii in range(seq_length):   
        if self.rnn_type == nn.LSTM:
            dec_out, (dec_h,dec_c) =self.decoder(dec_inp,(dec_h,dec_c))
        else:
            dec_out, dec_h =self.decoder(dec_inp,dec_h)
            
        dec_logit = self.hidden_to_vocab(dec_out[:,-1,:]).view(batch_size,-1)
        dec_logit = top_k_top_p_filtering(dec_logit)
        next_token = torch.multinomial(F.softmax(dec_logit/temp, dim=-1), num_samples=1)
        
        tokens.append(next_token.detach())
        chosen_scores = []
        for ii,tok in enumerate(tokens[-1]):
            chosen_scores.append(dec_logit[ii,tok].item())
        scores.append(chosen_scores)
        dec_inp = self.embed(next_token).view(batch_size,1,-1)
        #dec_inp =torch.cat((dec_inp,self.embed(next_token)
        #    ),dim=1)
        

    return tokens,scores


def prettify(level):
    levelstr = ''
    level = level.split(';')
    
    level = [c for c in level if len(c) == len(level[0])]
    width = len(level)
    height = len(level[0]) 
    for y in range(height):
        for x in range(width):

            levelstr += level[x][y]
        levelstr += '\n'
    return levelstr
import numpy as np
samples = 3
best_scores = np.ones(input_sequence.size()[0])*-np.inf
best_samples = [None]*input_sequence.size()[0]
for _ in range(samples):
    with torch.no_grad():
        rnn_vae.eval()
        encoded, scores = autoencode(rnn_vae,input_sequence,v2i['{'],0.8)
    scores = np.array(scores)
    scores = np.mean(scores,axis=0)
    for ii,(sc,be) in enumerate(zip(scores,best_scores)):
        if sc > be:
            best_scores[ii] = sc
            generation = []
            for batch in encoded:
                token = batch[ii].item()
                generation.append(token)
            best_samples[ii] = prettify(''.join([i2v[t] for t in generation]))
    print(best_scores)

for b,g in enumerate(best_samples):
    print(b)
    print('original')
    print(prettify(''.join([i2v[t.item()] for t in input_sequence[b,:]])).replace('x','+'))
    print('generated')
    print(g.replace('x','+'))
