In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
device = 'cuda'

In [2]:
segments = []
with open('./sampled_files.txt') as infile:
    for line in infile:
        segments.append(line.rstrip())  
levels = {}
vocab = set([';','{'])
for file in segments:
    with open(file) as input_file:
        level = []
        for row in input_file:
            level.append(list(row.rstrip()))
            vocab |= set(row.rstrip())
        levels[file] = level

v2i = {v:i for i,v in enumerate(sorted(vocab))}
i2v = {i:v for v,i in v2i.items()}


t_levels = {}
for file,level in levels.items():
    t_level = []
    for row in range(len(level[0])):
        column = []
        for col in range(len(level)):
            column.append(level[col][row])
        t_level.append(column)
    t_levels[file] = t_level
    
chunks = {}
for file,level in t_levels.items():
    chunk = level
    chunk = [''.join(c) for c in chunk]
    chunks[file] = ';'.join(chunk)
    
chunks_by_game = {}
for name,chunk in chunks.items():
    name = name.split('/')[1]
    game = name.split('_')[0]
    id = name.split('_')[1].split('.')[0]
    if game not in chunks_by_game:
        chunks_by_game[game] = []
    chunks_by_game[game].append((name,chunk)) 
    
    
class RNNVAE(nn.Module):
    def __init__(self,e_layers,d_layers,
                 dropout,vocab_size,
                 enc_size,
                 dec_size,latent_size,
                 rnn_type):
        super(RNNVAE, self).__init__()
        self.latent_dim = latent_size
        self.enc_size = enc_size
        self.dec_size = dec_size
        self.rnn_type = rnn_type
        self.embed = nn.Embedding(vocab_size,enc_size)
        self.encoder = rnn_type(enc_size,hidden_size=enc_size,
                              num_layers=e_layers,batch_first=True,
                                dropout=dropout, bidirectional=True)
        
        self.to_mu = nn.Linear(enc_size*2,latent_size)
        self.to_logvar = nn.Linear(enc_size*2,latent_size)
        self.latent_to_h = nn.Linear(latent_size,dec_size*d_layers)
        self.latent_to_c = nn.Linear(latent_size,dec_size*d_layers)
        
        self.decoder = rnn_type(enc_size,hidden_size=dec_size,
                              num_layers=d_layers,dropout=dropout,
                                batch_first=True,)
        self.hidden_to_vocab = nn.Linear(dec_size,vocab_size)
        self.CE = nn.CrossEntropyLoss()
    def forward(self,input_sequence,start_tok):
        
        batch_size, seq_length = input_sequence.size(0), input_sequence.size(1)
        device = input_sequence.device
        embedded = self.embed(input_sequence)
        out , _= self.encoder(embedded)
        
        f_enc = out[:,0,:out.size()[-1]//2]
        r_enc = out[:,-1,out.size()[-1]//2:]
        enc = torch.cat((f_enc,r_enc),dim=-1)
        mu,logvar = self.to_mu(enc), self.to_logvar(enc)
        
        std = torch.exp(0.5 * logvar)
        z = torch.randn([batch_size,self.latent_dim],device=device)
        z = z * std + mu        
        dec_h = self.latent_to_h(z)
        dec_h = dec_h.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()
        dec_c = self.latent_to_c(z)
        dec_c = dec_c.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()

        dec_inp = torch.cat((
             self.embed(torch.ones(batch_size,1,dtype=torch.long,device=device)*start_tok),
            embedded[:,:-1,:]),dim=1)        
        if self.rnn_type == nn.LSTM:
            dec_out, _ =self.decoder(dec_inp,(dec_h,dec_c))
        else:
            dec_out, _ =self.decoder(dec_inp,dec_h)
            
        dec_logit = self.hidden_to_vocab(dec_out)
        
        loss = self.CE(dec_logit.view(batch_size*seq_length,-1),input_sequence.view(-1))
        KL_div =  -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return loss, KL_div, dec_logit

e_layer = 3
d_layer = 2
e_size = 1024
l_size = 256
d_size = 256
dropout = 0.5
rnn_vae = RNNVAE(e_layer,d_layer,dropout,len(vocab),e_size,d_size,l_size,nn.GRU)    

#rnn_vae = torch.load('rnn_vae.model')
rnn_vae = torch.load(f'rnn_vaeEnc{e_layer}_{e_size}_Dec{d_layer}_{d_size}_Lat{l_size}.model').to(device)

In [3]:
def encode(self,input_sequence):
        
    batch_size, seq_length = input_sequence.size(0), input_sequence.size(1)
    device = input_sequence.device
    embedded1 = self.embed(input_sequence)
    out1 , _= self.encoder(embedded1)

    f_enc1 = out1[:,0,:out1.size()[-1]//2]
    r_enc1 = out1[:,-1,out1.size()[-1]//2:]
    enc1 = torch.cat((f_enc1,r_enc1),dim=-1)
    return self.to_mu(enc1)

vectors_by_game = {}
for game in chunks_by_game:
    all_vectors = []
    batch_size = 64
    with torch.no_grad():
        sequences = []
        for _,chunk in chunks_by_game[game]:
            sequences.append([v2i[t] for t in chunk])
            if len(sequences) == batch_size:
                sequences = torch.tensor(sequences).to(device)
                vectors = encode(rnn_vae,sequences)
                print(vectors.size())
                all_vectors += vectors.tolist()
                sequences = []



        sequences = torch.tensor(sequences).to(device)
        vectors = encode(rnn_vae,sequences)
        all_vectors += vectors.tolist()
        sequences = []
    
    
    import numpy as np

    vecs = np.array(all_vectors)

    vectors_by_game[game] = (np.mean(vecs,axis=0),np.std(vecs,axis=0),vecs)
    print(game,np.mean(vecs,axis=0),np.std(vecs,axis=0))

torch.Size([64, 256])
Castlevania [ 3.79131149e-02  1.77220646e-02  7.54070901e-04  5.80996552e-03
  1.04839727e-02  5.70420440e-03 -2.11501286e-02  3.31363779e-03
 -2.75188881e-02 -2.42385951e-02  8.51389291e-03  6.34562393e-03
 -1.53539544e-02 -3.63421267e-03 -1.26621332e-02 -2.15514952e-02
 -1.92219712e-02  3.10079748e-03 -2.28435013e-02  1.19770540e-02
 -8.73266731e-03 -9.29113133e-03 -4.89917850e-02 -3.23177342e-02
 -6.94690126e-02  4.27111727e-02 -5.97018576e-03 -1.77072258e-02
  2.91777560e-02 -3.55779454e-04  1.26903355e-02  7.61119948e-03
 -1.26401404e-02  3.83285959e-03  3.13505379e-02 -1.30819730e-02
 -6.55744592e-03  3.99904203e-02  4.77617192e-02 -4.79470344e-03
  1.61342756e-03 -7.55909090e-03 -5.77282493e-03  2.99171880e-02
  1.81315997e-02  2.43804349e-02  3.46764446e-03  2.14586954e-02
 -3.15514649e-03 -1.84116280e-02 -2.51059369e-02  3.80348328e-02
 -6.29632801e-02 -3.00150378e-03  5.66196596e-02  6.70380758e-03
  2.27257492e-03  5.16735789e-02  1.53588846e-02  3.9872

torch.Size([64, 256])
MegaMan [ 2.36397005e-03 -9.98714364e-03  2.40639074e-02 -7.78865999e-03
  1.37838562e-02  1.45394784e-02 -6.57527861e-04  1.47198440e-02
 -5.02019934e-03  5.90417659e-03  2.15936916e-02  9.75889818e-03
  2.26936524e-02 -6.40075569e-03  1.57411770e-02  1.06531399e-02
  1.14554804e-02 -1.69658734e-03 -1.26023042e-02  4.51994199e-03
  8.05559117e-03  1.09394998e-02 -1.71174039e-02  2.12299305e-02
  1.22486014e-02 -1.54068003e-02  4.88370974e-03  3.72987887e-03
  1.55764685e-03  1.18410361e-03  1.09820970e-02  1.26998538e-02
  3.17890844e-03 -2.92892854e-02  2.34855581e-02 -3.66236848e-02
 -1.11237550e-02 -6.60627050e-03 -2.24976854e-02 -1.74725646e-02
  7.59126227e-03 -1.72803868e-02  9.56196066e-03  3.43342460e-03
  1.86262507e-02 -2.24948150e-03  1.11639232e-02 -1.02904830e-02
 -5.82528423e-03 -8.07155669e-03  1.27410419e-02 -2.99527918e-02
  1.59284148e-02 -3.33690430e-03 -4.03063431e-02  2.93651868e-02
 -1.18086065e-02 -1.02572726e-02  1.96870028e-02 -2.74852846

torch.Size([64, 256])
Ninja Gaiden [ 2.68671468e-02  1.65012614e-02  2.40626425e-02  3.08748021e-02
  1.19572179e-02 -2.12278616e-02 -1.22985199e-02 -2.61594787e-02
 -3.80150537e-02 -3.03710026e-02 -1.35133657e-02 -3.02386398e-03
 -5.29222253e-03 -2.68709420e-02 -2.87383459e-02 -2.43067877e-02
 -1.35103764e-02  1.47204218e-02 -1.69626559e-03  4.51973647e-02
 -4.41335520e-03 -1.90362838e-02  2.19341675e-02 -2.66469457e-02
 -5.02853354e-02 -2.12128807e-02  1.02489448e-02 -8.51827790e-03
 -1.40273682e-02  3.00484595e-02 -1.34737896e-02 -2.55861713e-02
 -2.50684080e-02  2.47799907e-02 -2.48616152e-02  5.60497941e-03
 -1.11202523e-04  9.23391669e-03 -2.09666664e-02  9.12790084e-04
 -2.56895528e-03  1.05685689e-02 -7.46308855e-03 -2.12980687e-02
  3.72778497e-02 -5.90293520e-03 -1.59789727e-02  1.20410011e-02
 -1.05230854e-02 -9.11124563e-03 -2.75921210e-02  7.11553109e-02
 -1.48970664e-02  5.84487809e-02  4.39591617e-02 -4.35461984e-03
 -5.10890058e-03  2.45873474e-03  2.08771674e-02 -1.342

In [4]:
import numpy as np

vecs = np.array(all_vectors)


import matplotlib.pyplot as plt
for d in range(vecs.shape[1]):
    plt.hist(vecs[:,d])
    plt.show()
    

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

<Figure size 640x480 with 1 Axes>

In [5]:
import random
import torch.nn.functional as F
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size x vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
        
        logits[indices_to_remove] = filter_value
    return logits


mus  = np.mean(vecs,axis=0)
stds = np.std(vecs,axis=0)


def decode(self,z,seq_length,start_tok,temp):
    dec_h = self.latent_to_h(z)
    dec_h = dec_h.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()
    dec_c = self.latent_to_c(z)
    dec_c = dec_c.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()

    dec_inp = self.embed(torch.ones(batch_size,1,dtype=torch.long,device=device)*start_tok)
    tokens = []    
    scores = []
    for ii in range(seq_length):   
        if self.rnn_type == nn.LSTM:
            dec_out, (dec_h,dec_c) =self.decoder(dec_inp,(dec_h,dec_c))
        else:
            dec_out, dec_h =self.decoder(dec_inp,dec_h)
            
        dec_logit = self.hidden_to_vocab(dec_out[:,-1,:]).view(batch_size,-1)
        dec_logit = top_k_top_p_filtering(dec_logit,top_k=5)
        next_token = torch.multinomial(F.softmax(dec_logit/temp, dim=-1), num_samples=1)
        
        tokens.append(next_token.detach())
        chosen_scores = []
        for ii,tok in enumerate(tokens[-1]):
            chosen_scores.append(dec_logit[ii,tok].item())
        scores.append(chosen_scores)
        dec_inp = self.embed(next_token).view(batch_size,1,-1)        

    return tokens,scores

def prettify(level):
    levelstr = ''
    level = level.split(';')
    
    level = [c for c in level if len(c) == len(level[0])]
    width = len(level)
    height = len(level[0]) 
    for y in range(height):
        for x in range(width):

            levelstr += level[x][y]
        levelstr += '\n'
    return levelstr


from tqdm import tqdm_notebook



for game in tqdm_notebook(vectors_by_game):
    game_f = game.replace(' ','')
    mus,stds,_ = vectors_by_game[game]
    out_folder = f'gru_generation_{l_size}_{game_f}'

    !mkdir {out_folder}
    !rm {out_folder}/*

    batch_size = 32
    total_count = 1000

    count = 0
    sequence_length = len(chunk)
    acceptance_temp = 1.0
    with torch.no_grad():
        while total_count > 0:
            print(total_count)
            samples = 5
            best_scores = np.ones(batch_size)*-np.inf
            best_samples = [None]*batch_size
            good = False
            Zs = torch.tensor(np.random.normal(mus,stds,(batch_size,mus.shape[0])),dtype=torch.float32).to(device)
            for _ in range(samples):   
                good = True

                with torch.no_grad():
                    rnn_vae.eval()
                    encoded, scores = decode(rnn_vae,Zs,sequence_length,v2i['{'],1.0)

                scores = np.array(scores)
                scores = np.mean(scores,axis=0)
                for ii,(sc,be) in enumerate(zip(scores,best_scores)):
                    if random.random() < np.exp( (sc - be)/acceptance_temp):
                        generation = []
                        for batch in encoded:
                            token = batch[ii].item()
                            generation.append(token)
                        pretty = prettify(''.join([i2v[t] for t in generation]))
                        if len(pretty) < 495:
                            continue
                        best_scores[ii] = sc
                        best_samples[ii] = pretty
                if np.mean(best_scores) == -np.inf:
                    good = False
                good = good and samples <= 0

            for out_level in best_samples:
                if out_level is None:
                    continue
                if len(out_level) < 495:
                    continue

                total_count -= 1
                count += 1
                with open(f'{out_folder}/{count}.txt','w') as outfile:
                    outfile.write(out_level)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

rm: cannot remove 'gru_generation_256_Castlevania/*': No such file or directory
1000
968
936
904
872
840
808
776
744
712
680
648
616
584
552
520
488
456
424
392
360
328
296
264
232
200
168
136
104
72
40
8
rm: cannot remove 'gru_generation_256_Mario/*': No such file or directory
1000
968
936
904
872
840
808
776
744
712
680
648
616
584
552
520
488
456
424
392
360
328
296
264
232
200
168
136
104
72
40
8
rm: cannot remove 'gru_generation_256_MegaMan/*': No such file or directory
1000
968
936
904
872
840
808
776
744
712
680
648
616
584
552
520
488
456
424
392
360
328
296
264
232
200
168
136
104
72
40
8
rm: cannot remove 'gru_generation_256_Metroid/*': No such file or directory
1000
968
936
904
872
840
808
776
745
713
681
649
617
585
553
521
489
457
425
393
361
329
297
265
233
201
169
137
105
73
41
9
rm: cannot remove 'gru_generation_256_NinjaGaiden/*': No such file or directory
1000
968
936
904
872
840
808
776
744
712
680
648
616
584
552
520
488
456
424
392
360
328
296
264
232
200
169
137
1