# Lab5 - Conditional Sequence-to-sequence VAE

In [4]:
import string
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import matplotlib.pyplot as plt
from tqdm import trange
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu

%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
# Compute BLEU-4 score
def compute_bleu(output, reference):
    cc = SmoothingFunction()
    if len(reference) == 3:
        weights = (0.33, 0.33, 0.33)
    else:
        weights = (0.25, 0.25, 0.25, 0.25)
    return sentence_bleu([reference], output, weights=weights, smoothing_function=cc.method1)

# Compute Gaussian score
def Gaussian_score(words):
    words_list = []
    score = 0
    yourpath = 'data/train.txt'#should be your directory of train.txt
    with open(yourpath,'r') as fp:
        for line in fp:
            word = line.split(' ')
            word[3] = word[3].strip('\n')
            words_list.extend([word])
        for t in words:
            for i in words_list:
                if t == i:
                    score += 1
    return score/len(words)

In [6]:
# Define characteristic to vector dictionary for embedding
SOS_token = 0
EOS_token = 1
ch_to_ix = {ch: i+2 for i, ch in enumerate(string.ascii_lowercase)}
ix_to_ch = {i+2: ch for i, ch in enumerate(string.ascii_lowercase)}

# Define condition to dictionary
conditions = ['sp', 'tp', 'pg', 'p']
cond_to_ix = {conditions[i]: i for i in range(len(conditions))}

# Load the datasets
def load_data(filename):
    data = []
    with open(filename, 'r') as f:
        for line in f.readlines():
            data.append(line.rstrip().split(' '))
    return data


train_data = load_data('./data/train.txt')
train_cond = ['sp', 'tp', 'pg', 'p']

test_data = load_data('./data/test.txt')
test_cond = [
    ['sp', 'p'],
    ['sp', 'pg'],
    ['sp', 'tp'],
    ['sp', 'tp'],
    ['p', 'tp'],
    ['sp', 'pg'],
    ['p', 'sp'],
    ['pg', 'sp'],
    ['pg', 'p'],
    ['pg', 'tp'],
] # Given

