In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
mario_path_files = !ls TheVGLC/Super\ Mario\ Bros/Paths/*txt
levels = []
vocab = set([';','{','}'])
for file in mario_path_files:
    with open(file) as input_file:
        level = []
        for row in input_file:
            level.append(list(row.rstrip()))
            vocab |= set(row.rstrip())
        levels.append(level)
print(vocab)

v2i = {v:i for i,v in enumerate(sorted(vocab))}
i2v = {i:v for v,i in v2i.items()}

t_levels = []
for level in levels:
    t_level = []
    for row in range(len(level[0])):
        column = []
        for col in range(len(level)):
            column.append(level[col][row])
        t_level.append(column)
    t_levels.append(t_level)
            
chunks = []
chunk_size = 32
for level in t_levels:
    for x in range(len(level)-chunk_size):
        chunk = level[x:x+chunk_size]
        chunk = [''.join(c) for c in chunk]
        chunks.append(';'.join(chunk)+'}')


        

{'?', 'S', 'E', 'o', '[', '}', 'B', 'b', ']', 'X', ';', '-', 'Q', '<', '{', '>', 'x'}


In [3]:
class RNNVAE(nn.Module):
    def __init__(self,e_layers,d_layers,
                 dropout,vocab_size,
                 hidden_size,latent_size,
                 rnn_type):
        super(RNNVAE, self).__init__()
        self.latent_dim = latent_size
        self.hidden_size = hidden_size
        self.embed = nn.Embedding(vocab_size,hidden_size)
        self.encoder = rnn_type(hidden_size,hidden_size=hidden_size,
                              num_layers=e_layers,batch_first=True,
                                dropout=dropout, bidirectional=True)
        
        self.to_mu = nn.Linear(hidden_size*2,latent_size)
        self.to_logvar = nn.Linear(hidden_size*2,latent_size)
        self.latent_to_h = nn.Linear(latent_size,hidden_size*d_layers)
        self.latent_to_c = nn.Linear(latent_size,hidden_size*d_layers)
        
        self.decoder = rnn_type(hidden_size,hidden_size=hidden_size,
                              num_layers=d_layers,dropout=dropout,
                                batch_first=True,)
        self.hidden_to_vocab = nn.Linear(hidden_size,vocab_size)
        self.CE = nn.CrossEntropyLoss()
    def forward(self,input_sequence,start_tok):
        
        batch_size, seq_length = input_sequence.size(0), input_sequence.size(1)
        device = input_sequence.device
        embedded = self.embed(input_sequence)
        out , _= self.encoder(embedded)
        
        f_enc = out[:,0,:out.size()[-1]//2]
        r_enc = out[:,-1,out.size()[-1]//2:]
        enc = torch.cat((f_enc,r_enc),dim=-1)
        mu,logvar = self.to_mu(enc), self.to_logvar(enc)
        
        std = torch.exp(0.5 * logvar)
        z = torch.randn([batch_size,self.latent_dim],device=device)
        z = z * std + mu        
        dec_h = self.latent_to_h(z)
        dec_h = dec_h.view(batch_size,-1, self.hidden_size).permute(1,0,2).contiguous()
        dec_c = self.latent_to_c(z)
        dec_c = dec_c.view(batch_size,-1, self.hidden_size).permute(1,0,2).contiguous()

        dec_inp = torch.cat((
             self.embed(torch.ones(batch_size,1,dtype=torch.long,device=device)*start_tok),
            embedded[:,:-1,:]),dim=1)        
        
        dec_out, _ =self.decoder(dec_inp,(dec_h,dec_c))
        dec_logit = self.hidden_to_vocab(dec_out)
        
        loss = self.CE(dec_logit.view(batch_size*seq_length,-1),input_sequence.view(-1))
        KL_div =  -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return loss, KL_div, dec_logit
     
rnn_vae = RNNVAE(2,4,0.5,len(vocab),512,32,nn.LSTM)   

In [None]:
from tqdm import tqdm_notebook
import random
import numpy as np
device = 'cuda'
epochs = 100
batch_size = 32
show_every = 16
rnn_vae.to(device)

optimizer = optim.Adam(rnn_vae.parameters(),lr=1e-4)
losses = []
annealing = 1.0
annealing_rate = 0.999
for epoch in tqdm_notebook(range(epochs)):
    random.shuffle(chunks)
    for batch in tqdm_notebook(range(0,len(chunks),batch_size)):
        rnn_vae.train()
        rnn_vae.zero_grad()
        batch = chunks[batch:batch+batch_size]
        batch = [[v2i[t] for t in c] for c in batch]
        input_sequence =torch.tensor(batch).to(device)
        loss, KL_div, dec_logit = rnn_vae(input_sequence,v2i['{'])
        loss_KL_div = loss + KL_div*(1.0-annealing)
        losses.append((loss_KL_div.item(),loss.item(),KL_div.item()))
        loss_KL_div.backward()
        optimizer.step()
        if len(losses) % show_every == 0:
            losses_ = np.array(losses)
            print(np.mean(losses_[-show_every*2:,0]),
                  np.mean(losses_[-show_every*2:,1]),
                  np.mean(losses_[-show_every*2:,2]))
    annealing *= annealing_rate

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=77), HTML(value='')))

0.13879427104257047 0.13879427104257047 0.013472948223352432


In [18]:
import torch.nn.functional as F
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size x vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
        
        logits[indices_to_remove] = filter_value
    return logits

def autoencode(self,input_sequence,start_tok):
        
    batch_size, seq_length = input_sequence.size(0), input_sequence.size(1)
    device = input_sequence.device
    embedded = self.embed(input_sequence)
    out , _= self.encoder(embedded)

    f_enc = out[:,0,:out.size()[-1]//2]
    r_enc = out[:,-1,out.size()[-1]//2:]
    enc = torch.cat((f_enc,r_enc),dim=-1)
    mu,logvar = self.to_mu(enc), self.to_logvar(enc)

    std = torch.exp(0.5 * logvar)
    z = torch.randn([batch_size,self.latent_dim],device=device)
    z = z * std + mu        
    dec_h = self.latent_to_h(z)
    dec_h = dec_h.view(batch_size,-1, self.hidden_size).permute(1,0,2).contiguous()
    dec_c = self.latent_to_c(z)
    dec_c = dec_c.view(batch_size,-1, self.hidden_size).permute(1,0,2).contiguous()

    dec_inp = self.embed(torch.ones(batch_size,1,dtype=torch.long,device=device)*start_tok)
    tokens = []    
    for _ in range(seq_length):
        dec_out, _ =self.decoder(dec_inp,(dec_h,dec_c))
        dec_logit = self.hidden_to_vocab(dec_out[:,-1,:]).view(batch_size,-1)
        dec_logit = top_k_top_p_filtering(dec_logit,top_p = 0.9)
        next_token = torch.multinomial(F.softmax(dec_logit, dim=-1), num_samples=1)
        
        tokens.append(next_token.detach())
        dec_inp =torch.cat((dec_inp,self.embed(next_token)
            ),dim=1)
        

    return tokens

rnn_vae.eval()
encoded = autoencode(rnn_vae,input_sequence,v2i['{'])
    

In [19]:

generation = []
b = 4
for batch in encoded:
    token = batch[b].item()
    generation.append(token)

print(''.join([i2v[t.item()] for t in input_sequence[b,:]]).replace(';','\n'))
print('------')
print(''.join([i2v[t] for t in generation]).replace(';','\n'))

------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
-----Q---Q--xX
------------xX
-----Q---Q--xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
------------xX
-----------xxX}
------
----------xx-X
---------xx--X
---------x<[[X
--------xx>]]X
--------x----X
--------x---EX
--------x----X
--------x----X
---------x-<[X
----------x>]X
----------x--X
-----------x-X
------------xX
------------xX
------------xX
-----------xxX
----------xx-X
---------xx--X
--------xx---X
--------x<[[[X
-------xx>]]]X
-------x-----X
-------x-----X
-------x-----X
-------x-S---X
-------xxS---X
-------x-----X
------xx------
-----xxX------
----xx--X-----
---xx---------
--xx----------}
