In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'
device = 'cuda'

In [None]:
segments = []
with open('./sampled_files.txt') as infile:
    for line in infile:
        segments.append(line.rstrip())  


In [None]:
levels = {}
vocab = set([';','{'])
for file in segments:
    with open(file) as input_file:
        level = []
        for row in input_file:
            level.append(list(row.rstrip()))
            vocab |= set(row.rstrip())
        levels[file] = level

v2i = {v:i for i,v in enumerate(sorted(vocab))}
i2v = {i:v for v,i in v2i.items()}


t_levels = {}
for file,level in levels.items():
    t_level = []
    for row in range(len(level[0])):
        column = []
        for col in range(len(level)):
            column.append(level[col][row])
        t_level.append(column)
    t_levels[file] = t_level
    
chunks = {}
for file,level in t_levels.items():
    chunk = level
    chunk = [''.join(c) for c in chunk]
    chunks[file] = ';'.join(chunk)

In [None]:

class RNNVAE(nn.Module):
    def __init__(self,e_layers,d_layers,
                 dropout,vocab_size,
                 enc_size,
                 dec_size,latent_size,
                 rnn_type):
        super(RNNVAE, self).__init__()
        self.latent_dim = latent_size
        self.enc_size = enc_size
        self.dec_size = dec_size
        self.embed = nn.Embedding(vocab_size,enc_size)
        self.encoder = rnn_type(enc_size,hidden_size=enc_size,
                              num_layers=e_layers,batch_first=True,
                                dropout=dropout, bidirectional=True)
        
        self.to_mu = nn.Linear(enc_size*2,latent_size)
        self.to_logvar = nn.Linear(enc_size*2,latent_size)
        
        self.decoder = rnn_type(enc_size+latent_size,hidden_size=dec_size,
                              num_layers=d_layers,dropout=dropout,
                                batch_first=True,)
        self.hidden_to_vocab = nn.Linear(dec_size,vocab_size)
        self.CE = nn.CrossEntropyLoss()
    def forward(self,input_sequence,start_tok):
        
        batch_size, seq_length = input_sequence.size(0), input_sequence.size(1)
        device = input_sequence.device
        embedded = self.embed(input_sequence)
        out , _= self.encoder(embedded)
        
        f_enc = out[:,0,:out.size()[-1]//2]
        r_enc = out[:,-1,out.size()[-1]//2:]
        enc = torch.cat((f_enc,r_enc),dim=-1)
        mu,logvar = self.to_mu(enc), self.to_logvar(enc)
        
        std = torch.exp(0.5 * logvar)
        z = torch.randn([batch_size,self.latent_dim],device=device)
        z = z * std + mu       
        z = z.repeat(1,seq_length).view(batch_size,seq_length,-1)
        dec_inp = torch.cat((
             self.embed(torch.ones(batch_size,1,dtype=torch.long,device=device)*start_tok),
            embedded[:,:-1,:]),dim=1)        
        dec_inp = torch.cat((dec_inp,z),dim=-1)
        dec_out, _ =self.decoder(dec_inp)
        dec_logit = self.hidden_to_vocab(dec_out)
        
        loss = self.CE(dec_logit.view(batch_size*seq_length,-1),input_sequence.view(-1))
        KL_div =  -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return loss, KL_div, dec_logit
    
class RNNVAE(nn.Module):
    def __init__(self,e_layers,d_layers,
                 dropout,vocab_size,
                 enc_size,
                 dec_size,latent_size,
                 rnn_type):
        super(RNNVAE, self).__init__()
        self.latent_dim = latent_size
        self.enc_size = enc_size
        self.dec_size = dec_size
        self.rnn_type = rnn_type
        self.embed = nn.Embedding(vocab_size,enc_size)
        self.encoder = rnn_type(enc_size,hidden_size=enc_size,
                              num_layers=e_layers,batch_first=True,
                                dropout=dropout, bidirectional=True)
        
        self.to_mu = nn.Linear(enc_size*2,latent_size)
        self.to_logvar = nn.Linear(enc_size*2,latent_size)
        self.latent_to_h = nn.Linear(latent_size,dec_size*d_layers)
        self.latent_to_c = nn.Linear(latent_size,dec_size*d_layers)
        
        self.decoder = rnn_type(enc_size,hidden_size=dec_size,
                              num_layers=d_layers,dropout=dropout,
                                batch_first=True,)
        self.hidden_to_vocab = nn.Linear(dec_size,vocab_size)
        self.CE = nn.CrossEntropyLoss()
    def forward(self,input_sequence,start_tok):
        
        batch_size, seq_length = input_sequence.size(0), input_sequence.size(1)
        device = input_sequence.device
        embedded = self.embed(input_sequence)
        out , _= self.encoder(embedded)
        
        f_enc = out[:,0,:out.size()[-1]//2]
        r_enc = out[:,-1,out.size()[-1]//2:]
        enc = torch.cat((f_enc,r_enc),dim=-1)
        mu,logvar = self.to_mu(enc), self.to_logvar(enc)
        
        std = torch.exp(0.5 * logvar)
        z = torch.randn([batch_size,self.latent_dim],device=device)
        z = z * std + mu        
        dec_h = self.latent_to_h(z)
        dec_h = dec_h.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()
        dec_c = self.latent_to_c(z)
        dec_c = dec_c.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()

        dec_inp = torch.cat((
             self.embed(torch.ones(batch_size,1,dtype=torch.long,device=device)*start_tok),
            embedded[:,:-1,:]),dim=1)        
        if self.rnn_type == nn.LSTM:
            dec_out, _ =self.decoder(dec_inp,(dec_h,dec_c))
        else:
            dec_out, _ =self.decoder(dec_inp,dec_h)
            
        dec_logit = self.hidden_to_vocab(dec_out)
        
        loss = self.CE(dec_logit.view(batch_size*seq_length,-1),input_sequence.view(-1))
        KL_div =  -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return loss, KL_div, dec_logit
    
e_layer = 3
d_layer = 2
e_size = 1024
l_size = 128
d_size = 256
dropout = 0.5
rnn_vae = RNNVAE(e_layer,d_layer,dropout,len(vocab),e_size,d_size,l_size,nn.GRU)    

#rnn_vae = torch.load('rnn_vae.model')
rnn_vae = torch.load(f'rnn_vaeEnc{e_layer}_{e_size}_Dec{d_layer}_{d_size}_Lat{l_size}.model').to(device)

In [None]:
chunks_by_game = {}
for name,chunk in chunks.items():
    name = name.split('/')[1]
    game = name.split('_')[0]
    id = name.split('_')[1].split('.')[0]
    if game not in chunks_by_game:
        chunks_by_game[game] = []
    chunks_by_game[game].append((name,chunk))

In [None]:
import torch.nn.functional as F
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size x vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
        
        logits[indices_to_remove] = filter_value
    return logits

def interpolate(self,input_sequence1,input_sequence2,val,start_tok,temp):
        
    batch_size, seq_length = input_sequence1.size(0), input_sequence1.size(1)
    device = input_sequence1.device
    embedded1 = self.embed(input_sequence1)
    out1 , _= self.encoder(embedded1)

    f_enc1 = out1[:,0,:out1.size()[-1]//2]
    r_enc1 = out1[:,-1,out1.size()[-1]//2:]
    enc1 = torch.cat((f_enc1,r_enc1),dim=-1)
    z1 = self.to_mu(enc1)
    
    
    embedded2 = self.embed(input_sequence2)
    out2 , _= self.encoder(embedded2)

    f_enc2 = out2[:,0,:out2.size()[-1]//2]
    r_enc2 = out2[:,-1,out2.size()[-1]//2:]
    enc2 = torch.cat((f_enc2,r_enc2),dim=-1)
    z2 = self.to_mu(enc2)
    
    z = z1*val + z2*(1-val)
    
    dec_h = self.latent_to_h(z)
    dec_h = dec_h.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()
    dec_c = self.latent_to_c(z)
    dec_c = dec_c.view(batch_size,-1, self.dec_size).permute(1,0,2).contiguous()

    dec_inp = self.embed(torch.ones(batch_size,1,dtype=torch.long,device=device)*start_tok)
    tokens = []    
    scores = []
    for ii in range(seq_length):   
        if self.rnn_type == nn.LSTM:
            dec_out, (dec_h,dec_c) =self.decoder(dec_inp,(dec_h,dec_c))
        else:
            dec_out, dec_h =self.decoder(dec_inp,dec_h)
            
        dec_logit = self.hidden_to_vocab(dec_out[:,-1,:]).view(batch_size,-1)
        dec_logit = top_k_top_p_filtering(dec_logit,top_k=5)
        next_token = torch.multinomial(F.softmax(dec_logit/temp, dim=-1), num_samples=1)
        
        tokens.append(next_token.detach())
        chosen_scores = []
        for ii,tok in enumerate(tokens[-1]):
            chosen_scores.append(dec_logit[ii,tok].item())
        scores.append(chosen_scores)
        dec_inp = self.embed(next_token).view(batch_size,1,-1)
        #dec_inp =torch.cat((dec_inp,self.embed(next_token)
        #    ),dim=1)
        

    return tokens,scores
def prettify(level):
    levelstr = ''
    level = level.split(';')
    
    level = [c for c in level if len(c) == len(level[0])]
    width = len(level)
    height = len(level[0]) 
    for y in range(height):
        for x in range(width):

            levelstr += level[x][y]
        levelstr += '\n'
    return levelstr


In [None]:
import itertools
from tqdm import tqdm_notebook
out_folder = f'gru_interpolation_{l_size}'

!mkdir {out_folder}
!rm {out_folder}/*
seq1s = []
seq2s = []
level1s = []
level2s = []
batch_size = 128
to_do = 10
acceptance_temp = 0.5

for game1, game2 in tqdm_notebook(list(itertools.combinations(chunks_by_game.keys(),2))):
    for level1 in tqdm_notebook(chunks_by_game[game1][:to_do]):
        for level2 in  tqdm_notebook(chunks_by_game[game2][:to_do]):
            import numpy as np
            import random
            seq1 = level1[1]
            seq2 = level2[1]

            seq1 = [v2i[t] for t in seq1] 
            seq2 = [v2i[t] for t in seq2]
            level1s.append(level1[0])
            level2s.append(level2[0])
            seq1s.append(seq1)
            seq2s.append(seq2)
            
            if len(seq1s) == batch_size:
                input_sequence1 =torch.tensor(seq1s).to(device)
                input_sequence2 =torch.tensor(seq2s).to(device)
                
                for interp in [1.0, 0.75, 0.5, 0.25, 0.0]:
                    samples = 5
                    best_scores = np.ones(input_sequence1.size()[0])*-np.inf
                    best_samples = [None]*input_sequence1.size()[0]
                    good = False
                    while not good:
                        samples -= 1
                        good = True
                        with torch.no_grad():
                            rnn_vae.eval()
                            encoded, scores = interpolate(rnn_vae,input_sequence1,input_sequence2,interp,v2i['{'],1.0)
                            
                        scores = np.array(scores)
                        scores = np.mean(scores,axis=0)
                        for ii,(sc,be) in enumerate(zip(scores,best_scores)):
                            if random.random() < np.exp( (sc - be)/acceptance_temp):
                                generation = []
                                for batch in encoded:
                                    token = batch[ii].item()
                                    generation.append(token)
                                pretty = prettify(''.join([i2v[t] for t in generation]))
                                if len(pretty) < 495:
                                    continue
                                best_scores[ii] = sc
                                best_samples[ii] = pretty
                        if np.mean(best_scores) == -np.inf:
                            good = False
                        good = good and samples <= 0
                    
                    for lvl1,lvl2,out_level in zip(level1s,level2s,best_samples):
                        if interp == 0.0:
                            name = f'{out_folder}/{lvl2}-100%.txt'
                        elif interp == 1.0:
                            name = f'{out_folder}/{lvl1}-100%.txt'
                        else:
                            name = f'{out_folder}/{lvl1}-{lvl2}-{int(100*interp)}%.txt'
                        with open(name,'w') as outfile:
                            outfile.write(out_level)


                
                seq1s = []
                seq2s = []
                level1s = []
                level2s = []

In [None]:
input_sequence1 =torch.tensor(seq1s).to(device)
input_sequence2 =torch.tensor(seq2s).to(device)

for interp in [1.0, 0.75, 0.5, 0.25, 0.0]:
    samples = 5
    best_scores = np.ones(input_sequence1.size()[0])*-np.inf
    best_samples = [None]*input_sequence1.size()[0]
    good = False
    while not good:
        samples -= 1
        good = True
        with torch.no_grad():
            rnn_vae.eval()
            encoded, scores = interpolate(rnn_vae,input_sequence1,input_sequence2,interp,v2i['{'],1.0)

        scores = np.array(scores)
        scores = np.mean(scores,axis=0)
        for ii,(sc,be) in enumerate(zip(scores,best_scores)):
            if random.random() < np.exp( (sc - be)/acceptance_temp):
                generation = []
                for batch in encoded:
                    token = batch[ii].item()
                    generation.append(token)
                pretty = prettify(''.join([i2v[t] for t in generation]))
                if len(pretty) < 495:
                    continue
                best_scores[ii] = sc
                best_samples[ii] = pretty
        if np.mean(best_scores) == -np.inf:
            good = False
        good = good and samples <= 0

    for lvl1,lvl2,out_level in zip(level1s,level2s,best_samples):
        if interp == 0.0:
            name = f'{out_folder}/{lvl2}-100%.txt'
        elif interp == 1.0:
            name = f'{out_folder}/{lvl1}-100%.txt'
        else:
            name = f'{out_folder}/{lvl1}-{lvl2}-{int(100*interp)}%.txt'
        with open(name,'w') as outfile:
            outfile.write(out_level)

In [None]:
seq1s = []
seq2s = []
level1s = []
level2s = []
batch_size = 64


for game in tqdm_notebook(chunks_by_game.keys()):
    for level1,level2 in tqdm_notebook(list(itertools.combinations(chunks_by_game[game][:to_do],2))):    
        
        seq1 = level1[1]
        seq2 = level2[1]

        seq1 = [v2i[t] for t in seq1] 
        seq2 = [v2i[t] for t in seq2]
        level1s.append(level1[0])
        level2s.append(level2[0])
        print(level1[0],level2[0])
        seq1s.append(seq1)
        seq2s.append(seq2)

        if len(seq1s) == batch_size:
            input_sequence1 =torch.tensor(seq1s).to(device)
            input_sequence2 =torch.tensor(seq2s).to(device)

            for interp in [0.75, 0.5, 0.25]:
                samples = 5
                best_scores = np.ones(input_sequence1.size()[0])*-np.inf
                best_samples = [None]*input_sequence1.size()[0]
                good = False
                while not good:
                    samples -= 1
                    good = True
                    with torch.no_grad():
                        rnn_vae.eval()
                        encoded, scores = interpolate(rnn_vae,input_sequence1,input_sequence2,interp,v2i['{'],1.0)

                    scores = np.array(scores)
                    scores = np.mean(scores,axis=0)
                    for ii,(sc,be) in enumerate(zip(scores,best_scores)):
                        if random.random() < np.exp( (sc - be)/acceptance_temp):
                            generation = []
                            for batch in encoded:
                                token = batch[ii].item()
                                generation.append(token)
                            pretty = prettify(''.join([i2v[t] for t in generation]))
                            if len(pretty) < 495:
                                continue
                            best_scores[ii] = sc
                            best_samples[ii] = pretty
                    if np.mean(best_scores) == -np.inf:
                        good = False
                    good = good and samples <= 0

                for lvl1,lvl2,out_level in zip(level1s,level2s,best_samples):
                    name = f'{out_folder}/{lvl1}-{lvl2}-{int(100*interp)}%.txt'
                    with open(name,'w') as outfile:
                        outfile.write(out_level)
                
            seq1s = []
            seq2s = []
            level1s = []
            level2s = []
                        
                        
                        
for interp in [0.75, 0.5, 0.25]:
    samples = 5
    best_scores = np.ones(input_sequence1.size()[0])*-np.inf
    best_samples = [None]*input_sequence1.size()[0]
    good = False
    while not good:
        samples -= 1
        good = True
        with torch.no_grad():
            rnn_vae.eval()
            encoded, scores = interpolate(rnn_vae,input_sequence1,input_sequence2,interp,v2i['{'],1.0)

        scores = np.array(scores)
        scores = np.mean(scores,axis=0)
        for ii,(sc,be) in enumerate(zip(scores,best_scores)):
            if random.random() < np.exp( (sc - be)/acceptance_temp):
                generation = []
                for batch in encoded:
                    token = batch[ii].item()
                    generation.append(token)
                pretty = prettify(''.join([i2v[t] for t in generation]))
                if len(pretty) < 495:
                    continue
                best_scores[ii] = sc
                best_samples[ii] = pretty
        if np.mean(best_scores) == -np.inf:
            good = False
        good = good and samples <= 0

    for lvl1,lvl2,out_level in zip(level1s,level2s,best_samples):
        name = f'{out_folder}/{lvl1}-{lvl2}-{int(100*interp)}%.txt'
        with open(name,'w') as outfile:
            outfile.write(out_level)

In [None]:
import numpy as np
import random
seq1 = level1[1]
seq2 = level2[1]

seq1 = [v2i[t] for t in seq1] 
seq2 = [v2i[t] for t in seq2]
input_sequence1 =torch.tensor([seq1]).to(device)
input_sequence2 =torch.tensor([seq2]).to(device)


samples = 5

rnn_vae.eval()
torch.no_grad()
acceptance_temp = 1.0
print(1)
print(prettify(''.join([i2v[t] for t in seq1])))
for interp in [0.75, 0.5, 0.25]:
    best_score = 0
    best_result = []
    for ii in range(samples):

        results = interpolate(rnn_vae,input_sequence1,input_sequence2,interp,v2i['{'],1.0)
        score = np.mean(results[1])
        if random.random() < np.exp( (score - best_score)/acceptance_temp):
            best_score = score
            best_result = [t.item() for t in results[0]]
    print(interp)
    print(prettify(''.join([i2v[t] for t in best_result])))
    
print(0)
print(prettify(''.join([i2v[t] for t in seq2])))