In [7]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size, num_conds, cond_size):
        super(EncoderRNN, self).__init__()
        
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.cond_size = cond_size
        
        # Embedding the input
        self.word_embedding = nn.Embedding(input_size, hidden_size)
        self.cond_embedding = nn.Embedding(num_conds, cond_size)
        
        # LSTM
        self.lstm = nn.LSTM(hidden_size, hidden_size)
        
        # Fully connected layer to generate mean and var for hidden and cell
        self.hidden_mean = nn.Linear(hidden_size, latent_size)
        self.hidden_logvar = nn.Linear(hidden_size, latent_size)
        self.cell_mean = nn.Linear(hidden_size, latent_size)
        self.cell_logvar = nn.Linear(hidden_size, latent_size)
        
    def forward(self, input, condition, hidden):
        condition = self.cond_embedding(condition).view(1, 1, -1)
        hidden = (torch.cat((hidden[0], condition), dim=2), torch.cat((hidden[1], condition), dim=2))
        output = self.word_embedding(input).view(input.size(0), 1, -1)
        output, hidden = self.lstm(output, hidden)
        
        mean = (self.hidden_mean(hidden[0]), self.cell_mean(hidden[1]))
        logvar = (self.hidden_logvar(hidden[0]), self.cell_logvar(hidden[1]))
        
        # Reparameterize
        std = (torch.exp(logvar[0]), torch.exp(logvar[1]))
        eps = (torch.rand_like(std[0]), torch.rand_like(std[1]))
        latent = (mean[0]+eps[0]*std[0], mean[1]+eps[1]*std[1])
        
        return latent, mean, logvar
        
    def init_hidden(self):
        return (torch.zeros(1, 1, self.hidden_size-self.cond_size, device=device),
                torch.zeros(1, 1, self.hidden_size-self.cond_size, device=device))
    
    def condition_embedding(self, condition):
        return self.cond_embedding(condition)
    
    
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        self.embedding = nn.Embedding(output_size, hidden_size)
        
        self.lstm = nn.LSTM(hidden_size, hidden_size)
        
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, input, hidden):
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.lstm(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden
    
    def init_hidden(self):
        return (torch.zeros(1, 1, self.hidden_size, device=device),
                torch.zeros(1, 1, self.hidden_size, device=device))

In [8]:
def decode(decoder, hidden, target_tensor, teacher_forcing):
    decoder_input = torch.tensor([[SOS_token]], device=device)
    decoder_hidden = hidden
    
    target_length = target_tensor.size(0)
    
    outputs = []
    if teacher_forcing:
        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            outputs.append(decoder_output)
            decoder_input = target_tensor[di]
    else:
        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            outputs.append(decoder_output)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()
            
            if decoder_input.item() == EOS_token:
                break
                
    outputs = torch.cat(outputs, dim=0)
    return outputs
    

def train_pair(input_data_tensor, input_cond_tensor, target_data_tensor, target_cond_tensor,
               encoder, encoder_optimizer, decoder, decoder_optimizer, kl_weight,
               criterion, teacher_forcing_ratio=0.5, max_length=MAX_LENGTH):
    
    encoder_hidden = encoder.init_hidden()
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    input_length = input_data_tensor.size(0)
    target_length = target_data_tensor.size(0)
    
    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
    
    loss = 0
    
    #----------sequence to sequence part for encoder----------#
    latent, mean, logvar = encoder(input_data_tensor, input_cond_tensor, encoder_hidden)
    print(latent)
    exit(0)
    
    #----------sequence to sequence part for decoder----------#
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    
    decoder_cond = encoder.condition_embedding(target_cond_tensor).view(1, 1, -1)
    decoder_hidden = (torch.cat((latent[0], decoder_cond), dim=2), torch.cat((latent[1], decoder_cond), dim=2))
    output = decode(decoder, decoder_hidden, target_data_tensor, teacher_forcing=use_teacher_forcing)
    
    crossEntropy_loss = criterion(output, target_data_tensor[:output.size(0)].view(-1))
    kl_loss = torch.sum(0.5*(-logvar[0]+(mean[0]**2)+torch.exp(logvar[0])-1)) + torch.sum(0.5*(-logvar[1]+(mean[1]**2)+torch.exp(logvar[1])-1))
    (crossEntropy_loss+kl_weight*kl_loss).backward()
    
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return crossEntropy_loss.item(), kl_loss.item()
    
    
def predict(input_data_tensor, input_cond_tensor, target_data_tensor, target_cond_tensor, 
            encoder, decoder, criterion, max_length=MAX_LENGTH):
    
    encoder_hidden = encoder.init_hidden()
    input_data_length = input_data_tensor.size(0)
    
    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
    
    with torch.no_grad():
        
        latent, mean, logvar = encoder(input_data_tensor, input_cond_tensor, encoder_hidden)
        decoder_input = torch.tensor([[SOS_token]], device=device)
        decoder_cond = encoder.cond_embedding(target_cond_tensor).view(1, 1, -1)
        decoder_hidden = (torch.cat((latent[0], decoder_cond), dim=2), torch.cat((latent[1], decoder_cond), dim=2))
        
        crossEntropy_loss = 0
        kl_loss = torch.sum(0.5*(-logvar[0]+(mean[0]**2)+torch.exp(logvar[0])-1)) + torch.sum(0.5*(-logvar[1]+(mean[1]**2)+torch.exp(logvar[1])-1))
        
        pred = ''
        for di in range(target_data_tensor.size(0)):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()
            
            crossEntropy_loss += criterion(decoder_output, target_data_tensor[di])
            if decoder_input == EOS_token:
                break
            
            pred += ix_to_ch[decoder_input.item()]
    
    return pred, crossEntropy_loss, kl_loss
    

def train_data_combination(n=4):
    for i in range(n):
        for j in range(i+1, n):
            yield (i, j)
            yield (j, i)
            

def get_kl_anealing_func(method, slope=0.0002, max_val=1, period=None):
    if method == 'monotonic':
        return lambda iteration: min(iteration*slope, max_val)
    elif method == 'cyclical' and period:
        return lambda iteration: min((iteration%period)*slope, max_val)
    else:
        return None
    

def train(encoder, decoder, train_dataset, train_condition, test_dataset, test_condition,
          kl_anealing_func, teacher_forcing_ratio, num_epochs=20, learning_rate=1e-02):
    
    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
    
    criterion = nn.CrossEntropyLoss()
    
    train_result_list = {'cross_entropy': [], 'kl': []}
    test_result_list = {'cross_entropy': [], 'kl': [], 'score': []}
    iteration = 0
    for epoch in range(num_epochs):

        # Train
        encoder.train()
        decoder.train()
        
        shuffled_idx = np.random.permutation(len(train_dataset))
        train_loss = {'cross_entropy': 0, 'kl': 0}
        for i in trange(len(train_dataset)):
            data = train_dataset[shuffled_idx[i]]
            cond = train_condition
            for a, b in train_data_combination(4):
                input_data_tensor = torch.tensor([ch_to_ix[ch] for ch in data[a]], device=device).view(-1, 1)
                input_cond_tensor = torch.tensor(cond_to_ix[train_cond[a]], device=device)
                target_data_tensor = torch.tensor([ch_to_ix[ch] for ch in data[b]]+[EOS_token], device=device).view(-1, 1)
                target_cond_tensor = torch.tensor(cond_to_ix[train_cond[b]], device=device)
                loss = train_pair(input_data_tensor, input_cond_tensor, target_data_tensor, target_cond_tensor,
                           encoder, encoder_optimizer, decoder, decoder_optimizer, kl_anealing_func(iteration),
                           criterion, teacher_forcing_ratio)
                train_loss['cross_entropy'] += loss[0]
                train_loss['kl'] += loss[1]
                
        train_loss['cross_entropy'] /= (len(train_dataset)*12)
        train_loss['kl'] /= (len(train_dataset)*12)
        train_result_list['cross_entropy'].append(train_loss['cross_entropy'])
        train_result_list['kl'].append(train_loss['kl'])
                
        # Test
        encoder.eval()
        decoder.eval()
        
        test_loss = {'cross_entropy': 0, 'kl': 0}
        test_score = 0
        for i in trange(len(test_dataset)):
            input_data_tensor = torch.tensor([ch_to_ix[ch] for ch in test_dataset[i][0]], device=device).view(-1, 1)
            input_cond_tensor = torch.tensor(cond_to_ix[test_condition[i][0]], device=device)
            target_data_tensor = torch.tensor([ch_to_ix[ch] for ch in test_dataset[i][1]], device=device).view(-1, 1)
            target_cond_tensor = torch.tensor(cond_to_ix[test_condition[i][1]], device=device)
            pred, crossEntropy_loss, kl_loss = predict(input_data_tensor, input_cond_tensor, target_data_tensor, target_cond_tensor, encoder, decoder, criterion)
            test_loss['cross_entropy'] += crossEntropy_loss
            test_loss['kl'] += kl_loss
            score = compute_bleu(pred, test_dataset[i][1])
            test_score += score
        
        test_loss['cross_entropy'] /= len(test_dataset)
        test_loss['kl'] /= len(test_dataset)
        test_result_list['cross_entropy'].append(test_loss['cross_entropy'])
        test_result_list['kl'].append(test_loss['kl'])
        test_score /= len(test_dataset)
        test_result_list['score'].append(test_score)
        
        # Print the result
        print('=====================================================')
        print('Epoch: {} / {}'.format(epoch+1, num_epochs))
        print()
        print('Train Cross Entropy Loss: {}'.format(train_loss['cross_entropy']))
        print('Train KL Loss: {}'.format(train_loss['kl']))
        print()
        print('Test Cross Entropy Loss: {}'.format(test_loss['cross_entropy']))
        print('Test KL Loss: {}'.format(test_loss['kl']))
        print('Test BLEU-4 Score: {}'.format(test_score))
        print()
        print()
         
    return encoder, decoder, train_result_list, test_result_list


def evaluate(test_dataset, test_condition, encoder, decoder, max_length=MAX_LENGTH):
    
    score = 0
    for i in range(len(test_dataset)):
        input_data_tensor = torch.tensor([ch_to_ix[ch] for ch in test_dataset[i][0]], device=device).view(-1, 1)
        input_cond_tensor = torch.tensor(cond_to_ix[test_condition[i][0]], device=device)
        target_cond_tensor = torch.tensor(cond_to_ix[test_condition[i][1]], device=device)
        pred, _, _ = predict(input_data_tensor, input_cond_tensor, target_cond_tensor, encoder, decoder, criterion)
        score += compute_bleu(pred, test_dataset[i][1])
        
        print('Input:      {}'.format(test_dataset[i][0]))
        print('Target:     {}'.format(test_dataset[i][1]))
        print('Prediction: {}'.format(pred))
        print()
    print('Average BLEU-4 Score: {}'.format(score/len(test_dataset)))

In [31]:
# Define hyperparameters
hidden_size = 256
latent_size = 128
vocab_size = 28
num_conds = 4
cond_size = 32
teacher_forcing_ratio = 0.5
learning_rate = 0.05
MAX_LENGTH = 40
num_epochs = 20

# Encoder & Decoder
encoder = EncoderRNN(vocab_size, hidden_size, latent_size, num_conds, cond_size).to(device)
decoder = DecoderRNN(latent_size+cond_size, vocab_size).to(device)

# Train
encoder, decoder, train_loss_list, test_score_list = train(encoder, decoder,
                                                           train_data, train_cond, test_data, test_cond,
                                                           kl_anealing_func=get_kl_anealing_func('cyclical', period=10000),
                                                           teacher_forcing_ratio=teacher_forcing_ratio,
                                                           num_epochs=num_epochs)

  0%|          | 1/1227 [00:00<03:30,  5.82it/s]

(tensor([[[ 0.0671,  0.4449,  0.0715,  0.6617,  0.2406,  0.4819,  0.9446,
           0.5709,  0.4930,  0.6902,  0.1864,  0.3946,  1.0738,  0.4902,
           0.0428,  0.3898,  0.8097,  0.1008,  0.1545,  0.0890,  0.7683,
           0.3007,  0.9745,  0.4894,  0.5402,  0.9945,  0.3667,  0.7451,
           0.2960,  0.9450,  0.4693,  0.6595,  0.6082,  0.6110,  0.1822,
           0.4577,  0.5219,  0.9737,  0.0662,  0.2775,  0.9171,  0.1050,
           0.6765,  0.7380,  0.7103,  0.6825,  0.3935,  0.1138,  0.7962,
          -0.1546,  0.2812,  0.0600,  0.7602,  1.0026,  0.8182,  0.5284,
           0.3488,  0.5304,  0.8471,  0.4774,  0.7818,  0.8644,  0.4314,
           0.4265,  0.4708,  0.1536,  0.3532,  0.0349,  0.6228,  0.9451,
           0.7726,  0.1605,  0.8209,  1.0051, -0.0506,  0.7538,  0.8767,
           0.6414,  0.6804,  0.3210,  0.7893,  0.2976,  0.9821,  0.7653,
           0.5381,  0.5445, -0.0137,  0.4260,  0.4036,  0.2460,  0.3226,
           0.1396,  0.7200,  0.6522,  0.0095,  0.6

  0%|          | 2/1227 [00:00<03:28,  5.86it/s]

(tensor([[[ 0.7680,  0.3549,  0.5028,  0.1547,  0.0922,  0.8049,  0.1847,
           0.5825,  0.2767,  0.3484,  0.0909,  0.2842,  0.6325,  0.1189,
           0.6327,  0.7280,  0.4446,  0.5737,  0.5235,  0.6244,  0.7580,
           0.6156,  0.0308,  0.3480,  0.6364,  0.8769,  0.9182,  0.2792,
           0.4498,  0.4578,  0.6850,  0.6476,  0.1685, -0.0429,  0.0241,
           0.4015,  0.2159,  0.6336,  0.5072,  0.2178,  0.0996,  0.8189,
           0.9469,  0.1432,  0.1960,  0.0495,  1.0482,  0.3656,  0.3159,
           0.0658,  0.5632,  0.2299,  0.1357,  0.9011,  0.6470,  0.7376,
           0.4004,  0.4832,  0.9095,  0.2230,  0.2210,  0.6531, -0.0307,
           0.6362,  0.5972,  0.8923,  0.8350,  0.9406,  0.8791,  0.0263,
           0.2602,  0.2891,  0.2747,  0.6332,  0.6878,  0.0962,  0.3933,
           0.2982,  0.1845,  0.0191,  0.5451,  0.9604,  0.4479,  0.1165,
           1.0621,  1.1266,  0.0545,  0.7332,  1.0919,  0.0024,  0.6712,
           0.1804,  0.4679,  0.1224,  0.2150,  0.7

  0%|          | 3/1227 [00:00<03:24,  5.99it/s]

(tensor([[[ 0.3253,  0.7728,  0.3570,  0.3486,  0.7646,  0.1962,  0.3555,
           0.7259,  0.2489,  0.5704,  0.0670,  0.9386,  0.3691,  0.4683,
           0.0975,  0.2674,  0.7711,  0.4711,  0.8033,  0.3403,  0.2830,
           0.6403,  0.1058,  0.1948, -0.0490,  0.3301,  0.1480,  1.0152,
           0.8976, -0.0948,  1.0885,  0.9040,  0.3957,  0.7178, -0.0285,
           0.4403,  0.2117,  0.7110,  0.6140,  0.7065,  0.4663,  0.4885,
           0.1188,  0.8344,  0.5950,  1.0009,  0.2057,  0.3101,  0.2611,
           0.2009,  0.8345,  0.2493,  0.4215,  0.8004,  0.1309,  0.4573,
           0.5290,  0.2236,  0.5531,  0.2492,  0.5541,  0.2923,  0.0829,
           0.1473,  1.0735,  0.6606,  0.0947,  0.3745,  0.8993,  0.0630,
           0.8240,  0.5735,  0.2246,  0.2231,  0.7008,  0.5533,  0.2779,
          -0.0925,  0.0883,  0.7635,  0.4138, -0.0891,  1.1269,  0.1871,
           0.8647,  0.7564,  0.8593,  0.4369,  0.1742,  0.2773,  0.8018,
           0.8974,  0.5389,  0.1654,  0.2403,  0.8

  0%|          | 5/1227 [00:00<03:13,  6.31it/s]

(tensor([[[ 1.3214e-01,  9.7027e-01,  8.2956e-01,  1.0504e+00,  4.9056e-01,
           8.5345e-01,  5.3599e-01,  4.4356e-01,  1.0622e-01,  5.0621e-01,
           6.2689e-01,  4.8349e-01,  7.0785e-01,  9.6100e-01, -4.8855e-02,
           1.0050e+00,  1.1397e+00,  8.7541e-01,  6.6318e-02,  7.6745e-01,
           9.4856e-01,  1.0119e+00,  2.6905e-01,  7.9176e-01,  2.7368e-01,
           3.0968e-01, -6.5988e-02,  1.0406e-01,  5.3766e-01,  1.5813e-01,
           9.2007e-02,  2.1798e-01,  5.8926e-01,  5.6492e-01,  2.4358e-01,
           7.2465e-01,  5.4561e-01,  3.7328e-01,  6.8162e-01,  3.5923e-01,
           8.6942e-01,  8.4643e-01,  6.2576e-01,  5.1140e-02,  2.6707e-01,
           9.3025e-01,  3.6479e-01,  5.9712e-01,  8.7982e-01,  3.8028e-01,
           2.0095e-01,  4.3417e-01,  3.3324e-02,  6.8470e-01,  1.1528e-01,
           6.9484e-01,  3.8207e-01,  3.7698e-01,  7.2841e-01,  5.5364e-01,
           8.3576e-01,  2.0311e-01,  7.9302e-01,  3.8941e-01,  2.8868e-02,
           6.9134e-01,  

  0%|          | 6/1227 [00:00<03:09,  6.44it/s]

(tensor([[[ 1.0642,  0.4585, -0.1576,  0.8748,  0.3438,  0.5785,  0.4480,
           0.9920,  0.0205,  0.4777,  0.5111,  0.6716,  0.5765,  0.1342,
           0.0185,  0.4998,  0.6605,  0.2165,  0.6952,  0.3485,  0.0124,
           0.7751,  0.3758,  0.7005,  0.4285,  0.8379,  0.0477,  0.6121,
           0.7839,  0.6203,  0.8353,  0.3868,  0.0255,  0.4893,  0.3059,
           0.6785,  0.6488,  0.1271, -0.0276,  0.6777,  1.0399,  0.6080,
           0.7179,  0.7781,  1.0762,  0.3125,  0.8651,  0.2773,  0.9967,
           0.7221,  0.5892,  0.8202,  0.8919,  0.5643,  0.9629,  0.9655,
           0.4605,  0.4029,  0.6918,  0.8993,  0.9908,  0.4527,  0.4801,
           0.3553,  0.0845,  1.0536,  0.5481,  0.9851,  0.1088,  0.9708,
           0.9054,  0.0597,  0.4258,  0.9280,  0.2369,  0.8682,  0.5465,
           0.3792,  0.6152,  0.0406,  0.0630,  0.2454,  0.7427,  0.5125,
           0.7823,  0.1065,  0.2264,  0.2554,  0.3763, -0.1015,  0.8508,
           0.1030,  0.6450,  0.8635,  0.0868,  0.6

  1%|          | 7/1227 [00:01<03:02,  6.70it/s]

(tensor([[[-7.5895e-02,  8.9742e-01,  6.2283e-01,  9.0041e-01,  3.4510e-01,
           4.0370e-01,  7.5266e-01,  2.6552e-01, -4.4590e-02,  4.6996e-01,
           1.6536e-01,  5.7681e-01,  7.7070e-01,  5.6553e-01,  7.4192e-01,
           4.5782e-01, -3.8655e-02,  1.1103e-01,  4.5460e-01,  5.3951e-01,
           2.9179e-02,  9.0966e-02,  1.6428e-01,  1.7264e-01,  6.9883e-01,
           8.9718e-02,  1.0519e+00,  8.9457e-01,  5.0465e-01, -5.3212e-03,
           2.1201e-01,  7.3110e-01,  1.5129e-01,  7.4744e-01,  3.0070e-01,
           4.2997e-01,  3.3470e-01,  9.0357e-01,  6.3256e-01,  3.7233e-01,
           1.5434e-01,  1.2498e-01,  2.8300e-01,  9.1544e-01,  8.0946e-01,
           7.8220e-02,  7.0061e-01,  3.6941e-01,  3.7195e-01,  4.1793e-02,
           6.8365e-01,  1.9640e-01,  1.0363e+00,  5.3560e-01,  8.7266e-01,
           3.2540e-01,  7.6416e-01, -1.8624e-01,  3.9275e-01,  2.9924e-01,
           2.5498e-01,  4.6831e-01,  6.6069e-01,  2.6711e-01,  1.2984e-01,
           1.0335e+00,  

  1%|          | 9/1227 [00:01<03:11,  6.36it/s]

(tensor([[[ 0.8456,  0.2333,  0.3707,  0.5167,  0.2356,  0.9671,  0.4967,
           0.1799,  0.4768,  0.3253,  0.7855,  0.0818,  0.7030,  0.9802,
           0.2600,  0.7970,  0.4682,  0.6571,  0.2411,  0.5006,  0.6535,
           0.2880,  0.5311,  0.0856,  0.4702,  0.2075,  0.2781,  0.3865,
           0.1018,  0.6929,  0.8273,  0.0774,  0.1829,  0.8412,  0.4058,
           0.6258,  0.4027,  0.3780,  0.1803,  0.1955,  0.9969,  1.1388,
           0.1549, -0.0420,  0.5043,  0.5563,  0.2312,  0.1449,  1.1319,
           0.1309,  0.4271,  0.0922,  0.0591,  0.7192,  0.7228,  0.0769,
           0.2010,  0.7316,  0.8407,  0.1032,  1.0356,  0.5828,  0.7772,
           0.3532,  0.1002,  0.5814,  0.5842,  0.0822,  0.4425, -0.1160,
           0.2087, -0.1262,  0.4918,  0.5794,  0.8200,  0.5497, -0.0857,
           0.6940,  0.6164,  0.8980,  0.8511,  0.5802,  0.3808,  0.0299,
           0.3250,  1.0823,  0.4650,  0.6864,  0.5393,  0.0190,  0.4050,
           0.9254,  0.7335,  0.5312,  0.7705,  0.8

  1%|          | 10/1227 [00:01<03:08,  6.45it/s]

(tensor([[[ 1.0208,  0.6540,  0.6822,  0.1320,  0.2220,  0.9563,  1.1821,
           0.8649,  0.1685,  0.8889,  0.2664,  0.0789,  0.2962,  0.3809,
           0.3642,  1.0000,  0.6986,  0.2690,  0.3038,  0.2541,  0.9281,
           0.9166,  0.7037,  0.6516,  0.4137,  0.3635,  0.0235,  0.3683,
           0.7415,  0.6463,  0.4095,  0.8940,  0.5026,  0.1848,  0.0792,
           0.2547,  0.0806,  0.2155,  0.9211,  0.0213,  0.8499,  0.2437,
           0.5700,  0.3852,  0.6693,  0.2174,  0.5876,  0.5595,  0.3212,
           0.0138,  0.7789,  0.0918,  0.7408,  0.9550,  0.6573,  0.8317,
           0.3011,  0.7150,  0.6899,  0.2073,  0.7449,  0.6116, -0.0232,
           0.6020,  0.3116,  1.1134,  0.2300,  0.8388,  0.5116,  0.3365,
           0.2006,  0.5765,  0.1810,  1.0525,  0.9657,  1.0475,  0.6479,
           0.3144,  0.1649,  0.8009,  0.2710,  0.3308,  1.1698, -0.0181,
           0.6946,  0.8931,  0.0961,  0.7834,  0.1892, -0.0332,  0.6504,
           0.5171,  0.4813,  0.6320,  0.3925,  0.6

  1%|          | 12/1227 [00:01<03:05,  6.54it/s]

(tensor([[[ 0.0227,  0.4863,  0.8202,  1.0309,  0.7238,  0.1253,  0.4570,
           0.0700,  0.4068,  0.4346,  0.2408,  0.5703,  0.8268,  0.6082,
           0.1110,  0.7197,  0.0045,  0.5101,  0.4248,  0.1155,  0.9207,
           0.3371,  0.2282,  0.0224,  0.8059,  0.3868,  0.2860,  0.6012,
           0.4356,  0.9480,  0.9325,  0.8035,  0.8088,  0.6767,  0.4495,
           0.8758,  0.8474,  1.0464,  0.0756,  0.4578,  0.5410,  0.3766,
           0.9191,  0.0732,  0.1481,  0.2293,  0.7989,  0.9207,  0.8240,
           0.5806,  0.9563,  0.0351,  0.8634,  0.5678,  0.6742,  0.6814,
           0.7506,  0.3853,  0.1414,  0.1713,  0.9354,  1.0193,  0.1973,
           0.2870,  0.2980,  0.7440,  0.4058,  0.0262,  0.8967,  0.4335,
           0.7276,  0.8267,  0.0690,  0.1875,  0.5018,  0.8297, -0.0654,
           0.7884,  0.0437,  0.5944,  0.7394,  0.3236,  0.3957,  0.7887,
           0.6972,  0.3155,  0.0398,  0.1646,  0.1951,  0.1802,  1.0405,
           0.1735,  1.2340,  0.3236,  0.6600, -0.0

  1%|          | 13/1227 [00:02<03:07,  6.47it/s]

(tensor([[[ 0.0932,  0.0692,  0.5905,  0.3952,  0.5416,  0.1719,  0.5577,
           0.0650,  0.8519,  0.3564,  0.4363,  0.6612, -0.0139,  0.7908,
           0.3421,  0.1497,  0.7674,  0.4087,  0.4860,  0.4400,  0.5198,
           0.1951,  0.6875,  0.3353,  0.1072,  0.0872,  0.1548,  1.0223,
           0.2953,  0.6950,  0.2124,  0.5134,  0.1838,  0.5062,  0.7479,
           0.0972,  0.0886,  0.5109,  0.8136,  0.0401,  0.5387,  0.5909,
           0.9687,  0.7079,  1.1147,  0.4844,  0.3875,  0.6625,  0.8705,
           0.2160,  0.0141,  0.7565,  0.3827,  0.1387,  0.8178, -0.0167,
           0.5048,  0.7056,  0.8977,  0.1389,  0.8593, -0.1151,  0.5637,
           0.2510,  0.6130,  0.4294,  0.8127,  0.5462,  0.6686,  0.8858,
           0.1362,  0.8411,  0.4820,  0.2152,  0.5884,  0.0914,  0.5952,
           0.1025,  0.3042,  0.6140,  0.8387, -0.0333,  0.5759,  1.0471,
           0.2672,  0.6286,  0.2891,  0.2668,  0.4873,  0.9345,  0.1858,
           1.0027,  0.8992,  0.1964,  0.2244,  0.3

  1%|          | 14/1227 [00:02<03:11,  6.35it/s]

(tensor([[[ 0.8809,  0.9297,  0.8227,  1.3091,  0.5082,  0.1002,  0.5410,
           0.1091,  0.5138,  0.5616,  0.9576,  0.1986,  0.4311,  0.0400,
           0.3628,  0.6522,  0.8960,  0.6297,  1.0195,  0.4998,  0.3525,
           1.0408,  0.3794, -0.0172,  0.7474,  0.1208,  0.1443,  0.1820,
           0.3800,  0.3349,  0.0894,  0.2938,  0.0822,  0.2127,  0.7631,
           0.2081,  0.5311,  0.6879,  0.2475,  0.4055,  0.3684,  0.1616,
           0.7700,  0.8503,  0.6098,  0.8633, -0.0104,  0.6235,  0.4268,
           0.9283,  0.9268,  0.1635,  0.0736,  1.1220, -0.0117,  0.1876,
           0.3278,  0.0309,  0.3147,  0.5556,  0.1286,  0.2220,  0.7866,
           0.9139,  0.1906,  0.6526,  0.4825,  0.7452, -0.0111,  0.3306,
           0.9875,  0.8321,  0.9229,  0.2147,  0.5338,  0.6367, -0.0777,
           0.6528,  0.6887,  0.7650,  0.9068,  0.3904,  0.3824,  0.6681,
           0.5830,  0.8426,  0.6924,  0.2040,  0.6560,  0.5407,  0.6351,
           0.5628,  0.4887,  0.0948,  0.5290,  0.8

  1%|          | 15/1227 [00:02<03:12,  6.31it/s]

(tensor([[[ 0.7991,  0.5784,  0.4334,  0.5297,  0.9462,  0.5955,  0.4130,
           0.6007,  0.6028,  0.6531,  0.3347,  0.5915,  0.1322,  0.9224,
           0.2717,  0.2846,  0.9916,  0.3098,  0.2576,  0.6208,  0.6419,
           0.2292,  0.6851,  0.3933,  0.8147,  0.0334,  0.0850,  0.0584,
           0.4578,  0.2982,  0.0399,  0.8764,  0.2160, -0.0658,  0.8546,
           0.5713,  0.4601,  0.7282,  0.3090,  0.7281,  0.3291,  0.3468,
           0.5991,  1.0259,  0.5359,  0.6455,  0.8539,  0.3591,  0.5372,
           0.0681,  0.4345,  0.6538,  0.1694,  0.2879,  0.8335,  0.3896,
           0.8294,  0.7175,  0.2233,  0.0061,  0.8676,  0.0609,  0.0908,
           0.1368,  0.1647,  0.7870,  0.8279, -0.0133,  0.6146,  0.9384,
           0.8746, -0.0682,  0.7055,  0.6051,  0.2977,  0.6579,  0.3980,
           1.0183,  0.0060,  0.2420,  0.8209,  0.2487,  0.8822,  0.1957,
           0.4478,  1.0191,  0.6993,  0.4866,  0.5088,  0.5278,  0.6921,
           0.9199,  0.6438,  0.7383,  0.9289,  0.1

  1%|▏         | 17/1227 [00:02<03:10,  6.34it/s]

(tensor([[[ 0.1612,  0.0758,  0.5255,  0.8859,  1.0458,  0.7597,  0.3816,
           0.5398,  0.8171,  0.7046, -0.0091,  0.8744,  0.5801,  0.6487,
           0.2522,  1.0600,  0.1956,  0.8559,  0.7867,  0.7019,  0.4842,
           1.0636,  0.2700,  0.9846,  0.7415,  0.4502,  0.7828,  1.0144,
           0.8958,  0.4097,  0.4207,  0.3995,  1.1532,  0.2475,  0.8502,
           0.5613, -0.0510,  0.4482,  0.4439,  0.5380,  0.6363,  0.5376,
           0.8300,  1.0333,  0.7159,  0.1470,  0.4519,  0.1498,  0.0030,
           0.4277,  0.9755,  0.4523,  0.5838,  0.7816,  0.0117,  0.8412,
           0.6717,  0.8036,  0.5125,  0.3256,  0.4885,  0.7196,  0.1416,
           0.6928,  0.9855,  0.8908,  0.7850,  0.2197,  0.7496,  0.0910,
           0.7215,  0.1255,  0.4552,  1.1066,  0.4522,  0.3213,  0.2399,
           0.0603,  0.7021,  0.7334,  0.1766,  0.9744,  0.2599,  0.9229,
           0.0764,  0.3145,  0.7498,  0.1274,  0.6695,  0.8452,  0.6325,
           0.6047,  0.6423,  0.4414,  0.4945,  0.9

  1%|▏         | 18/1227 [00:02<03:08,  6.42it/s]

(tensor([[[ 4.9403e-01,  7.2248e-01,  3.5629e-01,  6.4129e-01,  3.7943e-01,
           5.6235e-01,  9.1096e-01,  2.4479e-01,  5.5773e-01,  6.7503e-01,
          -4.9139e-02,  1.2710e-01,  8.4263e-01,  2.1665e-01,  8.4198e-01,
           2.7529e-01,  7.5725e-01,  5.4956e-01,  1.8451e-01,  1.0849e-01,
           5.7454e-01,  2.1775e-01,  4.3674e-01,  3.9545e-01,  1.6274e-01,
           2.6709e-01,  1.3661e-01,  8.7314e-01,  4.8144e-01,  2.3979e-01,
           2.8274e-01,  1.5582e-01,  2.3294e-01,  2.2978e-04,  2.5041e-01,
           3.3822e-01,  7.7786e-01,  8.3493e-01,  4.7785e-01,  9.5035e-01,
           1.2275e-01,  7.8952e-01,  1.0523e+00,  2.1607e-01,  6.4768e-01,
           9.5733e-01,  5.1150e-01,  9.3279e-01,  7.3650e-01,  2.3735e-01,
           1.5020e-01,  3.1932e-01,  3.2551e-01,  1.4965e-01,  8.1877e-01,
           2.9205e-01,  2.9851e-01,  6.2276e-01,  4.7704e-01,  4.5968e-01,
           7.8072e-01,  1.8299e-01,  7.7229e-01,  4.8634e-01,  1.2454e-01,
           7.9290e-01, -

  2%|▏         | 19/1227 [00:02<03:05,  6.52it/s]

(tensor([[[ 0.4043,  0.2519,  0.1588,  0.2242,  0.5324,  0.6389,  0.8765,
           0.5466,  0.1236,  0.5663,  0.4949,  0.3190,  0.8382,  0.3301,
           0.9239,  0.0769,  0.4266,  0.9729,  0.7291,  0.3915,  0.1597,
           1.0528,  0.0107,  0.2319,  0.6856,  0.7004,  0.9748,  0.8933,
           0.5381,  0.5775,  0.3474,  0.2440,  0.8379,  0.3375,  0.0781,
           0.4347,  1.0178,  0.8482,  0.2263,  0.6088,  0.7940, -0.0868,
           0.7441,  0.7160,  0.3794,  0.8758,  0.1000,  0.2089,  0.5978,
           0.4459,  0.3096,  0.0917,  0.6179,  1.0163,  1.0123,  0.3638,
           0.9557,  0.2253,  0.2482,  0.9406,  0.8179,  0.4543,  0.2657,
           0.9416,  0.8051,  0.2557,  0.8352,  0.6630,  0.8306,  1.0107,
           0.8856,  0.4065,  0.9975,  0.9068,  0.4519,  0.2990,  0.6444,
           0.3552,  0.6037,  0.2152,  0.6105,  0.7312,  0.5352,  0.6616,
           0.7883,  0.1763,  0.0864,  0.9318,  0.5146,  0.6163,  0.4715,
           0.3497,  0.7765,  0.7948,  0.2331,  0.0

  2%|▏         | 21/1227 [00:03<03:13,  6.22it/s]

(tensor([[[-1.0420e-01,  7.1903e-01,  4.4506e-01,  6.4751e-01,  5.0897e-01,
           3.1731e-01,  5.4090e-01,  3.2463e-01,  7.8127e-01, -8.3354e-02,
           8.5110e-01,  1.4446e-01,  1.8130e-01,  8.5347e-01,  9.0424e-01,
           7.9623e-01,  7.4930e-01,  8.8584e-01,  4.9947e-01,  3.7239e-01,
           3.4552e-01,  6.4306e-01,  1.1139e-03,  5.7185e-01,  3.6655e-01,
           9.0170e-01,  1.0227e+00,  1.2617e-01,  8.3122e-01,  1.9271e-01,
           7.2600e-01,  4.0572e-01,  6.7519e-01,  3.9499e-01,  8.1690e-01,
           6.2561e-01,  9.5414e-01,  9.0463e-01,  2.1191e-01,  5.7505e-01,
           4.4344e-01,  2.2734e-01,  9.0306e-01,  1.0552e+00,  3.3421e-01,
           9.5537e-01,  6.5281e-01,  1.7445e-01,  3.0507e-01,  8.3209e-01,
           2.1529e-01,  4.0729e-01,  5.1728e-01,  2.4070e-01,  4.1583e-01,
           2.7784e-01,  5.3105e-01,  4.7354e-01,  1.3505e-01,  5.7409e-01,
           2.2180e-01,  6.5475e-01,  1.1549e-01,  3.3367e-01,  4.1412e-01,
           1.1730e+00,  

  2%|▏         | 22/1227 [00:03<03:22,  5.96it/s]

(tensor([[[ 0.5346,  0.6409,  0.0913,  0.5863,  0.2945,  0.7920,  0.2830,
           0.2399,  0.4233,  0.8521,  0.9319,  0.2584,  1.0189,  0.3501,
           0.1827,  0.4638,  0.3665,  0.1813,  1.0126,  0.0514,  0.2446,
           0.3842, -0.0332,  0.0362,  0.1886,  0.8469,  0.6482,  0.7359,
           1.0094,  0.1727,  0.5307,  0.2112,  0.1917,  0.6269,  0.4838,
           0.6618,  0.0646,  0.0909,  0.8281, -0.0593,  0.4347,  0.7522,
           0.0900,  0.0407,  0.9338,  1.1381,  0.6226,  0.8731,  0.3627,
           0.0883,  0.9876,  0.6405,  0.7149,  0.3713,  0.2861,  0.3427,
           0.1370, -0.1094,  0.3579,  0.5798,  0.5526,  0.7090,  0.3473,
           0.6335,  0.4882,  0.8026,  0.8622,  0.2687,  0.3735,  0.4357,
           0.3496,  0.2659,  0.5305,  0.6053,  0.4216,  0.9760,  0.8267,
           0.0171,  0.3494,  1.0480,  0.5466,  0.6465,  0.1062,  1.0337,
           0.6756,  0.5760,  0.0153,  0.4873,  0.3051,  0.5045,  0.6262,
           0.9662,  0.6762,  0.1794,  0.7121,  0.5

  2%|▏         | 23/1227 [00:03<03:20,  6.01it/s]

(tensor([[[ 0.6312,  0.4922,  0.1055,  1.1013,  0.6648,  0.8126,  0.7436,
           0.9648,  0.5313,  0.7613,  0.0168,  0.7710,  0.3356,  0.5438,
           0.7803,  0.5110,  1.0877,  0.6032,  0.0374,  0.2985,  0.1545,
           0.4209,  0.4186,  0.0210,  0.9157,  0.7108,  0.3514,  0.5558,
           0.6235,  0.8446,  0.8090,  0.1729,  0.3795,  0.7223,  0.6540,
           0.6586, -0.0839,  0.5608,  0.7340,  0.2465,  0.0385,  0.6409,
           0.9322,  0.3734,  0.6466,  0.7709,  0.2805,  0.9775,  0.7770,
           0.1459,  0.2024,  0.6218,  0.1078,  0.8428,  0.9878,  0.4626,
           0.6303,  0.9254,  0.2897,  0.8985,  0.8052,  0.8028,  0.8017,
           1.0889,  0.4177,  0.5255,  0.2184,  0.6186,  0.8635,  0.8947,
           1.0835,  0.7928,  1.1806,  0.2240,  0.4601,  0.2486,  0.6170,
           0.5748,  0.6399,  0.7020,  0.7207,  1.0454,  0.9355,  0.4138,
           0.1907,  0.9880,  0.3280,  0.4335,  0.5107,  0.0897,  0.6637,
           0.6717,  0.8205,  0.1768,  1.0280,  1.0

  2%|▏         | 24/1227 [00:03<03:18,  6.07it/s]

(tensor([[[ 0.6041,  0.2258,  0.5489,  0.7062,  0.5278,  0.4247,  0.6520,
           0.9433,  0.4562, -0.0083,  0.2215,  0.7879,  0.6651,  0.4864,
           0.7410,  0.7275,  0.9168,  0.2038,  0.0924,  0.7817,  0.7685,
           0.2900, -0.0655,  0.4174,  0.9255,  0.0783,  0.7055,  0.1816,
           0.7867,  0.2552,  0.8267,  0.4575,  0.0192,  0.4308,  0.0518,
           0.4466,  0.3413,  0.3069,  0.1539,  0.6605,  0.8818,  0.3994,
           0.4021,  0.8023,  0.7145,  0.7427,  0.3367,  1.0752,  0.6907,
           0.5882,  0.4181,  0.6333,  0.1471,  0.8805,  0.5688,  0.3016,
           0.7988,  0.8236,  0.3171,  0.2479,  1.1671,  0.5375,  0.4743,
           0.0681,  0.3317,  1.0262,  0.1048,  0.4465,  0.7390,  0.2860,
           0.0879,  0.1265,  0.9687,  0.5976,  0.2397,  0.5637,  0.8074,
           0.7787, -0.1335,  0.7036,  0.5062,  0.2972,  0.5992,  0.5160,
           0.9537,  0.1470,  0.6757,  0.6666,  0.9228,  0.0019,  0.7267,
           0.6950,  0.0905,  0.5094,  0.6110,  0.0

  2%|▏         | 26/1227 [00:04<03:25,  5.83it/s]

(tensor([[[ 0.4301,  0.1579,  0.1900,  0.6079,  0.5869,  0.8145,  0.6434,
           0.3167,  0.4165,  0.5947,  0.2202,  0.3129,  1.0337,  0.0830,
           0.5417,  0.2384,  0.3200,  0.3030,  0.8695,  0.6374,  0.4872,
           0.8853,  0.6847,  0.1414,  0.2432,  0.4988,  1.0099,  0.7664,
           0.2383,  0.7390,  0.2004, -0.0036, -0.0529,  0.7483,  0.6320,
           0.3103,  0.9094,  0.3119,  0.6255, -0.0954,  0.2766,  0.5362,
           0.6902,  0.0727,  0.3727,  0.6055, -0.0978,  0.1209,  1.0573,
           0.1138,  0.7231,  0.4233,  0.3797,  0.7943,  0.1823,  0.6268,
           0.5221,  1.0577,  0.3261,  0.4178,  0.3076,  0.7946,  0.5739,
           0.9136,  0.5409,  0.7750,  0.3634,  0.5046,  0.4368,  0.5674,
           0.2819,  0.0129,  0.8593,  0.5054,  0.6595,  0.5433,  0.3886,
           0.8794,  0.7566,  0.6387,  0.4161,  0.7125,  1.0373,  0.0832,
           0.2791,  1.2417,  0.1177,  0.8634,  1.0651,  0.6159,  0.5204,
           1.0642,  0.6823,  0.8607,  0.4018,  0.6

  2%|▏         | 27/1227 [00:04<03:18,  6.06it/s]

(tensor([[[ 0.3470,  0.1296,  0.4445,  0.1791,  0.3113,  0.0542,  0.1578,
           0.3930,  0.3708,  0.6289,  0.8300,  0.7550,  0.7644,  0.5142,
           0.0806,  0.4751,  0.8593,  0.1607,  0.3552,  0.5354,  0.7416,
          -0.0502,  0.8676, -0.0478, -0.0450,  0.8411,  0.1737,  0.0527,
           0.8497,  0.6993,  0.8825,  0.1155,  0.4764,  0.3222,  0.6455,
           0.3839,  0.1702,  0.6484,  0.2466,  0.2236,  0.5203, -0.0103,
           0.2103,  0.7437,  0.7666,  0.4216,  1.0533,  0.4422,  0.6797,
           0.4287,  0.9314,  0.5520,  0.2513,  0.9523,  1.0077, -0.0988,
           0.3837,  0.3125,  0.4640,  0.3619,  0.8285,  0.7313,  0.6389,
           0.1067,  0.1278,  0.7002,  0.5771,  0.5763,  0.8449,  0.6835,
          -0.0511,  0.3619,  0.9799,  0.9085,  0.2673,  0.5272,  0.2851,
           0.9512,  0.7899,  0.3537,  0.2640,  0.9824,  1.1427,  0.1259,
           0.6330,  0.5277,  0.4522,  0.1497,  0.6079,  0.5036,  0.4161,
           0.7142,  0.5238,  0.5759,  0.8189,  0.4

  2%|▏         | 28/1227 [00:04<03:11,  6.25it/s]

(tensor([[[-0.0958,  0.4787,  0.0807,  0.4809,  0.6557,  0.8687,  0.0731,
           0.2956,  0.6775,  0.3572,  0.0106,  0.4551,  0.2221,  0.4205,
           0.5914,  0.5480,  0.4510,  0.8864,  0.7160,  0.1150,  0.9053,
           0.5385,  0.4593,  0.4437,  0.5252,  0.5347, -0.0185,  0.8213,
           0.1934,  0.1835,  0.5541,  0.8180,  1.1432,  0.2409,  0.5657,
           0.4379,  0.0063,  0.7294,  0.0313,  0.5973,  0.2696,  0.2002,
           0.3664,  0.9617,  0.7294,  0.7018,  0.4980,  0.1665,  0.6772,
           0.7721,  0.6643,  0.4982,  0.6385,  0.4443,  0.6510,  0.6807,
           0.7412,  0.6780,  1.0607,  0.6456,  0.9092,  0.0771,  0.2903,
           0.9593,  0.2410,  0.2803,  0.7980,  0.0703,  0.2095, -0.0848,
           0.3578,  0.5083,  1.0975,  0.9911,  0.8077,  0.2559,  0.1162,
           0.7082,  0.4825,  0.4303,  0.3425,  0.4176,  0.8858,  0.5303,
           0.3620,  0.6410,  0.3672,  0.9145,  0.3753, -0.0560,  0.3037,
           0.9716,  0.1739,  0.9988,  0.2807,  0.3

  2%|▏         | 29/1227 [00:04<03:15,  6.13it/s]

(tensor([[[ 0.4031,  0.5475,  0.2765,  0.0961,  0.4547,  0.1987,  0.8750,
           0.7166,  0.2753,  0.3969,  0.6735,  0.6770,  0.8898,  0.1361,
           0.8204,  0.9754,  0.3506,  0.3244,  0.7521,  0.2629,  0.0870,
           0.3101,  0.0915,  0.0334,  0.7961,  0.3615,  0.1730,  0.2849,
           0.2381,  0.6796,  0.6176,  0.6576,  0.3976,  0.9519,  0.1627,
           0.5086,  0.9532,  0.2740,  0.5783,  0.5783,  0.8212,  0.1367,
           0.8599,  0.3316,  0.2660,  0.6352,  0.6821,  0.3590,  0.9638,
           0.8861,  0.9441,  0.7752,  0.2623,  0.9501,  0.8950,  0.3303,
           0.5858,  1.0723,  0.4246, -0.1385,  0.4505,  0.9663,  0.1231,
           0.5778,  0.3744,  0.4077,  0.1980,  0.3029,  0.4834,  0.7509,
           0.1423,  1.0430,  1.0474,  0.5700,  0.0203,  0.4333,  0.6389,
           0.3214,  0.6778,  0.9129,  0.6011,  0.5049,  0.8351,  0.7629,
           0.8857,  0.5259,  0.2288,  1.0221,  0.5198,  0.4588,  0.4816,
           0.6964,  0.6775,  0.3998,  0.7155,  0.4

  3%|▎         | 31/1227 [00:04<03:22,  5.92it/s]

(tensor([[[ 0.4069,  0.0687,  0.8188,  0.9989,  0.3098,  0.6583,  0.3716,
           0.4869,  0.2927,  0.0857,  0.9957,  0.8861,  0.0735,  0.2495,
           0.4562,  0.5352,  0.1250,  0.7508,  0.7702,  0.7871,  0.4380,
           0.7009,  0.5107,  0.5789, -0.0812,  0.0807,  0.5316,  0.8975,
           0.8986,  0.8590,  0.1441,  0.1583,  0.9733,  0.6289,  0.1165,
           0.4106,  0.6521,  0.0263,  0.7813,  0.3539,  0.9563,  0.6561,
           0.7694,  1.2083,  0.3680,  0.0235,  0.3135,  0.7658,  1.0129,
           0.2540,  0.9961,  0.1102,  0.4931,  0.0426,  0.8816,  0.1014,
           0.7659,  0.7807,  0.9863,  0.9914,  0.2516,  0.2917,  0.4587,
           0.0715,  0.3765,  0.6930,  0.0798,  0.3969,  0.7211,  0.2606,
           0.5196,  0.2199,  0.4004,  0.5936,  0.1671,  0.7440,  0.1319,
           0.5223,  0.6035,  0.0712,  0.6890,  0.9694,  0.1623,  0.3678,
           0.1328, -0.0058,  0.0210,  1.0117,  0.1899,  0.7211,  0.7183,
           1.0253,  0.0441,  0.0226,  0.8977,  0.5

  3%|▎         | 32/1227 [00:05<03:17,  6.06it/s]

(tensor([[[ 0.1079,  0.6585,  0.6186,  0.3263,  0.5522,  0.2787,  0.1150,
           0.3411,  0.5560,  0.3531,  0.8290,  0.9943,  0.9365,  0.1715,
           0.4993, -0.0109,  0.0826,  0.7079,  0.7621,  0.6016,  0.2459,
           0.3309,  0.8663,  0.0758,  0.5412,  0.4806,  0.8762,  0.7616,
           0.1493, -0.0258,  0.1628,  0.1273,  1.0293,  0.9499,  0.0593,
           0.2409,  0.3329,  0.8447,  0.6462,  0.4002,  0.3039,  0.0315,
           0.7363,  0.9911,  0.7686,  0.9280,  0.7741,  0.4052,  0.9023,
           0.1552,  0.7715,  0.1851,  0.3708,  1.0216,  0.6934,  0.1562,
           0.8136,  0.8675,  0.9029,  0.8550,  0.2374,  0.7877,  0.0829,
           0.5201,  0.6419,  0.7035,  0.0184,  0.0622,  0.9177,  0.7706,
           0.7921,  1.0434,  0.2210,  0.5070,  0.7149,  0.7156,  0.0537,
           0.0787,  0.6647,  0.3703,  0.2124,  0.5761,  0.2147,  0.5578,
           0.1792,  0.4291,  0.1043,  0.4770,  0.2670,  0.8928,  0.9854,
           0.9006,  1.2503,  1.0206,  0.3460,  0.3

  3%|▎         | 33/1227 [00:05<03:14,  6.12it/s]

(tensor([[[-0.1620,  0.1330,  0.6593,  0.6882,  0.3208,  0.6150,  0.3966,
           0.4402,  0.4961,  0.6163,  0.0712,  0.3323,  0.3079,  0.6146,
           0.1660,  1.0033,  0.4152,  0.4920,  0.3527,  0.5521,  0.9391,
           0.6394,  0.0842,  0.5822, -0.0457,  0.2175,  1.0734,  0.9119,
           0.9195,  0.6557,  0.5188,  0.8529,  0.7206,  0.8137,  0.9409,
           0.4410,  0.4941,  0.8443,  0.8587,  0.7989,  0.1451,  0.4205,
           0.9274,  0.7597,  0.8606,  1.0154,  0.7126,  0.8288,  0.7191,
           0.3196,  0.8531,  0.8694,  0.1061,  0.2855,  0.1339,  0.2549,
           0.0890,  0.7527,  0.2678,  0.8187,  0.4620,  0.8667,  0.0380,
           0.8920,  0.2532,  0.2000,  0.8389,  0.4770,  0.4070, -0.0711,
           0.1832,  0.8153,  0.1507,  0.1675,  1.0689,  0.0852,  0.5979,
           0.8760,  0.7676,  0.0201,  0.2068,  0.1990,  0.3722,  0.3433,
           0.1217,  1.0351,  0.3200,  0.6410,  0.5510,  0.2719,  0.0894,
           0.7056,  0.1825,  0.2841,  0.0795,  0.6

  3%|▎         | 34/1227 [00:05<03:15,  6.09it/s]

(tensor([[[ 6.3453e-01,  2.1744e-01,  8.4774e-02,  9.9937e-01,  1.0212e+00,
           7.8745e-01,  8.5044e-01,  7.1331e-01,  4.0678e-01,  4.0004e-01,
           3.4663e-01,  7.0851e-01,  8.7558e-01,  4.0959e-01,  3.8780e-01,
           1.5937e-01,  8.5123e-01,  2.7790e-01,  7.0005e-01,  7.3527e-01,
           4.3720e-01,  1.0112e+00,  4.1091e-01,  6.0373e-01,  9.5888e-01,
           3.3163e-01,  5.8130e-01,  5.9522e-01,  4.3871e-01,  4.3601e-02,
           6.7097e-01,  8.0787e-01,  7.2061e-01,  1.9276e-01,  6.0824e-01,
           5.2631e-01,  3.1127e-01,  8.6967e-01,  1.0556e+00,  5.2111e-01,
           2.6948e-01,  1.4403e-01,  6.2978e-01,  8.5155e-01,  2.5876e-01,
           8.6501e-01,  3.0235e-01,  9.2416e-01,  2.2284e-01,  2.5178e-01,
           1.0494e+00,  2.2977e-01,  7.0135e-02,  5.7535e-01,  1.1670e-01,
           6.4290e-01,  9.8216e-01,  6.2635e-01,  2.6022e-01,  3.1808e-02,
           1.6078e-01,  6.8946e-01,  1.1408e-01,  8.1733e-01,  5.1816e-01,
           7.4559e-01,  

  3%|▎         | 35/1227 [00:05<03:23,  5.85it/s]

(tensor([[[ 0.9606,  0.4647, -0.0449,  0.1227,  0.4818,  0.1125,  0.5247,
           0.7243,  0.3263,  0.3434,  0.6579,  0.0202,  0.8335,  0.9288,
           0.7938,  0.4986,  1.0556,  0.3805,  0.3604,  0.3781,  0.5626,
           0.7745,  0.6666,  0.6537,  0.0513, -0.0073,  0.2253,  0.7288,
           0.1101,  0.1875,  0.6058,  0.6061,  0.8086,  0.8184,  0.1411,
           0.1542,  0.6466,  0.1765,  0.2554,  0.0630,  0.6628,  0.9083,
           0.8767,  0.0192,  0.6824,  0.5438,  0.8168,  1.0995,  0.3747,
           0.3775,  0.7429,  0.4938,  0.1103,  0.0198, -0.0119,  0.7028,
           0.2266,  0.1386,  0.8018,  0.2919,  0.8494,  0.2675,  0.4424,
           0.4402,  0.2314,  0.1110,  0.4797,  0.6525,  0.9592,  0.2589,
           0.9684,  0.2940,  0.5037,  0.9865,  0.5477,  0.2986,  0.8051,
           0.6379,  0.3068,  0.5914,  0.6820,  0.1535,  0.2179,  0.8933,
           0.1455,  0.9398,  0.5984,  0.2359,  0.1691, -0.0436,  0.5868,
           0.8283,  0.8494,  0.3949,  0.7172,  0.9

  3%|▎         | 37/1227 [00:05<03:19,  5.95it/s]

(tensor([[[ 0.3636,  0.5565,  0.4275,  0.2511,  0.9588,  0.9588,  0.8987,
           0.3553, -0.1074,  0.4776,  0.3181,  0.9747,  0.4882,  0.6414,
          -0.0687,  0.9246,  0.6964,  0.8649,  0.0114,  0.7267,  0.9111,
           0.2835,  0.3009,  0.9456,  0.4293,  0.8961,  0.2945,  0.9728,
           0.8942,  0.8125,  0.0573,  0.7231,  0.1943,  0.4462,  0.5093,
           0.8393,  0.8215,  0.3860,  0.3748,  0.6103,  0.5097,  0.3980,
           0.1297,  0.2976,  0.8675,  0.9094,  0.5178,  0.7073,  0.7596,
           0.4142,  0.8227,  0.7110,  0.6420,  0.7754,  0.1473,  0.6890,
           1.0524,  0.8619,  0.4531,  0.7939,  0.6628,  0.0835,  0.1873,
           0.9922,  0.1585,  1.0324,  0.0985,  0.3465,  0.6308,  0.3216,
           0.3315,  0.0697,  0.1708,  0.4723,  0.2080,  0.8208,  0.4669,
           0.7399,  0.6514,  0.4172,  0.6128,  0.5251,  0.3354,  0.4076,
           0.8385,  1.1644,  0.6434,  0.1887,  0.9341, -0.0380,  0.7758,
           0.2408, -0.0212,  0.5691,  0.4670,  0.5

  3%|▎         | 38/1227 [00:06<03:14,  6.11it/s]

(tensor([[[ 0.5294,  0.6452,  0.6370,  0.3317,  0.2945,  0.5192,  0.8858,
           0.7145,  0.5472,  0.1462,  0.7741,  0.8115,  0.6133,  0.3708,
           0.4560,  0.4249,  0.4833,  0.5917,  0.0850,  0.1820,  0.6878,
           0.2897,  0.1349,  0.5388,  0.5437,  0.1458,  0.3962,  0.5177,
           0.3316, -0.0350,  0.3424,  0.9930,  0.7247,  0.7918,  0.7454,
           0.0270, -0.0147,  0.3758,  0.8284,  0.3196,  0.8007,  0.5286,
           1.0841,  0.7907,  0.7354,  0.5772,  0.3034,  0.7292,  1.0548,
           0.0986,  0.7571,  0.8307,  0.9737,  0.3833,  0.9120,  0.6833,
           0.4079,  0.7806,  0.6922,  0.3335,  0.1811,  0.5468,  0.0343,
           0.5090,  0.7424,  0.1557,  0.7938,  0.8032,  0.5962,  0.3966,
           0.3412,  0.8069,  1.0065,  0.6970,  0.6866,  0.1579,  0.1891,
           1.0260, -0.1033,  0.2743,  0.9252,  0.6244,  0.6880,  0.6382,
           0.4203,  1.1114, -0.0847,  0.1939,  1.0118, -0.0243,  0.6662,
           0.3248,  0.9296,  0.8095,  0.8563,  0.0

  3%|▎         | 39/1227 [00:06<03:12,  6.16it/s]

(tensor([[[ 0.6198,  0.6001,  0.1076,  0.6020,  0.7336,  0.0202,  0.4818,
           1.1481,  0.2377,  0.8558, -0.0674,  0.2251,  0.0812,  0.5264,
           0.0236,  0.3794,  1.1294,  0.9473,  0.1686,  0.3357,  0.5168,
           0.3314,  0.1892,  0.4906,  0.3789,  0.5993,  0.3705,  1.0492,
           0.2951,  0.9110,  0.3148,  0.6800,  0.2739, -0.0898,  0.8165,
           0.7361,  0.6094,  0.8238,  0.4493,  0.8014,  0.6267,  0.1272,
           0.2389,  0.0822,  0.9030,  0.5178,  0.7915,  0.7835,  0.7828,
           0.3054,  1.0428,  0.9030,  0.8542,  0.1302,  0.6784,  0.1886,
           0.0674,  0.2403,  0.2512,  0.0469,  0.3540, -0.0049,  0.6924,
          -0.0584,  0.5916,  0.5471,  0.4001,  0.8483,  0.2907,  0.2912,
           0.9252,  0.9044,  0.9927,  0.4047,  0.7606,  0.1497,  0.8356,
           0.1704, -0.0672,  0.5659,  0.8671,  0.4719,  0.5750,  0.9503,
           0.2619,  0.5767, -0.0871,  0.0577,  1.1426, -0.1575,  0.3364,
           0.3768,  0.5056,  0.1967,  0.0603,  0.5

  3%|▎         | 41/1227 [00:06<03:08,  6.31it/s]

(tensor([[[ 0.0346,  0.3773,  0.7471,  0.7077,  0.5045,  0.3150,  1.1112,
           0.4925, -0.0200,  0.2841,  0.5084,  0.1882,  0.7063,  0.4687,
           0.3957,  0.6100,  0.7048,  0.4931,  0.3154,  0.0091,  0.8515,
           0.5234,  0.3913,  0.8108,  0.7592,  0.1441,  1.0279,  0.4736,
           0.8328,  0.7265,  0.3610,  0.5656,  0.4192,  0.3074,  0.0175,
           0.1625,  0.3653,  0.3646,  0.5047,  0.6114,  0.4419,  0.2474,
           0.6628,  0.9225,  0.5474,  0.8374,  0.3140,  0.5403,  0.1653,
           0.6202,  0.2083,  0.4793,  0.3447,  0.4722,  0.5096,  0.1346,
           0.7963,  0.5421,  0.6672,  0.7927,  0.5076,  0.4975,  0.5782,
           0.2061,  0.1400,  1.0934, -0.1903,  0.7386,  0.6773,  0.8424,
           0.3942,  0.5540,  0.3612,  0.7388,  0.1043,  0.6679,  0.8879,
           0.8852,  0.8327,  0.4775,  0.7505,  0.2666,  0.2748,  0.1622,
           0.2597,  0.6285,  0.0862,  0.9621,  0.1778,  0.4752,  0.8584,
           0.3805,  0.4368,  0.0766,  0.1468,  0.7

  3%|▎         | 42/1227 [00:06<03:02,  6.48it/s]

(tensor([[[ 0.7085,  0.3837,  0.1196,  0.5286,  0.6695,  0.4523,  1.0610,
           0.6560,  0.0090,  0.3475,  0.2714,  0.0499,  0.4403,  0.6603,
           0.3727,  0.4530,  0.4157,  0.2454,  0.0201,  0.5087,  0.9585,
           0.8963,  0.6409,  0.1132,  0.3498,  0.2522,  0.1374,  0.6380,
           0.1310,  0.3838,  0.7586,  1.0414,  0.0852,  0.6165,  0.4699,
           0.6986,  0.1585,  0.2764,  0.1665,  0.1260,  0.1114,  0.3642,
           0.1465,  0.4444,  0.5033,  0.5782,  0.6025,  0.4806,  0.6518,
           0.6446,  1.0880,  0.2700,  0.2642,  0.8012,  0.2691,  0.8286,
           0.5474,  0.6234,  0.6817,  0.6262,  0.3127,  0.3095,  0.2205,
           0.1278,  0.1970,  0.5379, -0.0930,  0.3103,  0.8565,  0.9090,
           0.1428,  0.9480,  0.5018,  0.4743,  0.7394,  1.1212,  0.4649,
           0.1605,  0.3947,  0.5335,  0.8244,  0.3557,  0.8184,  0.2980,
           0.4687,  1.0785,  0.2307,  0.2052,  0.2496,  0.3388,  0.9359,
           1.0469,  0.8542,  0.4783,  0.0271,  0.9

  4%|▎         | 43/1227 [00:06<03:01,  6.51it/s]

(tensor([[[ 0.6193,  0.6845,  0.5716,  0.2893,  0.9082,  0.3071,  0.2231,
           0.0817,  0.6044,  0.6978,  1.0276,  0.6884,  0.6040,  0.7175,
           0.9521,  0.9002,  0.7397,  0.1045,  0.2779,  0.5283,  0.5723,
           0.2539,  0.1283,  0.7460,  0.0341,  0.9722,  0.7129,  0.9734,
           1.0371,  0.6903,  0.8613,  0.7274,  1.0672,  0.4437,  0.2905,
           0.4329,  0.8458,  0.7696,  0.2192,  0.3455,  0.0481,  0.3703,
           0.1252,  0.3735,  0.1830,  0.9726,  0.6684,  0.3263, -0.0087,
           0.1974,  0.9772,  0.4847,  0.7139,  0.4311,  1.0012,  0.5440,
           0.2451,  0.1986,  0.2291,  0.1205,  0.5257,  0.7002,  0.4112,
           0.6124,  0.6182,  1.0096,  0.6484,  0.3079,  0.0586,  0.2002,
           0.9863,  0.7218,  0.4831,  0.9239,  0.5235,  0.1416,  0.1515,
           0.1577,  0.2740,  0.2775,  0.8568,  0.3381,  1.2334,  0.5887,
           0.8406,  0.9827,  0.0832,  0.1318,  0.8930,  0.4314,  0.3449,
           0.9647,  1.2443,  0.6865,  0.4452,  0.3

  4%|▎         | 45/1227 [00:07<03:03,  6.44it/s]

(tensor([[[ 0.1337,  0.8075,  0.2580,  0.3235,  0.5927,  0.4770,  0.7252,
           0.2054,  0.4239,  0.5755,  0.4837,  0.7568,  0.1514,  0.4850,
           0.9632,  0.1599,  0.6170,  0.7393,  0.5221,  0.7896,  0.3766,
           0.1880,  0.8377,  0.7568,  0.3075,  0.2755,  0.6654,  0.3150,
           0.1944,  0.8541,  0.8445,  0.2450,  0.2057,  0.2483,  0.4412,
           0.4825,  1.0148,  0.7219,  0.2879,  0.2288,  0.3332,  0.7343,
           0.8594,  0.2836,  0.2995,  0.3241,  0.8436,  0.8706,  0.3043,
           0.1988,  0.1611,  0.6024,  0.3844,  0.6287,  0.1785,  1.0664,
           0.4492,  0.8022,  0.5562,  0.9586,  0.8479,  0.4197,  0.3451,
           0.5276,  0.3649,  0.0975,  0.7578,  0.0272,  0.7948,  0.9792,
           0.2401, -0.0335,  0.5261,  0.3360,  0.3931,  0.4383,  0.8852,
           0.4174,  0.8802,  0.3761,  0.3474,  0.7243,  0.2701,  0.2030,
           0.6020,  0.7315,  0.2183,  0.5360,  0.7246,  0.2766,  0.0704,
           1.3316,  0.6742,  0.0570,  0.5724,  0.7

  4%|▎         | 46/1227 [00:07<03:04,  6.39it/s]

(tensor([[[ 3.2851e-01,  3.1776e-01,  2.2314e-01,  8.4283e-01,  4.8617e-01,
           5.1202e-01,  2.1057e-01,  3.4840e-02,  5.2502e-01,  4.4396e-02,
           1.4222e-01,  7.4897e-01,  7.0520e-01,  1.2156e-02,  7.3852e-01,
           5.0189e-01,  7.3314e-01,  6.1734e-01,  8.2348e-01,  3.9106e-01,
           7.3050e-01,  1.5616e-01,  9.5218e-01,  1.8012e-01,  1.0338e-01,
           3.7336e-01,  6.8273e-01,  2.6965e-01,  8.3294e-01,  4.5801e-01,
           5.6972e-01,  1.0172e+00,  5.9803e-01,  2.0104e-01,  4.7398e-01,
           2.3432e-01,  3.7656e-01,  8.4717e-01,  7.6830e-01, -1.6366e-02,
           6.0185e-01,  6.9177e-03,  7.3445e-01,  2.7180e-01,  9.2437e-01,
           8.3375e-01,  6.7058e-01,  4.8724e-01,  7.0193e-01,  4.8172e-01,
          -2.0451e-04,  1.2964e-01,  4.2006e-01,  1.5195e-01,  2.5742e-01,
           3.2004e-02,  1.1084e+00,  8.5900e-01,  9.8395e-01,  2.4460e-01,
           9.0793e-01,  1.1570e-01,  9.0448e-01,  6.2393e-01,  8.1256e-01,
           2.0830e-01,  

  4%|▍         | 48/1227 [00:07<03:00,  6.52it/s]

(tensor([[[ 0.8226,  0.2648,  0.1711,  0.4079,  0.5927,  0.7367,  1.0393,
           0.7467,  0.3936,  0.9207,  0.0897,  0.5998,  0.1342,  0.2321,
           0.8131,  0.6702,  0.7933,  0.7446,  0.6983,  0.6521,  0.0764,
           0.9043,  0.9467,  0.3227,  0.5989,  0.2493,  0.5350,  0.1288,
           0.3919,  0.2901,  0.9717,  0.0133,  0.2152,  0.8763,  0.7554,
           0.7675,  0.4052,  0.4897,  0.0584,  0.3163,  0.4448,  1.0112,
           0.3226,  0.6177,  0.1251,  0.2441,  1.0290,  0.6353,  0.6132,
           0.6481,  0.4594,  0.3680,  0.7326,  0.5880,  0.1195,  0.3088,
           0.0641,  0.7103,  0.6719,  0.5158,  0.5313,  0.9071,  0.3926,
           0.4147,  1.0829,  0.2868,  0.2686,  0.2826,  0.9871,  0.5048,
           0.9044,  0.2136,  0.5605,  0.8088,  0.5695,  0.5312, -0.0418,
           0.1739,  0.0143,  0.1045,  1.0890,  0.5978,  1.0716,  0.4701,
          -0.0059,  0.6738,  0.1420,  0.1628,  0.1170,  0.4092,  0.2143,
           0.2865,  0.7598,  0.3248, -0.0704,  0.0

  4%|▍         | 49/1227 [00:07<03:05,  6.36it/s]

(tensor([[[ 0.0388,  0.0787,  0.2096,  0.3316,  0.3243,  0.6561,  0.8386,
           0.7944,  0.6643,  0.9327,  0.5613,  0.7542,  0.3969,  0.6626,
           0.7345,  0.3840,  0.5877,  0.3998,  0.1512,  0.5392,  0.3228,
           0.8805,  0.3050,  0.5043, -0.1241,  0.1949,  0.2059,  0.1637,
           0.5485,  0.3424,  0.5493,  0.2129,  0.1197,  0.9513,  0.5059,
           0.5286,  0.9727,  1.0814,  0.6476,  0.4932,  0.2305,  0.9127,
           0.9342,  0.3572,  0.8311,  0.6196,  0.9613,  0.3561,  0.2947,
           0.7103,  0.5263,  0.3213,  0.6927,  0.8877,  0.1146,  0.4741,
           0.1685,  0.3873,  0.3813,  0.1367,  0.8588,  0.1842,  0.1664,
           0.4886,  0.5423,  0.4269,  0.6478,  0.4376,  0.7781,  0.5864,
           0.4738,  0.5413,  0.4197,  0.7738,  0.0739,  0.0967,  0.1944,
           0.5564,  0.4771,  0.3199,  0.2630,  0.3309,  0.2452,  0.9673,
           0.9038,  0.2213,  0.0252,  1.0270,  0.4283,  0.4993,  0.2924,
           0.6437,  0.5835,  0.7983,  0.5079,  0.4

  4%|▍         | 50/1227 [00:07<03:04,  6.38it/s]

(tensor([[[ 0.6764,  0.6691,  0.3432,  0.0700,  0.1030,  0.2889,  0.8610,
           1.0838,  0.0617,  0.5112,  0.6452,  0.3554,  0.1199,  0.2458,
           0.3609, -0.0344,  0.6261,  0.8320,  0.5203,  0.0653,  0.5145,
           0.1731,  0.5800,  0.0743,  0.5391,  0.1088,  0.5629,  0.1017,
           0.1353,  0.9861,  0.2924,  0.0855,  0.0114,  0.5583,  0.7511,
           0.8744,  0.1039,  0.0303,  0.0559,  0.4235,  0.2353,  0.5133,
           0.3258,  0.4882,  0.4859,  0.6243,  0.6342,  0.5090,  0.3325,
           0.3744, -0.0558,  0.1740,  0.1080,  1.0161,  0.3825,  0.6170,
           0.7129,  0.3195,  0.8213,  0.0871,  0.2461,  0.5028,  0.9223,
           0.8716,  0.0314,  0.3321,  0.1092,  0.1374, -0.0924,  0.2176,
           0.1628,  0.1107,  0.4426,  0.8781,  0.4319,  0.3601,  0.7817,
           0.8115,  0.1854,  0.5111,  0.0462,  0.4015,  0.4730,  0.4758,
           0.7553,  0.0959,  0.2646,  0.5815,  0.1391,  0.8325,  0.2174,
           0.3412,  0.4519,  0.5247,  0.2529,  0.6

  4%|▍         | 52/1227 [00:08<03:02,  6.44it/s]

(tensor([[[ 0.0239, -0.0531,  0.8843,  0.2140,  0.9721,  0.0464,  0.7513,
           0.0570, -0.1864,  0.9450,  0.5892,  0.7834,  0.3372,  0.5736,
           0.5339,  0.2406,  0.9138,  0.4307,  0.2793,  0.9041,  0.9592,
           0.8070,  0.7891,  0.2575,  0.2688,  0.8736,  0.8301,  0.6400,
           0.4048,  0.9183,  0.1537,  0.5742,  0.6515,  0.1545,  0.4245,
           0.5030,  0.4669,  0.4518,  0.7567,  0.2799,  0.1601, -0.0689,
           0.9479,  0.9560,  0.6831, -0.0168,  0.9192,  0.1140,  0.7606,
          -0.0195,  0.9678,  0.1421,  0.2754,  0.2415,  0.6613,  0.3396,
           0.6253,  0.5081,  0.1568,  0.3377,  0.1673,  0.1924,  0.3989,
           0.2443,  0.1760,  0.4287,  0.4333, -0.0578,  0.8141,  0.9400,
           0.2344,  0.4301,  0.4904,  0.1644,  0.7843,  0.4724,  0.7092,
           0.4618,  0.6591,  0.6002,  1.0945, -0.1192,  0.7406,  1.0696,
           0.9289,  1.0891,  0.7191,  0.7024,  0.0813,  0.0467,  0.9977,
           0.4874,  1.1352,  0.8049,  0.4973,  0.1

  4%|▍         | 53/1227 [00:08<03:05,  6.33it/s]

(tensor([[[ 0.2288,  0.8220,  0.6844,  0.8544,  1.0638,  0.8408,  1.1136,
           0.9981,  0.3322,  0.1803,  0.0756,  0.2488,  0.9998,  0.9409,
           0.3429,  0.4381,  0.8484,  0.4406,  0.6972,  0.6358,  0.4181,
           0.2181,  0.8534,  0.3463,  0.3793,  0.3797,  0.6816,  0.0503,
           0.4989,  0.1446,  0.7406,  0.7121,  0.3798,  0.5237,  0.2816,
           0.0200,  0.6267,  0.2781,  1.0834,  0.9221,  0.7025,  0.7026,
           0.0756,  0.4747,  0.6076,  0.5147,  0.0830,  0.3216,  0.3596,
           0.5672,  0.8638,  0.9254,  0.4318,  0.3819,  0.2346,  0.4670,
           0.7664,  1.0117,  0.4332,  0.5186,  1.0972,  0.1628,  0.3617,
           0.1188,  0.4684,  0.4935, -0.0534,  0.7209,  0.5757,  0.9375,
           0.8303,  0.6839,  0.9072,  0.0299,  0.6600, -0.0248,  0.0241,
           0.6805,  0.2133,  0.3697,  0.3334,  0.2633,  0.2562,  0.3680,
           0.4052,  0.7829,  0.7256,  0.7979,  0.4528,  0.4239,  0.3030,
           0.2454,  0.8638,  0.4528,  0.3907,  1.0

  4%|▍         | 54/1227 [00:08<03:04,  6.36it/s]

(tensor([[[ 0.5815,  0.9996,  0.1444,  0.6280,  0.2717,  0.3252,  0.9459,
           1.0451,  0.6373,  0.6697, -0.0160,  0.2063,  0.2341,  1.0567,
           0.0594,  0.6659,  1.0061,  1.1078,  0.6144,  0.7723,  0.8710,
           0.7834,  0.4080,  0.0302,  0.7484,  0.2347,  0.3265,  0.7448,
           0.0857,  0.5534,  0.3814,  0.7909,  0.6727,  0.1174,  0.4786,
           0.6687,  0.3884,  0.8114,  0.4621,  0.2819,  0.4771,  0.8469,
           0.3884,  0.9979,  0.9889,  0.7202,  0.7205,  0.3010,  0.8049,
           0.1825,  0.1103,  0.2548,  0.8564,  0.1428,  0.2949,  0.4887,
           0.3978,  0.9468,  0.7622, -0.0970,  1.0353,  0.2273,  0.0709,
           0.1259,  0.9162,  0.7859,  0.0920,  0.9035,  0.2096,  0.6605,
           0.3867,  0.8857,  0.2996,  0.6402,  0.1084,  0.2669,  0.2787,
           0.9261,  0.5534,  0.5257,  0.3217,  0.1096,  1.1868, -0.0175,
           0.9388,  0.7994,  0.0736,  0.8251,  0.3334, -0.1070,  0.5464,
           1.0819,  0.5902,  0.0584,  0.9358,  0.3

  4%|▍         | 55/1227 [00:08<03:12,  6.08it/s]

(tensor([[[ 0.3100,  0.5906,  0.2431,  0.8243, -0.0317, -0.0551,  0.2219,
           0.9130,  0.2583,  0.4132,  0.5558,  1.1451,  0.8391,  0.5194,
           0.0032,  0.6224,  0.9131,  0.1176,  0.9201,  0.4496,  0.0134,
           0.1226,  0.8041,  0.5815,  0.3713,  0.9961,  0.8875,  0.8557,
           0.2681,  0.0921,  0.5238,  0.6519,  1.1048,  0.1476,  0.0363,
           0.2892,  0.1057,  0.9611,  0.0931,  0.2881,  0.6713,  0.9919,
           0.2772,  0.3449,  0.6707,  0.4483,  0.9224,  0.9390,  0.6659,
           0.3449,  0.3092,  0.1758,  0.7446,  0.8802,  0.5853,  0.6175,
           0.5995, -0.0489,  0.7954,  0.5702,  1.0008,  0.1685,  0.9019,
           0.4335,  0.7184,  0.3361,  0.7581,  0.1560,  0.8222,  0.1213,
           1.1189,  0.8138,  0.8394,  0.6780,  0.2962,  0.2316,  0.4582,
           0.4039,  0.3165,  0.8314,  1.0381, -0.1300,  0.6277,  0.6398,
           0.7401,  0.5875,  0.7787,  0.2021,  0.4003,  0.7796,  0.8181,
           0.7436,  0.3312,  0.8118,  0.5700,  0.2

  5%|▍         | 56/1227 [00:08<03:10,  6.15it/s]

(tensor([[[ 0.0472,  0.7635,  0.8397,  0.8038,  0.3372,  0.1390,  1.0654,
           0.2148,  0.5383, -0.0260,  0.3694,  0.7691,  0.3100,  0.3840,
           0.8904,  0.1133,  1.1325,  0.2703,  0.2430,  0.0692,  0.0813,
           0.5834,  0.2048,  0.1404,  0.5615,  1.0986,  0.4173,  0.6299,
           0.4995,  0.3380,  0.7171,  0.2806,  0.4268,  0.4920,  0.2508,
           0.8704,  0.6303,  0.3147,  0.7987,  0.0085, -0.0113,  0.2519,
           0.3583,  0.9667,  0.8462,  0.9575,  0.4292,  0.9384,  0.4457,
           0.9058,  1.0311,  0.6002, -0.0325,  0.1841,  0.9075,  0.8775,
           0.4161,  0.1816,  0.4659,  0.8138,  0.4161,  0.4678,  0.5828,
           0.7141,  0.1948,  0.9794, -0.1026,  0.7011,  0.9192,  0.7225,
           0.3654,  0.5595,  0.1461,  0.5275,  1.0229,  0.8322,  0.6192,
           0.5486,  0.7960, -0.0117,  0.3331,  0.3716,  0.1441,  0.5544,
           1.1611,  1.1040,  0.6289,  1.1191,  0.3823,  0.9648,  0.8715,
           0.4327,  0.8543,  0.6808,  0.8686,  0.0

  5%|▍         | 57/1227 [00:09<03:07,  6.23it/s]

(tensor([[[ 0.7797,  0.7136,  0.8735,  0.3865,  0.5601,  0.5458,  0.1699,
           0.1385,  0.6218,  0.2897,  0.8859,  0.5142,  0.1685,  0.8880,
           0.3183,  0.8289,  1.2757,  0.5494,  0.0283,  0.6286,  0.9168,
           0.5149,  0.3122,  0.3335,  0.0515,  0.4580,  0.3409,  0.0193,
           0.7435,  0.5731,  0.4036,  0.0843,  0.1628,  0.9220,  0.0279,
           0.6349,  0.0060,  0.4239,  0.5179,  0.9191,  0.8496,  1.0077,
           1.1453,  0.6490,  0.2922,  0.2637,  1.0449,  0.3517,  0.4411,
           0.2087,  0.5230,  0.5806,  0.7603,  0.1307,  0.1449,  0.1522,
           0.9955,  1.0965,  0.3415,  0.6035,  0.9894,  0.3294, -0.1133,
          -0.1802,  0.6061,  0.2512,  0.3650,  0.6656,  0.1416,  0.7757,
           0.2555,  0.4136,  0.7754,  0.6781,  0.4678,  0.7092,  0.6337,
           0.4153,  0.2094,  0.6247,  1.0041,  0.4786,  0.4758,  0.3844,
           0.4627,  0.4858,  0.3202,  0.7345,  0.3006,  0.7694,  0.8185,
           0.1444,  0.6026,  0.3322,  0.2459,  0.2




KeyboardInterrupt: 

In [289]:
torch.save(encoder.state_dict(), 'models/encoder_monotonic.pth')
torch.save(decoder.state_dict(), 'models/decoder_monotonic.pth')

SyntaxError: illegal expression for augmented assignment (<ipython-input-15-41ab0e08df21>, line 3)

In [7]:
# Define hyperparameters
hidden_size = 256
latent_size = 128
vocab_size = 28
num_conds = 4
cond_size = 32
teacher_forcing_ratio = 0.5
learning_rate = 0.05
MAX_LENGTH = 40
num_epochs = 20

# Encoder & Decoder
encoder = EncoderRNN(vocab_size, hidden_size, latent_size, num_conds, cond_size).to(device)
decoder = DecoderRNN(latent_size+cond_size, vocab_size).to(device)

encoder.load_state_dict(torch.load('models/encoder_cyclical.pth'))
decoder.load_state_dict(torch.load('models/decoder_cyclical.pth'))

<All keys matched successfully>

In [11]:
words_list = generate_words(encoder, decoder)

(tensor([[[-1.3554, -0.4302,  0.3166,  0.0586, -0.3410,  0.5489, -0.0815,
          -0.4609,  0.1342,  0.4785, -1.3191, -0.2082,  0.7813, -1.3948,
          -0.2952, -1.4528, -1.3076,  0.1294,  0.2870, -0.2742, -1.9832,
          -0.1390, -0.9281,  0.5579,  2.0234,  1.3393, -1.1813,  0.2878,
          -0.1760,  0.0969,  0.0297, -1.4681,  0.5995,  0.7840, -0.7095,
           0.4644, -0.0772,  0.0513, -0.2005,  0.0996, -0.1426,  0.4712,
           0.0667, -0.0848,  0.3033, -0.9096,  0.3957, -0.4060,  0.0515,
           0.0432,  0.9506, -0.8074,  0.2111, -1.0355,  0.8437,  2.4551,
          -0.4498, -0.1083, -1.5220,  1.1774, -0.9578,  1.8828,  0.9016,
          -1.1422,  0.0167,  1.3688, -0.5648,  0.0688, -2.5277,  0.8364,
          -2.0461, -0.1645,  1.2321, -1.3797,  0.6728,  1.3848,  0.2451,
          -1.2914,  0.4231,  0.0145,  0.2952,  1.4987, -0.2367, -0.5906,
          -0.2984, -0.6557, -1.1615,  0.3507, -1.0994,  0.9797, -1.9912,
           0.4515,  0.5153,  0.4650, -0.1690, -0.0

(tensor([[[-2.3550, -1.9872, -1.0729, -1.1634, -0.3705,  0.3174,  0.5894,
          -0.3252, -0.1879, -0.1672, -1.0352, -0.0373, -0.1886,  1.8623,
           0.6009, -0.5415,  0.8368, -1.4444,  2.0295, -0.0769, -1.0406,
           0.6985, -0.6286,  0.1927, -1.0598, -1.2090,  0.5353, -0.8489,
          -0.8582,  1.5506,  1.0503, -1.4356, -1.3509,  1.8993, -1.2497,
           0.2391,  1.3701, -0.0038,  1.0183, -0.7712,  0.2985,  0.0872,
           0.8486, -0.0615, -0.5202, -1.6247, -1.6553, -1.1476, -0.9570,
           1.6773,  0.4787, -0.9522,  0.8922,  1.0099, -0.6584, -0.1060,
          -0.6225,  0.5807,  0.3248, -0.3486,  0.7192, -0.3190,  0.0130,
           1.9678,  0.1860,  3.0203,  0.5907,  0.8928, -1.2865,  0.0528,
          -0.2003, -1.0825,  0.0316, -0.3443, -1.1096,  0.1107,  1.3911,
          -0.1826,  0.2013, -1.7944, -0.8927, -1.0361,  0.7301,  0.8463,
           0.9251,  0.8496, -0.8728, -0.5722,  0.5271,  0.0988,  1.4568,
           0.2953, -0.5592,  0.9553,  0.4077,  0.8

(tensor([[[ 1.7645e+00,  9.7721e-01, -2.4031e-01,  2.2376e+00, -1.4626e-01,
          -1.9992e+00, -6.3164e-01, -4.9811e-01, -5.8170e-01,  9.3369e-01,
          -6.8444e-01,  1.0617e+00,  8.0941e-02, -1.1306e+00,  6.7968e-01,
           2.8853e-01, -1.7421e+00,  1.7469e-01, -7.4127e-01, -2.0943e+00,
           1.9236e+00, -1.9765e+00,  1.3118e+00, -7.5656e-02,  1.8671e+00,
          -1.1252e-01,  2.1115e+00,  7.2884e-02,  7.1798e-01, -1.2476e+00,
          -1.0835e-01, -1.0873e-01, -5.8276e-01, -3.4436e-01,  6.8883e-02,
           1.8760e-01,  9.0691e-01, -1.7274e+00,  1.6635e-01, -9.2832e-02,
          -4.5538e-01, -4.1756e-01, -2.1702e+00,  1.8612e-01,  9.1699e-02,
          -1.2738e-01,  6.3400e-01,  1.0522e+00,  3.9063e-01, -9.0161e-01,
          -2.0238e+00, -7.6484e-01,  1.6699e+00, -1.7715e+00,  5.7862e-01,
          -1.7018e+00,  4.3745e-01, -7.0578e-01,  3.5399e-01, -1.1034e+00,
          -5.9486e-01,  1.1327e+00,  2.0601e+00, -2.5938e+00,  1.3607e+00,
           6.8587e-01, -

(tensor([[[ 0.1681, -1.2398,  0.8761, -0.3031,  0.1130,  1.0877, -0.2752,
           0.7034, -0.4465, -0.4027,  0.1365, -0.5055, -0.4281, -1.2603,
           0.6361,  1.9502,  0.6321, -0.4057, -1.0358, -1.7653,  0.4956,
          -0.5591, -0.2533, -2.7840, -1.3730,  0.9921, -0.3113,  0.0847,
           0.5533, -0.1358,  0.3246, -1.4159,  1.2251, -0.8310,  2.2451,
          -0.4034, -0.1934,  2.2379, -0.5139, -0.6489,  0.5120, -0.8119,
           0.6978,  0.5643, -0.9681,  1.5043,  0.1052, -1.1237, -0.4390,
           0.5091, -0.2627, -0.8159, -1.2607, -0.3522, -0.2672,  0.5923,
           0.6217,  1.1508,  0.7818,  0.9967, -0.0645,  1.4888,  0.5669,
          -1.5950, -2.1005,  1.2261,  1.5585,  0.6446,  0.8404, -0.6484,
          -0.7072,  1.5421,  0.4800, -0.5666, -0.7763, -0.2937, -1.5439,
           2.2993,  0.5042,  1.2806,  0.6974,  0.0688,  0.0387,  0.2810,
           2.4361, -1.0463,  0.7052,  1.4112, -0.5671, -0.9714, -0.6312,
           0.4871,  0.3862,  0.9245, -0.5786, -0.0

(tensor([[[-1.0736e+00, -8.9911e-01, -7.6442e-01,  2.4509e+00, -2.5580e-01,
          -7.6862e-02,  2.3797e+00,  7.6094e-01, -5.5715e-01,  3.4729e-01,
           4.6131e-01, -9.8068e-02,  8.4284e-01, -2.6123e-01, -6.8134e-01,
          -1.3084e+00,  9.8349e-01, -2.9111e-01,  6.9726e-02, -1.9866e+00,
          -2.7476e+00, -2.1026e-01,  2.3516e-01,  3.4338e-01, -5.1833e-01,
          -6.5271e-01,  4.8549e-01,  7.1102e-01,  1.2326e-01, -6.1937e-02,
           2.0231e-02,  7.7376e-01,  1.6573e+00,  1.0354e+00, -2.3845e-02,
           6.9863e-01,  1.3700e+00, -1.1061e+00, -1.0664e+00,  2.3430e+00,
           5.7530e-01, -1.5139e+00,  1.0287e+00,  7.6057e-02,  2.1930e-01,
           1.1695e+00,  4.4980e-01,  1.3793e+00, -1.2750e+00,  2.6100e-01,
           1.2590e+00, -2.9065e-01,  1.3169e+00, -4.6732e-01,  3.8857e-02,
           1.4383e-01,  1.0866e+00, -7.1082e-01,  8.7099e-01, -2.5276e-01,
          -4.0634e-01, -8.5508e-01,  2.5496e-04,  2.8586e-01,  1.8409e-01,
           6.2021e-01,  

(tensor([[[-0.1046, -0.7717, -0.7051,  0.0794, -0.1372,  0.4886,  0.4262,
           1.4693,  1.8927,  0.6461, -0.0283,  1.9158,  2.0722, -0.6380,
           0.2260,  0.4615,  1.5218,  0.3295,  0.0906, -1.5331, -0.3547,
           0.1923,  1.1295, -1.4964,  0.7527,  1.8824, -0.2326, -0.7891,
          -0.2527, -0.6343,  0.1307, -0.7696,  1.5088, -0.7934, -0.1061,
          -1.8848, -1.2913, -0.3719,  0.0890, -0.0315,  0.1122,  2.4612,
           2.1541,  0.1984, -0.2576, -0.0735,  0.9958,  1.0790,  1.8131,
          -0.3089, -0.5467, -0.3970,  0.3780, -0.9635,  0.7856, -0.8134,
           1.0309, -0.2551, -1.0557,  3.1865, -1.3136, -0.9180, -0.3115,
           0.8070,  0.7095, -0.3672, -0.8999,  0.9766,  0.0139, -0.1689,
           0.9032,  0.5627, -1.0869,  0.1127,  0.0944,  1.3116, -0.1308,
          -0.1514,  1.4585, -0.0927,  0.5047, -1.7857, -0.1387, -0.9693,
           1.5533,  0.7837, -2.0699,  1.1386, -2.0109, -1.2759,  0.1884,
          -2.4911, -1.0011, -0.1953,  1.5337,  0.0

(tensor([[[ 0.5280,  0.1977, -0.9139,  1.2154, -1.8671,  0.7108, -0.8292,
          -0.6192,  1.2786, -0.4039,  0.3597,  0.1001,  0.3876,  1.8151,
          -0.0339, -1.1934, -0.5468,  0.4443,  0.1703, -0.7241,  0.2211,
           0.4043, -0.9225,  0.5923, -0.0257,  0.9076,  0.6474, -0.0788,
           0.7288, -2.0963,  1.5977, -1.2958,  0.1445, -0.6328, -0.4092,
           1.4002,  0.4897, -1.2444,  0.1557, -0.6976, -1.4980,  0.1082,
          -1.2094,  0.2519,  0.3570, -0.6116,  0.2413, -0.2740,  0.0714,
          -2.0718,  1.3820, -0.5778, -2.1460, -0.6916,  0.2083,  0.3390,
          -0.5723, -0.9425,  1.5091, -0.6241, -2.0180,  0.2130, -2.3965,
           0.4337, -0.7361, -0.3373, -0.8385,  0.2026,  0.1774, -0.8011,
          -1.4350,  0.6403,  0.1710, -2.1165, -0.4760, -2.1182, -1.3856,
          -0.0062, -1.0456,  0.1081, -0.3161, -1.0330, -1.0643, -1.0472,
           0.2075, -0.7922, -0.5543, -0.6893,  0.8675,  1.5044, -2.0777,
          -0.4492, -0.9350,  1.1407,  0.1797,  1.0

(tensor([[[ 1.0339e+00, -1.7065e-01, -5.2140e-01, -7.6879e-01, -7.4087e-01,
           2.6270e-01,  5.8273e-01, -5.6769e-01,  8.4337e-01,  8.6672e-01,
           1.0446e+00, -2.6806e-01,  8.2051e-01, -1.4228e-01, -5.0451e-01,
           2.8023e-01, -4.6951e-01, -1.0056e+00,  8.9234e-01,  1.3501e+00,
          -7.7790e-01, -6.1073e-01, -6.9056e-01, -3.1755e-01, -3.2612e-01,
          -2.1377e+00, -7.1442e-02, -2.6416e-01,  3.7674e-01, -3.3580e-01,
           6.6122e-01,  6.5696e-01, -2.3035e-01,  7.2305e-01, -4.4764e-01,
          -2.3609e-01, -5.8633e-01,  9.1217e-01, -5.3026e-01, -1.0011e+00,
           5.9415e-01, -2.5612e-01, -9.6214e-02, -1.6572e+00,  4.4296e-01,
           8.8551e-01,  8.2455e-01, -5.0573e-01,  1.8161e+00, -2.5322e+00,
           1.7586e+00, -6.5030e-01, -7.9790e-01,  3.3104e-01,  5.4575e-01,
          -3.8904e-01,  1.4018e+00, -5.7039e-01,  6.0777e-01, -1.0350e+00,
           2.1969e+00, -2.9866e-01,  1.4991e+00,  9.8914e-01, -8.4459e-01,
          -5.6353e-01,  

(tensor([[[-1.7643,  0.4445,  0.6456, -0.2562,  0.4643, -1.2604, -2.3221,
           0.5782,  0.2715,  0.0869,  0.3769,  0.5166,  0.5832,  1.0869,
           0.4035,  0.1482,  2.2491,  0.0508, -0.9606, -0.1010,  1.3060,
           0.2568,  0.4240, -1.0356, -0.2206, -1.1542, -0.8616, -1.1224,
           0.2288,  0.7595,  2.0690, -0.2244,  0.3675,  1.7514, -1.3205,
          -1.0100,  0.0082,  0.3119, -0.4465,  1.7085,  0.3381,  0.6317,
           0.3264,  0.6911, -1.5886,  0.1831,  1.7066,  1.4554,  0.2164,
           0.0261, -1.8302,  0.0042, -0.0370, -1.6719,  0.6611,  0.2794,
          -1.1454, -0.5983,  2.0109,  0.2644,  0.6048,  1.5708,  0.5733,
          -0.1290, -0.5302, -0.1678,  0.1975,  0.3392,  1.1247,  0.7716,
          -1.5634, -1.2896,  0.6036, -0.2075,  0.0870,  0.2801, -0.6454,
           0.8955, -0.6699, -0.4564, -0.7274,  1.1773, -0.6084,  1.4909,
          -0.9738,  1.0462,  1.5481,  1.2203,  1.0704, -0.7123,  0.6575,
           0.8084,  0.0635, -0.3535,  0.6218,  0.2

(tensor([[[-1.2743, -0.6946, -0.5273, -0.3188, -0.0568,  0.7629, -0.0311,
          -0.3522, -1.2423,  1.3701,  0.1138, -0.3675, -2.5833, -0.1510,
          -0.3226,  0.5039, -0.9234, -0.6930, -0.1101,  0.1385,  0.0994,
           0.1067, -1.5186, -0.9162, -1.0179, -0.8773, -1.2718, -0.9422,
          -0.9982, -0.0606,  2.2314,  0.7206, -0.0284, -0.9897,  0.6720,
          -1.2970,  0.7827, -0.3521, -0.8377, -1.7055,  0.8513, -0.1569,
          -0.3647,  2.1459, -0.3455, -1.1245, -1.7652, -0.3608, -0.3234,
          -0.6926,  0.3240, -0.1719,  1.5578,  1.0727,  0.2631, -0.0419,
          -0.0145, -0.6945, -3.0233,  0.2355,  1.2840, -1.2781,  0.3812,
          -0.3352, -1.2108,  0.2954, -0.6798, -0.1873, -1.2544, -0.9876,
           1.2460, -0.1733,  0.2277,  0.2993,  1.5627,  0.6177, -2.7670,
          -0.4790,  0.6882,  1.3797, -0.6291,  1.3841,  0.5874,  0.2143,
           0.2696,  1.7254, -0.0603,  0.4085,  0.4743, -0.0381, -0.5731,
          -0.2912,  0.5569,  0.2776, -0.0409,  0.2

(tensor([[[ 0.2451,  1.0155, -1.2132, -0.1183, -0.3729,  0.8197, -0.3307,
          -0.7054,  2.0540,  0.6928,  0.3093,  0.0707,  0.0229, -0.4497,
          -1.9719, -0.2459, -0.5017,  0.6213, -1.1440,  0.3519,  1.0515,
           1.2309,  1.0686,  0.2587,  0.2435,  0.7558,  1.7629, -0.6452,
           1.4713,  0.5663,  1.5311,  1.5804, -0.5570,  0.9868,  0.1609,
          -0.2523, -0.8759,  0.7641, -0.6190,  1.1508,  0.8113, -0.5843,
          -1.6737, -0.3503,  0.4408,  0.8783, -0.7723, -0.9719, -1.5527,
           1.5790,  0.4931, -0.1290,  0.7938,  0.2049, -1.2594, -0.0447,
          -0.9786,  1.2151, -0.4813, -1.1090,  2.2127, -1.1518, -1.4773,
           0.0672,  0.9415,  1.7102, -0.5127,  0.9629,  0.3157, -0.5821,
           0.7984, -1.2003,  0.8660,  0.8621,  0.7764,  1.4267,  1.1908,
           0.2688,  0.7186,  0.8281, -0.5400, -0.2336,  0.4387, -1.3181,
           1.3434, -0.6203, -0.4894, -0.3921, -0.5203,  1.7610,  0.4928,
           0.7221,  1.2011,  0.7298,  1.0083, -1.2

(tensor([[[ 0.3125,  1.1110,  2.5300, -0.8677, -0.6701,  0.8130,  0.8854,
           1.1803, -0.0771, -0.1987,  0.0096, -0.6858, -0.7686, -1.9430,
          -0.1218, -0.6328, -1.0675,  0.0803,  0.7564, -0.6123, -1.1549,
           0.0140, -0.3958, -0.5268,  1.3671, -0.5576, -1.0964,  0.8003,
          -1.3204,  0.6286, -1.7599,  0.2477, -0.5289, -1.1711, -1.2622,
           1.8139,  0.3481, -0.7732, -0.3589,  0.0267, -1.4591,  1.0054,
           0.4345,  0.7271, -0.8810,  1.8014, -1.0001, -0.9877, -0.0271,
           1.4306, -0.5837,  2.2067, -1.5930, -0.3819,  0.7011,  0.4766,
          -0.6624, -0.4266, -0.2487, -1.2212, -0.0708, -0.8558, -0.2767,
          -1.2946,  0.3509,  0.3346, -0.1726, -0.7191,  0.4546,  1.1010,
          -0.5663,  1.0860,  1.2272,  1.2889,  0.1249, -0.5247,  0.0659,
           1.4471, -1.0483,  0.1425, -0.3376,  0.4345,  0.8079,  0.7084,
           1.0177, -1.2358,  1.4944,  0.6666,  1.0268, -1.2400,  0.3220,
           0.2100, -2.6804,  0.1378, -0.2494, -1.3

(tensor([[[-0.3586, -0.4935, -1.4580, -1.1335, -0.3257,  2.0823,  0.7269,
           0.5325, -0.7916,  0.9221,  0.0445, -0.0122,  0.7179, -0.6552,
          -1.7834, -0.5389,  1.4039,  0.8988, -0.3924, -1.5420,  1.7954,
           0.2983, -0.5502, -0.0369,  0.1749,  1.9214,  0.6152,  0.1443,
           0.0242, -0.9093,  0.9872,  0.3064, -0.1444, -0.3438, -0.0083,
          -0.2571, -2.2128,  1.1682, -0.6639,  1.2445, -0.4691, -0.2873,
           1.1801,  1.8503, -1.4499,  0.3650,  0.9163,  1.4555, -0.7756,
          -1.6544, -0.3930,  1.2703, -0.2433, -0.5616, -0.8512, -0.6068,
           0.6352, -0.4335,  1.2754, -0.0526, -1.2957, -1.1288, -1.7648,
          -1.1737,  0.6343, -0.2827,  0.4179,  0.4315,  0.0348,  1.5854,
          -0.8746, -0.2596, -0.8995, -0.3292,  0.1836, -1.1194,  0.7863,
          -0.6631,  0.3943, -0.1094, -0.1894,  0.5176,  1.1670,  0.3920,
          -0.3182, -0.3215,  0.7961,  0.6675,  1.1107,  1.2228,  0.8447,
          -0.8697, -1.4701, -1.4047,  1.5024, -1.8

(tensor([[[ 1.7461, -0.9538, -2.0993,  0.2771, -0.4368,  0.4160, -0.7784,
           2.5428,  0.0347, -0.8368, -0.3309, -0.4224,  0.2626, -1.5379,
          -0.2974, -0.0734,  1.0527, -1.4722,  0.5334,  1.5357, -0.4240,
          -1.0638, -0.5192, -0.3662, -1.0504,  0.3955,  0.2059,  0.5675,
          -1.6570,  0.8878, -0.0671, -0.0601,  1.0376,  0.0185,  0.0168,
          -0.0570,  0.5704,  0.9012,  0.7234, -0.2515,  2.3245,  0.0684,
          -0.3608, -0.1424,  0.4611,  0.3222, -0.5703,  1.0298,  0.3972,
          -1.0366, -2.2328,  1.3398, -0.4505,  0.3441, -0.8190, -0.8747,
          -0.3358,  0.6242, -0.2776, -0.2447, -0.5648, -0.5707, -0.1914,
           1.1617, -0.3034, -1.0681,  0.8335, -0.1779, -0.3712,  0.7676,
           0.4257,  0.7248,  0.1989, -0.2306, -0.7639, -1.0081,  0.6699,
           0.8190, -1.1139,  2.0954,  0.9737,  0.0135,  0.3959, -2.4827,
          -0.7346, -0.1479, -1.3060, -0.8524, -1.4398, -1.3762, -1.9594,
          -0.6014, -0.8378, -1.1022, -0.3282,  1.1

(tensor([[[ 0.2589,  2.0427, -0.9805,  0.0768,  1.0377, -0.3949,  0.0965,
           1.2495,  0.0143, -1.9658,  1.1838, -0.8295,  1.6260, -1.2727,
           0.0900,  1.2034,  1.1159, -2.2313, -0.7006,  0.2766,  0.0861,
           1.9647,  1.7513, -0.3231, -1.9829, -1.7489,  0.0631, -1.5741,
           0.2110,  0.3345,  1.2419, -0.4784,  0.0650,  0.2307, -0.4999,
          -0.2501,  0.8538, -0.5656, -0.7875,  0.7962,  0.1640, -0.8199,
          -2.0084,  0.1070, -0.9191,  1.3288, -0.4899, -0.4054, -1.6323,
           0.3049, -2.7075, -0.6550, -0.3366, -0.3271,  1.0126,  0.5318,
          -0.0605, -0.1088,  0.8303, -0.0152,  0.1872, -1.0559,  0.9556,
          -1.0070, -0.1560,  0.3835,  0.8108,  0.3381,  0.1732, -0.9277,
          -0.9031,  0.6379, -0.8772, -1.5627,  0.0833,  0.0937,  1.0575,
           2.0299, -2.0783, -0.9930,  0.3183,  0.8008,  0.4847, -0.2082,
          -0.5498,  1.8365,  0.3115, -1.0418,  1.4138, -0.9523, -0.9287,
          -1.0380,  0.6038,  0.1228,  0.4530, -0.8

In [12]:
for wl in words_list:
    print(wl)

['ad', 'es', 'anging', 'god']
['e', 'ens', 'enging', 'eng']
['de', 'ges', 'dirng', 'ged']
['a', 'ge', 'aging', 'ged']
['ro', 'les', 'ling', 'ling']
['y', 'as', 'ying', '']
['re', 'res', 'reing', 'or']
['e', 'es', 'inging', 'an']
['', 's', 'ging', 'ged']
['te', 'tes', 'tenging', 'tod']
['o', 'o', 'og', 'o']
['', 's', 'sing', 'wed']
['', 'se', 'ding', '']
['', 'se', 'ding', 'wed']
['se', 'ses', 'inging', 'ged']
['g', 'ges', 'ging', 'ged']
['a', 'as', 'aging', 'ad']
['se', 'es', 'seng', 'sed']
['pe', 'ges', 'ging', 'ged']
['g', 'gs', 'ging', 'g']
['', 'ers', 'ering', 'wed']
['', 'es', 'ing', 'wed']
['ce', 'ers', 'cering', 'ced']
['h', 'hes', 'hoing', 'hod']
['ge', 'ges', 'gering', 'ged']
['s', 'ses', 'sing', 'thed']
['', 'se', 'sing', 'wed']
['a', 'as', 'aring', 'ad']
['ad', 'ades', 'adging', 'ad']
['', 'res', 'ricng', '']
['d', 'des', 'ding', 'd']
['', 'ges', 'gering', 'god']
['as', 'ees', 'erming', 'wed']
['', 'tes', 'ting', 'ted']
['', 's', 'inging', 'wed']
['r', 'res', 'ro', 'ro']
[''

In [127]:
def evaluate(test_data, test_cond, encoder, decoder):
    
    test_score = 0
    for i in range(len(test_data)):
        input_data_tensor = torch.tensor([ch_to_ix[ch] for ch in test_data[i][0]], device=device)
        input_cond_tensor = torch.tensor(cond_to_ix[test_cond[i][0]], device=device)
        target_cond_tensor = torch.tensor(cond_to_ix[test_cond[i][1]], device=device)
        

In [128]:
a

tensor([[1]])

In [258]:
train_loss_list_cyclical = train_loss_list.copy()

In [259]:
test_score_list_cyclical = test_score_list.copy()

In [290]:
train_loss_list_monotonic = train_loss_list.copy()
test_score_list_monotonic = test_score_list.copy()

In [179]:
b.view(-1)

tensor([1, 1])

In [264]:
for i in trange(len(test_data)):
    input_data_tensor = torch.tensor([ch_to_ix[ch] for ch in test_data[i][0]], device=device).view(-1, 1)
    input_cond_tensor = torch.tensor(cond_to_ix[test_cond[i][0]], device=device)
    target_cond_tensor = torch.tensor(cond_to_ix[test_cond[i][1]], device=device)
    pred = predict(input_data_tensor, input_cond_tensor, target_cond_tensor, encoder, decoder)
    score = compute_bleu(pred, test_data[i][1])
    print(pred, test_data[i][1])

100%|██████████| 10/10 [00:00<00:00, 314.86it/s]

abandoned abandoned
abetting abetting
begins begins
expends expends
sends sends
splitting splitting
flare flare
function function
functioned functioned
heals heals





In [266]:
torch.empty(5).normal_()

tensor([ 0.7531, -1.5099,  0.5967,  0.7548, -0.4312])

In [268]:
t = encoder.init_hidden()

In [271]:
t[0].normal_()

tensor([[[ 0.4233, -0.2107,  1.3108, -0.4431, -1.2059,  0.1561, -0.6079,
          -1.0967,  0.0221,  1.1125, -1.6064,  0.3629,  1.0480, -0.8899,
          -0.8501, -0.0663, -0.3708, -1.1808,  2.3166, -1.3944, -1.2858,
           0.5931,  0.0908, -0.0376, -0.9505, -0.4532,  1.5534, -1.1819,
           0.1057, -0.7331,  1.4506, -0.1566,  0.2780,  0.3375, -0.1828,
          -0.8213,  0.5343, -1.4857,  0.2811,  1.7749, -0.4735,  0.5542,
          -1.4555, -0.6746, -0.8415,  0.9433, -0.7942, -0.0576,  0.9452,
          -1.9283, -0.1319, -1.5095,  0.5740, -0.3545, -0.2987, -0.4119,
           1.0933, -1.3812, -1.8943, -0.0090, -0.3866, -0.5335, -0.8637,
          -0.4162,  0.2573, -0.9612,  0.5147,  0.4524,  1.3464,  0.5403,
           1.2908, -1.7049, -0.4026,  1.9600,  0.6947, -0.0949, -0.9394,
          -0.7928,  2.4211,  0.9984,  1.1858, -1.0931, -0.2186,  0.4929,
           0.8675,  1.2076,  0.5002, -0.1441, -0.1893, -0.1253, -0.0392,
           0.5997, -0.3665,  0.3795, -2.2488,  1.17

In [305]:
# Encoder & Decoder
encoder_0 = EncoderRNN(vocab_size, hidden_size, latent_size, num_conds, cond_size).to(device)
decoder_0 = DecoderRNN(latent_size+cond_size, vocab_size).to(device)

# Train
encoder_0, decoder_0, train_loss_list_0, test_score_list_0 = train(encoder_0, decoder_0,
                                                           train_data, train_cond, test_data, test_cond,
                                                           kl_anealing_func=get_kl_anealing_func('cyclical', period=10000),
                                                           teacher_forcing_ratio=0,
                                                           num_epochs=num_epochs)

100%|██████████| 1227/1227 [01:43<00:00, 11.91it/s]
100%|██████████| 10/10 [00:00<00:00, 374.11it/s]
  0%|          | 2/1227 [00:00<01:40, 12.25it/s]

Epoch: 1 / 20

Train Cross Entropy Loss: 2.621463323750272
Train KL Loss: 11.985321844459648

Test Cross Entropy Loss: 14.8517427444458
Test KL Loss: 23.800907135009766
Test BLEU-4 Score: 0.06651283155781636




100%|██████████| 1227/1227 [01:48<00:00, 11.28it/s]
100%|██████████| 10/10 [00:00<00:00, 320.68it/s]
  0%|          | 1/1227 [00:00<02:03,  9.91it/s]

Epoch: 2 / 20

Train Cross Entropy Loss: 1.6178162841551882
Train KL Loss: 214.45321799296136

Test Cross Entropy Loss: 16.460832595825195
Test KL Loss: 348.7698669433594
Test BLEU-4 Score: 0.16390657654389604




100%|██████████| 1227/1227 [01:52<00:00, 10.92it/s]
100%|██████████| 10/10 [00:00<00:00, 319.36it/s]
  0%|          | 2/1227 [00:00<01:55, 10.58it/s]

Epoch: 3 / 20

Train Cross Entropy Loss: 0.8264309477619675
Train KL Loss: 696.0865751808993

Test Cross Entropy Loss: 11.405317306518555
Test KL Loss: 710.2158813476562
Test BLEU-4 Score: 0.34610668505824355




100%|██████████| 1227/1227 [01:51<00:00, 10.96it/s]
100%|██████████| 10/10 [00:00<00:00, 319.90it/s]
  0%|          | 2/1227 [00:00<01:42, 11.94it/s]

Epoch: 4 / 20

Train Cross Entropy Loss: 0.5043856120787013
Train KL Loss: 886.6391504702766

Test Cross Entropy Loss: 12.98766803741455
Test KL Loss: 668.0698852539062
Test BLEU-4 Score: 0.34238710874683637




100%|██████████| 1227/1227 [01:52<00:00, 10.94it/s]
100%|██████████| 10/10 [00:00<00:00, 304.87it/s]
  0%|          | 2/1227 [00:00<01:52, 10.88it/s]

Epoch: 5 / 20

Train Cross Entropy Loss: 0.34046243111603214
Train KL Loss: 999.5478166301948

Test Cross Entropy Loss: 6.3575758934021
Test KL Loss: 750.42822265625
Test BLEU-4 Score: 0.7126517329632864




100%|██████████| 1227/1227 [01:52<00:00, 10.91it/s]
100%|██████████| 10/10 [00:00<00:00, 304.89it/s]
  0%|          | 0/1227 [00:00<?, ?it/s]

Epoch: 6 / 20

Train Cross Entropy Loss: 0.24890896492648826
Train KL Loss: 1114.2334447465362

Test Cross Entropy Loss: 7.118893623352051
Test KL Loss: 845.7221069335938
Test BLEU-4 Score: 0.6578973522472031




100%|██████████| 1227/1227 [01:53<00:00, 10.82it/s]
100%|██████████| 10/10 [00:00<00:00, 304.46it/s]
  0%|          | 2/1227 [00:00<01:41, 12.01it/s]

Epoch: 7 / 20

Train Cross Entropy Loss: 0.19132963653169477
Train KL Loss: 1060.479935426098

Test Cross Entropy Loss: 6.423893928527832
Test KL Loss: 837.74072265625
Test BLEU-4 Score: 0.6625387587799281




100%|██████████| 1227/1227 [01:53<00:00, 10.82it/s]
100%|██████████| 10/10 [00:00<00:00, 315.03it/s]
  0%|          | 2/1227 [00:00<01:38, 12.41it/s]

Epoch: 8 / 20

Train Cross Entropy Loss: 0.14906951000851232
Train KL Loss: 1200.560304326164

Test Cross Entropy Loss: 3.622102737426758
Test KL Loss: 857.658203125
Test BLEU-4 Score: 0.7680911692785725




100%|██████████| 1227/1227 [01:53<00:00, 10.83it/s]
100%|██████████| 10/10 [00:00<00:00, 316.19it/s]
  0%|          | 2/1227 [00:00<01:31, 13.43it/s]

Epoch: 9 / 20

Train Cross Entropy Loss: 0.1178051424202001
Train KL Loss: 1303.7461278019505

Test Cross Entropy Loss: 3.413670301437378
Test KL Loss: 890.7767944335938
Test BLEU-4 Score: 0.7696875038579386




100%|██████████| 1227/1227 [01:53<00:00, 10.80it/s]
100%|██████████| 10/10 [00:00<00:00, 315.53it/s]
  0%|          | 2/1227 [00:00<01:42, 11.96it/s]

Epoch: 10 / 20

Train Cross Entropy Loss: 0.09401605211056993
Train KL Loss: 1281.4154630633548

Test Cross Entropy Loss: 2.247598171234131
Test KL Loss: 903.4953002929688
Test BLEU-4 Score: 0.7869228046212127




100%|██████████| 1227/1227 [01:53<00:00, 10.84it/s]
100%|██████████| 10/10 [00:00<00:00, 316.11it/s]
  0%|          | 1/1227 [00:00<02:09,  9.49it/s]

Epoch: 11 / 20

Train Cross Entropy Loss: 0.07688520355706009
Train KL Loss: 1261.7159842968117

Test Cross Entropy Loss: 2.621689558029175
Test KL Loss: 901.3770751953125
Test BLEU-4 Score: 0.8499992354272841




100%|██████████| 1227/1227 [01:56<00:00, 10.57it/s]
100%|██████████| 10/10 [00:00<00:00, 316.21it/s]
  0%|          | 2/1227 [00:00<01:45, 11.66it/s]

Epoch: 12 / 20

Train Cross Entropy Loss: 0.06294278175324432
Train KL Loss: 1275.8399910080923

Test Cross Entropy Loss: 2.1336379051208496
Test KL Loss: 916.1139526367188
Test BLEU-4 Score: 0.8755642034501456




100%|██████████| 1227/1227 [02:00<00:00, 10.15it/s]
100%|██████████| 10/10 [00:00<00:00, 316.92it/s]
  0%|          | 2/1227 [00:00<01:36, 12.70it/s]

Epoch: 13 / 20

Train Cross Entropy Loss: 0.0522277608977497
Train KL Loss: 1352.2467511847303

Test Cross Entropy Loss: 0.35409459471702576
Test KL Loss: 937.83056640625
Test BLEU-4 Score: 0.9778800783071404




100%|██████████| 1227/1227 [01:52<00:00, 10.88it/s]
100%|██████████| 10/10 [00:00<00:00, 319.20it/s]
  0%|          | 0/1227 [00:00<?, ?it/s]

Epoch: 14 / 20

Train Cross Entropy Loss: 0.04140846650010831
Train KL Loss: 1292.917116755346

Test Cross Entropy Loss: 0.16734234988689423
Test KL Loss: 973.7716674804688
Test BLEU-4 Score: 1.0




100%|██████████| 1227/1227 [01:52<00:00, 10.88it/s]
100%|██████████| 10/10 [00:00<00:00, 316.16it/s]
  0%|          | 2/1227 [00:00<01:40, 12.14it/s]

Epoch: 15 / 20

Train Cross Entropy Loss: 0.03007307338749804
Train KL Loss: 1423.9595105607236

Test Cross Entropy Loss: 0.5195269584655762
Test KL Loss: 1001.0723876953125
Test BLEU-4 Score: 0.9707106781186547




100%|██████████| 1227/1227 [01:54<00:00, 10.73it/s]
100%|██████████| 10/10 [00:00<00:00, 312.68it/s]
  0%|          | 2/1227 [00:00<01:47, 11.35it/s]

Epoch: 16 / 20

Train Cross Entropy Loss: 0.02268237265294022
Train KL Loss: 1410.627798905874

Test Cross Entropy Loss: 0.13444827497005463
Test KL Loss: 996.9894409179688
Test BLEU-4 Score: 1.0




100%|██████████| 1227/1227 [01:53<00:00, 10.84it/s]
100%|██████████| 10/10 [00:00<00:00, 308.19it/s]
  0%|          | 1/1227 [00:00<02:04,  9.87it/s]

Epoch: 17 / 20

Train Cross Entropy Loss: 0.017328348619484627
Train KL Loss: 1339.4184202449667

Test Cross Entropy Loss: 0.1312038004398346
Test KL Loss: 997.1080932617188
Test BLEU-4 Score: 1.0




100%|██████████| 1227/1227 [01:52<00:00, 10.86it/s]
100%|██████████| 10/10 [00:00<00:00, 310.23it/s]
  0%|          | 2/1227 [00:00<01:50, 11.09it/s]

Epoch: 18 / 20

Train Cross Entropy Loss: 0.015266353100239917
Train KL Loss: 1291.9690145934544

Test Cross Entropy Loss: 0.10948039591312408
Test KL Loss: 1045.0169677734375
Test BLEU-4 Score: 1.0




100%|██████████| 1227/1227 [01:53<00:00, 10.79it/s]
100%|██████████| 10/10 [00:00<00:00, 318.07it/s]
  0%|          | 2/1227 [00:00<01:45, 11.61it/s]

Epoch: 19 / 20

Train Cross Entropy Loss: 0.014009266868512523
Train KL Loss: 1318.8346170717139

Test Cross Entropy Loss: 0.10826517641544342
Test KL Loss: 1026.8331298828125
Test BLEU-4 Score: 1.0




100%|██████████| 1227/1227 [01:51<00:00, 10.96it/s]
100%|██████████| 10/10 [00:00<00:00, 320.06it/s]

Epoch: 20 / 20

Train Cross Entropy Loss: 0.010027530158666104
Train KL Loss: 1340.1527442787044

Test Cross Entropy Loss: 0.10461216419935226
Test KL Loss: 1019.2639770507812
Test BLEU-4 Score: 1.0







In [306]:
# Encoder & Decoder
encoder_25 = EncoderRNN(vocab_size, hidden_size, latent_size, num_conds, cond_size).to(device)
decoder_25 = DecoderRNN(latent_size+cond_size, vocab_size).to(device)

# Train
encoder_25, decoder_25, train_loss_list_25, test_score_list_25 = train(encoder_25, decoder_25,
                                                           train_data, train_cond, test_data, test_cond,
                                                           kl_anealing_func=get_kl_anealing_func('cyclical', period=10000),
                                                           teacher_forcing_ratio=0.25,
                                                           num_epochs=num_epochs)

100%|██████████| 1227/1227 [01:45<00:00, 11.66it/s]
100%|██████████| 10/10 [00:00<00:00, 318.78it/s]
  0%|          | 1/1227 [00:00<02:04,  9.81it/s]

Epoch: 1 / 20

Train Cross Entropy Loss: 2.5685994831667553
Train KL Loss: 12.619815396375223

Test Cross Entropy Loss: 16.137582778930664
Test KL Loss: 33.84439468383789
Test BLEU-4 Score: 0.09928773400135933




100%|██████████| 1227/1227 [01:50<00:00, 11.08it/s]
100%|██████████| 10/10 [00:00<00:00, 319.46it/s]
  0%|          | 2/1227 [00:00<01:24, 14.56it/s]

Epoch: 2 / 20

Train Cross Entropy Loss: 1.5874917947810423
Train KL Loss: 483.41315448941566

Test Cross Entropy Loss: 12.941168785095215
Test KL Loss: 328.1576843261719
Test BLEU-4 Score: 0.19489470099666256




100%|██████████| 1227/1227 [01:50<00:00, 11.08it/s]
100%|██████████| 10/10 [00:00<00:00, 318.75it/s]
  0%|          | 2/1227 [00:00<01:51, 10.94it/s]

Epoch: 3 / 20

Train Cross Entropy Loss: 0.7520689498848203
Train KL Loss: 101685.26882862249

Test Cross Entropy Loss: 9.478663444519043
Test KL Loss: 785.9896850585938
Test BLEU-4 Score: 0.3834410210779974




100%|██████████| 1227/1227 [01:50<00:00, 11.06it/s]
100%|██████████| 10/10 [00:00<00:00, 318.67it/s]
  0%|          | 1/1227 [00:00<02:17,  8.90it/s]

Epoch: 4 / 20

Train Cross Entropy Loss: 0.443525070052683
Train KL Loss: 398068.3824653377

Test Cross Entropy Loss: 10.653193473815918
Test KL Loss: 631.7128295898438
Test BLEU-4 Score: 0.4660549870980386




100%|██████████| 1227/1227 [01:51<00:00, 11.03it/s]
100%|██████████| 10/10 [00:00<00:00, 319.92it/s]
  0%|          | 2/1227 [00:00<01:42, 11.92it/s]

Epoch: 5 / 20

Train Cross Entropy Loss: 0.3029331322700565
Train KL Loss: 215565.5965370032

Test Cross Entropy Loss: 10.031100273132324
Test KL Loss: 758.125
Test BLEU-4 Score: 0.5809339673738876




100%|██████████| 1227/1227 [01:50<00:00, 11.14it/s]
100%|██████████| 10/10 [00:00<00:00, 318.68it/s]
  0%|          | 2/1227 [00:00<01:38, 12.40it/s]

Epoch: 6 / 20

Train Cross Entropy Loss: 0.22221282162400977
Train KL Loss: 349957.082692845

Test Cross Entropy Loss: 4.089372634887695
Test KL Loss: 5034.390625
Test BLEU-4 Score: 0.6066492299962756




100%|██████████| 1227/1227 [01:52<00:00, 10.91it/s]
100%|██████████| 10/10 [00:00<00:00, 304.26it/s]
  0%|          | 0/1227 [00:00<?, ?it/s]

Epoch: 7 / 20

Train Cross Entropy Loss: 0.1663562333025753
Train KL Loss: 167405.9690446369

Test Cross Entropy Loss: 5.625385761260986
Test KL Loss: 9439.4130859375
Test BLEU-4 Score: 0.8607682328855878




100%|██████████| 1227/1227 [01:52<00:00, 10.93it/s]
100%|██████████| 10/10 [00:00<00:00, 306.24it/s]
  0%|          | 2/1227 [00:00<01:44, 11.77it/s]

Epoch: 8 / 20

Train Cross Entropy Loss: 0.12931597599712671
Train KL Loss: 143006.0809116151

Test Cross Entropy Loss: 3.515899658203125
Test KL Loss: 1154.298095703125
Test BLEU-4 Score: 0.8410910972520058




100%|██████████| 1227/1227 [01:52<00:00, 10.92it/s]
100%|██████████| 10/10 [00:00<00:00, 307.25it/s]
  0%|          | 2/1227 [00:00<01:50, 11.08it/s]

Epoch: 9 / 20

Train Cross Entropy Loss: 0.1080077605705184
Train KL Loss: 100390.28651848502

Test Cross Entropy Loss: 1.7555128335952759
Test KL Loss: 1369.3758544921875
Test BLEU-4 Score: 0.8344030470095112




100%|██████████| 1227/1227 [01:52<00:00, 10.95it/s]
100%|██████████| 10/10 [00:00<00:00, 320.18it/s]
  0%|          | 1/1227 [00:00<02:15,  9.04it/s]

Epoch: 10 / 20

Train Cross Entropy Loss: 0.08262321467484743
Train KL Loss: 126266.77535346305

Test Cross Entropy Loss: 3.730701208114624
Test KL Loss: 3158.57568359375
Test BLEU-4 Score: 0.7888460035803996




100%|██████████| 1227/1227 [01:52<00:00, 10.95it/s]
100%|██████████| 10/10 [00:00<00:00, 310.75it/s]
  0%|          | 1/1227 [00:00<02:03,  9.94it/s]

Epoch: 11 / 20

Train Cross Entropy Loss: 0.06560634876807536
Train KL Loss: 171685.8791945524

Test Cross Entropy Loss: 3.4446377754211426
Test KL Loss: 2603.660888671875
Test BLEU-4 Score: 0.8923419757923693




100%|██████████| 1227/1227 [01:51<00:00, 11.03it/s]
100%|██████████| 10/10 [00:00<00:00, 311.51it/s]
  0%|          | 2/1227 [00:00<01:51, 10.96it/s]

Epoch: 12 / 20

Train Cross Entropy Loss: 0.051672571049667305
Train KL Loss: 172760.76819091398

Test Cross Entropy Loss: 5.586647033691406
Test KL Loss: 3461.158203125
Test BLEU-4 Score: 0.8332729998656294




100%|██████████| 1227/1227 [01:53<00:00, 10.77it/s]
100%|██████████| 10/10 [00:00<00:00, 312.77it/s]
  0%|          | 1/1227 [00:00<02:03,  9.91it/s]

Epoch: 13 / 20

Train Cross Entropy Loss: 0.043007535117245255
Train KL Loss: 164930.18295369542

Test Cross Entropy Loss: 0.25629639625549316
Test KL Loss: 2680.843505859375
Test BLEU-4 Score: 1.0




100%|██████████| 1227/1227 [01:52<00:00, 10.91it/s]
100%|██████████| 10/10 [00:00<00:00, 318.46it/s]
  0%|          | 2/1227 [00:00<01:53, 10.80it/s]

Epoch: 14 / 20

Train Cross Entropy Loss: 0.03447036396330054
Train KL Loss: 301522.57414801593

Test Cross Entropy Loss: 0.17907775938510895
Test KL Loss: 1947.7625732421875
Test BLEU-4 Score: 1.0




100%|██████████| 1227/1227 [01:51<00:00, 10.98it/s]
100%|██████████| 10/10 [00:00<00:00, 316.51it/s]
  0%|          | 2/1227 [00:00<01:28, 13.89it/s]

Epoch: 15 / 20

Train Cross Entropy Loss: 0.02487293364195401
Train KL Loss: 232618.01076109518

Test Cross Entropy Loss: 0.16274933516979218
Test KL Loss: 2250.348388671875
Test BLEU-4 Score: 1.0




100%|██████████| 1227/1227 [01:50<00:00, 11.11it/s]
100%|██████████| 10/10 [00:00<00:00, 318.32it/s]
  0%|          | 2/1227 [00:00<01:41, 12.09it/s]

Epoch: 16 / 20

Train Cross Entropy Loss: 0.020511650764866187
Train KL Loss: 227229.60537272037

Test Cross Entropy Loss: 0.12054647505283356
Test KL Loss: 3371.32666015625
Test BLEU-4 Score: 1.0




100%|██████████| 1227/1227 [01:51<00:00, 10.97it/s]
100%|██████████| 10/10 [00:00<00:00, 307.11it/s]
  0%|          | 2/1227 [00:00<01:42, 11.93it/s]

Epoch: 17 / 20

Train Cross Entropy Loss: 0.014768759800642132
Train KL Loss: 229109.37019875838

Test Cross Entropy Loss: 0.10595454275608063
Test KL Loss: 2591.80126953125
Test BLEU-4 Score: 1.0




100%|██████████| 1227/1227 [01:50<00:00, 11.10it/s]
100%|██████████| 10/10 [00:00<00:00, 316.19it/s]
  0%|          | 2/1227 [00:00<01:49, 11.19it/s]

Epoch: 18 / 20

Train Cross Entropy Loss: 0.01236524084364917
Train KL Loss: 286974.92428928986

Test Cross Entropy Loss: 0.08162911981344223
Test KL Loss: 3501.729248046875
Test BLEU-4 Score: 1.0




100%|██████████| 1227/1227 [01:53<00:00, 10.84it/s]
100%|██████████| 10/10 [00:00<00:00, 315.50it/s]
  0%|          | 2/1227 [00:00<01:42, 11.98it/s]

Epoch: 19 / 20

Train Cross Entropy Loss: 0.010259658841725949
Train KL Loss: 330630.84126113646

Test Cross Entropy Loss: 0.09209506213665009
Test KL Loss: 3839.728271484375
Test BLEU-4 Score: 1.0




100%|██████████| 1227/1227 [01:51<00:00, 10.96it/s]
100%|██████████| 10/10 [00:00<00:00, 308.16it/s]

Epoch: 20 / 20

Train Cross Entropy Loss: 0.009653625063210615
Train KL Loss: 315045.04113369406

Test Cross Entropy Loss: 0.055579137057065964
Test KL Loss: 3017.506591796875
Test BLEU-4 Score: 1.0







In [None]:
# Encoder & Decoder
encoder_50 = EncoderRNN(vocab_size, hidden_size, latent_size, num_conds, cond_size).to(device)
decoder_50 = DecoderRNN(latent_size+cond_size, vocab_size).to(device)

# Train
encoder_50, decoder_50, train_loss_list_50, test_score_list_50 = train(encoder_50, decoder_50,
                                                           train_data, train_cond, test_data, test_cond,
                                                           kl_anealing_func=get_kl_anealing_func('cyclical', period=10000),
                                                           teacher_forcing_ratio=0.50,
                                                           num_epochs=num_epochs)

100%|██████████| 1227/1227 [01:48<00:00, 11.31it/s]
100%|██████████| 10/10 [00:00<00:00, 324.86it/s]
  0%|          | 2/1227 [00:00<01:46, 11.53it/s]

Epoch: 1 / 20

Train Cross Entropy Loss: 2.4089470976754654
Train KL Loss: 21.66898718494917

Test Cross Entropy Loss: 16.570415496826172
Test KL Loss: 60.777400970458984
Test BLEU-4 Score: 0.14470810260043157




100%|██████████| 1227/1227 [01:53<00:00, 10.80it/s]
100%|██████████| 10/10 [00:00<00:00, 314.86it/s]
  0%|          | 2/1227 [00:00<01:33, 13.05it/s]

Epoch: 2 / 20

Train Cross Entropy Loss: 1.3119903259784897
Train KL Loss: 272.65956233908577

Test Cross Entropy Loss: 12.476791381835938
Test KL Loss: 368.546630859375
Test BLEU-4 Score: 0.2768408045530575




100%|██████████| 1227/1227 [01:51<00:00, 11.01it/s]
100%|██████████| 10/10 [00:00<00:00, 314.48it/s]
  0%|          | 2/1227 [00:00<01:41, 12.02it/s]

Epoch: 3 / 20

Train Cross Entropy Loss: 0.5790610207586062
Train KL Loss: 669.0296185256929

Test Cross Entropy Loss: 13.550272941589355
Test KL Loss: 513.8779296875
Test BLEU-4 Score: 0.4620740188731863




100%|██████████| 1227/1227 [01:50<00:00, 11.15it/s]
100%|██████████| 10/10 [00:00<00:00, 321.69it/s]
  0%|          | 1/1227 [00:00<02:04,  9.87it/s]

Epoch: 4 / 20

Train Cross Entropy Loss: 0.34994948927768
Train KL Loss: 1251.3374488413965

Test Cross Entropy Loss: 13.2019624710083
Test KL Loss: 567.2276611328125
Test BLEU-4 Score: 0.41598079182848435




  5%|▌         | 66/1227 [00:05<01:38, 11.74it/s]

In [None]:
# Encoder & Decoder
encoder_75 = EncoderRNN(vocab_size, hidden_size, latent_size, num_conds, cond_size).to(device)
decoder_75 = DecoderRNN(latent_size+cond_size, vocab_size).to(device)

# Train
encoder_75, decoder_75, train_loss_list_75, test_score_list_75 = train(encoder_75, decoder_75,
                                                           train_data, train_cond, test_data, test_cond,
                                                           kl_anealing_func=get_kl_anealing_func('cyclical', period=10000),
                                                           teacher_forcing_ratio=0.75,
                                                           num_epochs=num_epochs)

In [None]:
# Encoder & Decoder
encoder_100 = EncoderRNN(vocab_size, hidden_size, latent_size, num_conds, cond_size).to(device)
decoder_100 = DecoderRNN(latent_size+cond_size, vocab_size).to(device)

# Train
encoder_100, decoder_100, train_loss_list_100, test_score_list_100 = train(encoder_100, decoder_100,
                                                           train_data, train_cond, test_data, test_cond,
                                                           kl_anealing_func=get_kl_anealing_func('cyclical', period=10000),
                                                           teacher_forcing_ratio=1,
                                                           num_epochs=num_epochs)

In [14]:
Gaussian_score(words_list)

0.0

In [16]:
plt.plot(test_stest_score_list_0core_list_0)

NameError: name 'test_score_list_0' is not defined

In [17]:
test_score_list_0

NameError: name 'test_score_list_0' is not defined

In [13]:
def generate_words(encoder, decoder, num=100):
    
    encoder.eval()
    decoder.eval()
    
    cond_size = encoder.cond_size
    embedded_conditions = {}
    for cond in conditions:
        cond_tensor = torch.tensor(cond_to_ix[cond], device=device)
        embedded_conditions[cond] = encoder.condition_embedding(cond_tensor)
    
    words_list = []
    with torch.no_grad():
        for _ in range(num):
            noise = decoder.init_hidden()

            # Generate Gaussian noise
            noise = (noise[0].normal_(std=0.7), noise[1].normal_(std=0.7))

            # Generate words with 4 different tenses
            words = []
            for embedded_cond in embedded_conditions.values():
                
                noise[0][:, :, -cond_size:] = embedded_cond
                noise[1][:, :, -cond_size:] = embedded_cond
                
                decoder_input = torch.tensor([[SOS_token]], device=device)
                decoder_hidden = noise
            
                word = ''
                while True:
                    decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                    topv, topi = decoder_output.topk(1)
                    decoder_input = topi.squeeze().detach()
                    
                    if decoder_input == EOS_token:
                        break
                        
                    word += ix_to_ch[decoder_input.item()]
                words.append(word)
            words_list.append(words)
    return words_list

In [14]:
# Define hyperparameters
hidden_size = 512
latent_size = 64
vocab_size = 28
num_conds = 4
cond_size = 32
teacher_forcing_ratio = 0.5
learning_rate = 0.005
MAX_LENGTH = 40
num_epochs = 20

# Encoder & Decoder
encoder = EncoderRNN(vocab_size, hidden_size, latent_size, num_conds, cond_size).to(device)
decoder = DecoderRNN(latent_size+cond_size, vocab_size).to(device)
E

# Evaluate


# Generate
words_list = generate_words(encoder, decoder, num=100)
for words in words_list:
    print(words)
print(Gaussian_score(words_list))

['u', 'ses', 'inding', 'aded']
['y', 'ses', 'ying', 'od']
['e', 'es', 'necing', 'ed']
['', 'es', 'iding', 'ked']
['y', 'ys', 'ying', 'ad']
['a', 'ds', 'zing', 'ad']
['de', 'ses', 'ding', 'ded']
['y', 'eds', 'ying', 'oded']
['fe', 'fes', 'ying', 'fed']
['e', 'es', 'ending', 'd']
['y', 'es', 'eading', 'ad']
['e', 'es', 'eying', 'ed']
['e', 'es', 'ling', 'ad']
['id', 'eds', 'eding', 'ud']
['e', 'eds', 'eding', 'ed']
['u', 'ies', 'ying', 'ud']
['ee', 'eds', 'eying', 'ad']
['y', 'es', 'eying', 'wed']
['ud', 'ses', 'ging', 'ad']
['ue', 'eds', 'eding', 'ued']
['e', 'es', 'eeing', 'ad']
['w', 'ses', 'wing', 'wed']
['y', 'eds', 'eaing', 'ded']
['e', 'es', 'ezing', 'oud']
['', 'ses', 'ying', 'wed']
['d', 'des', 'ding', 'wed']
['uy', 'is', 'ing', 'ung']
['o', 'es', 'eding', 'ad']
['ie', 'es', 'inging', 'od']
['ue', 'ses', 'neing', 'wed']
['e', 'es', 'eving', 'ad']
['ed', 'des', 'deing', 'od']
['', 's', 'sing', 'de']
['e', 'es', 'ening', 'ed']
['u', 'eds', 'ling', 'ad']
['y', 'ges', 'yeding', 'wed