## Neural and N-Gram Language Model for CPSC 503 Assignment 2
#### The neural model notebook is modified from Yunjey Choi's Github repository - pytorch-tutorial.
#### Familiarize yourself with pytorch, start with: https://pytorch.org/tutorials/beginner/basics/intro.html

#### The N-gram model notebook is from Josh Loehr's Github repository - ngram-language-model.
#### https://github.com/joshualoehr/ngram-language-model

#### ========================================================================================

### Let's load a number of dependencies:

In [35]:
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
import torch.nn.functional as F
import numpy as np
import os
import math
from collections import defaultdict
from tqdm import tqdm

# check if GPU is available to pytorch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Here are two classes needed for the data loading and formating:

In [36]:
class Dictionary(object):
    # use to generate and return the word-to-index (index-to-word) vocabulary dictionary
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0
    
    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def return_dict(self):
        return self.idx2word
        
    def __len__(self):
        return len(self.word2idx)


class Corpus(object):
    # load and prepare the corpus  the language models input format
    def __init__(self):
        self.dictionary = Dictionary()

    def get_data(self, path, n_gram=2):
        with open(path, 'r') as f:
            tokens = 0;
            sample_list = []
            for line in f:
                # add <start> tokens based on the number of n-grams.
                words = ['<start>'] * (n_gram - 1) + line.split() + ['<end>']
                tokens += len(words)
                sample_list.append(words)
                for word in words: 
                    self.dictionary.add_word(word)  

        # # Read corpus and store the each line (word sequence) into its corresponding index sequence.
        ids_list = [[0]*len(s) for s in sample_list if len(s) > n_gram] 
        with open(path, 'r') as f:
            sample_num = 0
            for line in f:
                token = 0
                words = ['<start>'] * (n_gram - 1) + line.split() + ['<end>']
                if len(words) >= n_gram:
                    for word in words:
                        ids_list[sample_num][token] = self.dictionary.word2idx[word]
                        token += 1
                    sample_num += 1

        # FOR THE NEURAL MODEL Convert the flat index sequences into the n-gram tensors which are ready for neural model.
        for n in range(len(ids_list)):
            flat_ids = ids_list[n]
            ids_list[n] = torch.LongTensor([flat_ids[i:i+n_gram] for i in range(len(flat_ids)-(n_gram - 1))])
        return ids_list

### Hyper-parameters for both Language Models: 

In [37]:
# Bigram model
n_gram = 2

# m_gram is the number of preceding/conditioning words 
m_gram = n_gram - 1

In [78]:
corpus = Corpus()
ids = corpus.get_data('data/train_mini.txt', n_gram)
# ids = corpus.get_data('data/train.txt', n_gram)


# Use 70% for training, 15% for development, and 15% for testing 
n_train = round(len(ids) * .75)
n_dev = round(len(ids) * .15)

train_ids = ids[:n_train]
dev_ids = ids[n_train:n_train + n_dev]
test_ids = ids[n_train + n_dev:]

print(f"Number of sentences: {len(train_ids)} train, {len(dev_ids)} dev, {len(test_ids)} test")
vocab_size = len(corpus.dictionary)
print(f"Vocab size: {vocab_size}")

Number of sentences: 786 train, 157 dev, 105 test
Vocab size: 3580


### Load the "Penn Treebank" dataset and split it into train/dev/test

### The class of count-based language model:

In [79]:
class CountLM(object):
    def __init__(self, vocab_size, x_n, x_m, laplace=1):
        self.vocab_size = vocab_size
        self.laplace = laplace
        
        # Dictionaries for tracking the count of n-grams
        self.n_gram_count = self.count_ngrams(x_n)
        self.m_gram_count = self.count_ngrams(x_m)
    
    def count_ngrams(self, x):
        """
        Populate the dictionary with the number of occurrences of each n-gram
        """
        count_list = defaultdict(int)
        for example in x:
            for n_gram in example.tolist():
                count_list[tuple(n_gram)] += 1 
                
        return count_list
    
    def compute_mle(self, n_gram):
        """
        Compute the MLE of P(w_n|w_{n−1}, ...) with add-one Laplacian smoothing
        
        Please see chapter 3.5.1 of J&M 3rd Ed. for more information
        """
        n_count = self.n_gram_count[n_gram]
        
        m_gram = n_gram[:-1]
        m_count = self.m_gram_count[m_gram]
        prob = (n_count + self.laplace) / (m_count + self.laplace * self.vocab_size) 
        return prob


### Train the n-gram model based on MLE

In [103]:
m_gram_train_ids = corpus.get_data('data/train_mini.txt', m_gram)[:-200]
# m_gram_train_ids = corpus.get_data('data/train.txt', m_gram)[:n_train]
n_gram_train_ids = train_ids

# Populate the dictionaries with counts from training corpus
count_model = CountLM(vocab_size, n_gram_train_ids, m_gram_train_ids)

# Compute average perplexity on training set
train_ppl = 0
for i in range(0, len(train_ids)):
    
    probabilities = list(map(lambda x: count_model.compute_mle(tuple(x)), train_ids[i].tolist()))
    perplexity = np.exp(sum(-np.log(probabilities)) / len(train_ids[i]))
    train_ppl += perplexity
#     print(probabilities)
print('The average training perplexity for count-based LM: '+str(train_ppl/len(train_ids)))

The average training perplexity for count-based LM: 1026.5023506565244


In [92]:
we = corpus.dictionary.word2idx["we"]
want = corpus.dictionary.word2idx["want"]
count_model.compute_mle((we, want))
# count_model.n_gram_count

0.000278473962684489

conditional prob of the word given the previous context
get all the words in the vocab...
given train IDs how do I map it to a word?
How do I loop through all the words and get all the probabilities?
normalize the sum of the count to 1

In [153]:
the = corpus.dictionary.word2idx["the"]
zero = corpus.dictionary.idx2word[0]
count_model.compute_mle((0, the))
x = count_model.n_gram_count

In [168]:
the = corpus.dictionary.word2idx["the"]
zero = corpus.dictionary.idx2word[0]
count_model.compute_mle((0, the))
x = count_model.n_gram_count
while True:
    value = 0
    key
    for pair in x:
    #     print(pair)
        try:
            prob = count_model.compute_mle(pair)
            if prob > value:
                value = prob
                key = pair

        except TypeError:
            # hit end of dict
            
            break
    prob = count_model.compute_mle(key)
    x = count_model.n_gram_count
    print(prob)
    
    
# print(key, value)
# print(corpus.dictionary.idx2word[key[1]])
# prob = count_model.compute_mle(key)
# print(prob)

0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665
0.03910614525139665


KeyboardInterrupt: 

In [8]:
# Compute average perplexity on testing set
test_ppl = 0
for i in range(0, len(test_ids)):

    probabilities = list(map(lambda x: count_model.compute_mle(tuple(x)), test_ids[i].tolist()))
    perplexity = np.exp(sum(-np.log(probabilities)) / len(test_ids[i]))
    print('Perplexity for test sample '+str(i)+' :', perplexity)
    test_ppl += perplexity
    
print('The average testing perplexity for count-based LM: '+str(test_ppl/len(test_ids)))

Perplexity for test sample 0 : 481.93397924150236
Perplexity for test sample 1 : 1791.0646372289625
Perplexity for test sample 2 : 295.4281859753384
Perplexity for test sample 3 : 336.8972086530059
Perplexity for test sample 4 : 188.3063912073697
Perplexity for test sample 5 : 1580.4694541715269
Perplexity for test sample 6 : 908.3518812063347
Perplexity for test sample 7 : 587.567516576945
Perplexity for test sample 8 : 466.1199712311848
Perplexity for test sample 9 : 710.2569626653177
Perplexity for test sample 10 : 3201.2726581592956
Perplexity for test sample 11 : 731.7403780164503
Perplexity for test sample 12 : 902.777634164664
Perplexity for test sample 13 : 2107.650302714663
Perplexity for test sample 14 : 867.2708592238805
Perplexity for test sample 15 : 1282.4375182196761
Perplexity for test sample 16 : 516.5791519391194
Perplexity for test sample 17 : 1637.3482160822728
Perplexity for test sample 18 : 891.4693914511101
Perplexity for test sample 19 : 1841.6775228827773
Perpl

Perplexity for test sample 522 : 383.2279527359837
Perplexity for test sample 523 : 1163.9777024650646
Perplexity for test sample 524 : 465.934322506438
Perplexity for test sample 525 : 1334.0055729360536
Perplexity for test sample 526 : 2180.096615174918
Perplexity for test sample 527 : 1108.7488559401008
Perplexity for test sample 528 : 462.42303691332467
Perplexity for test sample 529 : 1829.5391982113467
Perplexity for test sample 530 : 680.7873439128163
Perplexity for test sample 531 : 1075.720921572503
Perplexity for test sample 532 : 1439.8581346877356
Perplexity for test sample 533 : 774.3979944399077
Perplexity for test sample 534 : 703.7149869261153
Perplexity for test sample 535 : 1840.1099473425838
Perplexity for test sample 536 : 1048.4176752574454
Perplexity for test sample 537 : 1971.689039941437
Perplexity for test sample 538 : 1139.6771184979086
Perplexity for test sample 539 : 472.15129462701213
Perplexity for test sample 540 : 269.29719943722847
Perplexity for test s

Perplexity for test sample 1013 : 1711.7082329941663
Perplexity for test sample 1014 : 379.5192028280277
Perplexity for test sample 1015 : 2401.26656732838
Perplexity for test sample 1016 : 843.8697037125777
Perplexity for test sample 1017 : 284.7970686294181
Perplexity for test sample 1018 : 798.2743772249205
Perplexity for test sample 1019 : 454.1174549175305
Perplexity for test sample 1020 : 422.26318313376515
Perplexity for test sample 1021 : 685.499143250206
Perplexity for test sample 1022 : 844.0352499114954
Perplexity for test sample 1023 : 1612.4572589188404
Perplexity for test sample 1024 : 1360.7913047868365
Perplexity for test sample 1025 : 2522.872881530421
Perplexity for test sample 1026 : 1094.4399862834046
Perplexity for test sample 1027 : 508.29628617229054
Perplexity for test sample 1028 : 1106.0300800743587
Perplexity for test sample 1029 : 948.5602366707981
Perplexity for test sample 1030 : 770.1203563949285
Perplexity for test sample 1031 : 1801.8496264164366
Perple

Perplexity for test sample 1531 : 577.1488602828803
Perplexity for test sample 1532 : 299.9255184957937
Perplexity for test sample 1533 : 827.8076804668558
Perplexity for test sample 1534 : 705.3064289852103
Perplexity for test sample 1535 : 663.849079561899
Perplexity for test sample 1536 : 711.7755184372891
Perplexity for test sample 1537 : 2288.6995492379706
Perplexity for test sample 1538 : 402.1920779200347
Perplexity for test sample 1539 : 374.4837679885483
Perplexity for test sample 1540 : 301.62385188074904
Perplexity for test sample 1541 : 268.5909645429137
Perplexity for test sample 1542 : 480.63808049208467
Perplexity for test sample 1543 : 1405.4494167726066
Perplexity for test sample 1544 : 1233.6836836396255
Perplexity for test sample 1545 : 964.158641688546
Perplexity for test sample 1546 : 2632.0574769360564
Perplexity for test sample 1547 : 736.6576720700164
Perplexity for test sample 1548 : 768.7408603227124
Perplexity for test sample 1549 : 1089.5486884190204
Perplex

Perplexity for test sample 2043 : 861.2297931704337
Perplexity for test sample 2044 : 110.09632460794144
Perplexity for test sample 2045 : 140.54405500883854
Perplexity for test sample 2046 : 1048.592085992945
Perplexity for test sample 2047 : 612.4826962422168
Perplexity for test sample 2048 : 946.1144442717643
Perplexity for test sample 2049 : 370.91632651931025
Perplexity for test sample 2050 : 255.6946883265172
Perplexity for test sample 2051 : 271.5976210243736
Perplexity for test sample 2052 : 322.3281938989347
Perplexity for test sample 2053 : 423.3829105781211
Perplexity for test sample 2054 : 495.4090954098413
Perplexity for test sample 2055 : 187.84196154116012
Perplexity for test sample 2056 : 683.4962631886007
Perplexity for test sample 2057 : 503.32404625354593
Perplexity for test sample 2058 : 1276.2226596241794
Perplexity for test sample 2059 : 634.3246230358413
Perplexity for test sample 2060 : 382.6625850446478
Perplexity for test sample 2061 : 3283.5156372008087
Perpl

Perplexity for test sample 2579 : 605.0602644009936
Perplexity for test sample 2580 : 1442.6644536325182
Perplexity for test sample 2581 : 182.68606183602927
Perplexity for test sample 2582 : 425.6359296698872
Perplexity for test sample 2583 : 748.6310597242529
Perplexity for test sample 2584 : 1141.9374107897643
Perplexity for test sample 2585 : 820.109041897057
Perplexity for test sample 2586 : 732.5077261378736
Perplexity for test sample 2587 : 1582.8616487707743
Perplexity for test sample 2588 : 1028.554216184425
Perplexity for test sample 2589 : 977.2569026619584
Perplexity for test sample 2590 : 1513.6263758858752
Perplexity for test sample 2591 : 214.30865853296692
Perplexity for test sample 2592 : 4331.1556271852505
Perplexity for test sample 2593 : 1030.9518452066873
Perplexity for test sample 2594 : 2282.1208493960876
Perplexity for test sample 2595 : 248.37353419604037
Perplexity for test sample 2596 : 381.4080662096307
Perplexity for test sample 2597 : 550.7581864087533
Per

Perplexity for test sample 3028 : 1261.8112982701539
Perplexity for test sample 3029 : 2006.224132054
Perplexity for test sample 3030 : 1739.668878912662
Perplexity for test sample 3031 : 2617.620305319237
Perplexity for test sample 3032 : 1103.0667006056615
Perplexity for test sample 3033 : 727.801165168394
Perplexity for test sample 3034 : 113.86124725832116
Perplexity for test sample 3035 : 373.8212565838409
Perplexity for test sample 3036 : 974.4019105063632
Perplexity for test sample 3037 : 298.32207411577417
Perplexity for test sample 3038 : 492.3418672006072
Perplexity for test sample 3039 : 443.67791797312304
Perplexity for test sample 3040 : 2408.3122701450493
Perplexity for test sample 3041 : 460.6088781904691
Perplexity for test sample 3042 : 1301.0712495012476
Perplexity for test sample 3043 : 485.5108436196283
Perplexity for test sample 3044 : 201.9283952336438
Perplexity for test sample 3045 : 955.7219220031121
Perplexity for test sample 3046 : 1202.7081907056463
Perplexi

Perplexity for test sample 3530 : 997.4662625708547
Perplexity for test sample 3531 : 451.1659927828725
Perplexity for test sample 3532 : 297.71909643495354
Perplexity for test sample 3533 : 872.0652912531395
Perplexity for test sample 3534 : 516.204174117343
Perplexity for test sample 3535 : 1377.026592717125
Perplexity for test sample 3536 : 1020.3817831975803
Perplexity for test sample 3537 : 5196.320377062466
Perplexity for test sample 3538 : 1718.4717987308995
Perplexity for test sample 3539 : 484.7088777285546
Perplexity for test sample 3540 : 416.3101844030728
Perplexity for test sample 3541 : 1025.1644117532715
Perplexity for test sample 3542 : 633.4810840069381
Perplexity for test sample 3543 : 647.7869467597407
Perplexity for test sample 3544 : 1924.3657481362886
Perplexity for test sample 3545 : 1700.785021048634
Perplexity for test sample 3546 : 2640.525586784412
Perplexity for test sample 3547 : 1823.0024346119217
Perplexity for test sample 3548 : 1126.1181619967729
Perple

Perplexity for test sample 4021 : 82.65844465400002
Perplexity for test sample 4022 : 282.59626961095927
Perplexity for test sample 4023 : 270.44256557604655
Perplexity for test sample 4024 : 370.74600106092873
Perplexity for test sample 4025 : 284.59926331014435
Perplexity for test sample 4026 : 818.6174979486494
Perplexity for test sample 4027 : 855.6481067946715
Perplexity for test sample 4028 : 1104.6814809002547
Perplexity for test sample 4029 : 2221.80807749992
Perplexity for test sample 4030 : 1088.9258127562407
Perplexity for test sample 4031 : 2101.7298594087565
Perplexity for test sample 4032 : 666.4617562975096
Perplexity for test sample 4033 : 323.15513871994847
Perplexity for test sample 4034 : 36.08325561884498
Perplexity for test sample 4035 : 47.01492395478094
Perplexity for test sample 4036 : 644.1534902951437
Perplexity for test sample 4037 : 218.29294127090128
Perplexity for test sample 4038 : 846.3648158499624
Perplexity for test sample 4039 : 127.3814455523921
Perp

### Hyper-parameters for the neural language model

In [14]:
# FOR THE NEURAL MODEL
embed_size = 128
intermediate_size = 1024
num_epochs = 3
learning_rate = 1e-1

### The class for the neural language model:

In [69]:
class NeuralLM(nn.Module):
    def __init__(self, vocab_size, embed_size, intermediate_size, m_gram):
        super(NeuralLM, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.intermediate = nn.Linear(m_gram * embed_size, intermediate_size)
        self.final = nn.Linear(intermediate_size, vocab_size)
        
    def forward(self, x):
        x = self.embed(x) # Embed word id(s) to vectors
        print(x)
        conc_emb = x.view(x.size(0), x.size(1)*x.size(2))
        intermediate_output = self.intermediate(conc_emb) # one layer of MLP
        intermediate_output = F.relu(intermediate_output) # ReLU non-linear function
        final_out = self.final(intermediate_output) # Map to the vocabulary size output
        return final_out
    
    # x is the context vector and output is the prob distrbution over all the tokens
    # use softmax to make everything into 1
    # choose word with highest probability
    # at each step make the modl take in the context x again that was generated previously
    # apply softmax to create a conditional distribution and take the argmax from that
    # do this iteratively until reaching the end of the sentence

In [70]:
neural_model = NeuralLM(vocab_size, embed_size, intermediate_size, m_gram).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(neural_model.parameters(), lr=learning_rate, momentum=0.9)

### Train the neural model

In [71]:
# Reduce batch size if you are running out of memory
batch_size = 64
training_data = torch.cat(train_ids, dim=0)
neural_model.train()

for epoch in range(num_epochs):
    total_loss = 0
    for i in tqdm(range(0, len(training_data), batch_size)):
        batch = training_data[i:i + batch_size]
        inputs = batch[:, 0:n_gram-1].to(device)
        targets = batch[:, n_gram-1:].to(device)
        
        # Forward pass
        outputs = neural_model(inputs)
        loss = criterion(outputs, targets.reshape(-1))
        total_loss += loss.item();
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    
    # Calculate the performance (perplexity) of the current trained model on dev set.
    total_ppl = 0
    for i in range(0, len(dev_ids)):
        dev_inputs = dev_ids[i][:, 0:n_gram-1].to(device)
        dev_targets = dev_ids[i][:, n_gram-1:].to(device)
        dev_outputs = neural_model(dev_inputs)
        ce = criterion(dev_outputs, dev_targets.reshape(-1))
        total_ppl += np.exp(ce.item());
    
    print ('Epoch [{}/{}], Training Loss: {:.4f}, Dev Perplexity: {:5.2f}'
        .format(epoch + 1, num_epochs, total_loss/len(train_ids), total_ppl/len(dev_ids)))

  0%|          | 1/10910 [00:00<27:34,  6.59it/s]

tensor([[[ 0.3366, -0.6849, -0.6231,  ...,  0.4395,  0.2915,  0.8121]],

        [[-2.9914,  1.7914, -0.2272,  ..., -0.8930,  0.2622,  0.6680]],

        [[ 0.2563,  0.6231,  0.0235,  ..., -0.1732,  1.9910,  0.6156]],

        ...,

        [[ 0.9145, -0.7784,  0.3912,  ..., -1.1965,  1.0594,  1.0600]],

        [[-0.4304,  0.2398,  0.0237,  ..., -0.3762,  0.7208, -0.1786]],

        [[ 0.2462,  0.7818, -0.2613,  ...,  0.8993,  0.1481, -0.9627]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 1.7959,  0.6012, -0.4454,  ...,  0.0339, -1.8028,  1.4657]],

        [[-0.8503,  1.5785,  0.5406,  ...,  0.0375, -0.0828, -0.4686]],

        [[-0.2936,  0.5210,  1.4298,  ...,  0.2086, -0.4422,  0.9478]],

        ...,

        [[ 0.3705, -0.4952,  0.4017,  ...,  0.6291,  0.8509, -1.6369]],

        [[ 0.6665, -0.9370,  1.3804,  ..., -1.2028, -0.0238,  1.3010]],

        [[ 0.4892,  0.5927, -0.0739,  ...,  0.6776, -0.3975, -1.1005]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.2697,  0.

  0%|          | 3/10910 [00:00<23:57,  7.59it/s]


tensor([[[-1.7487, -0.1242,  0.7775,  ..., -1.2465, -1.2272,  0.9909]],

        [[-1.1443, -0.5819, -0.2509,  ..., -0.3759, -0.4278,  0.4879]],

        [[-1.4152, -1.7207, -0.1968,  ...,  0.7618, -0.8367,  0.6615]],

        ...,

        [[ 0.3361, -0.6847, -0.6231,  ...,  0.4395,  0.2912,  0.8119]],

        [[-0.1149, -1.3540, -1.8638,  ..., -0.1145,  0.6369, -0.8028]],

        [[-2.1575,  0.1363,  0.5428,  ...,  0.4330, -0.2222,  0.1070]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.3274, -2.0989,  1.3299,  ..., -0.2913,  1.3246,  0.8078]],

        [[ 1.0870, -0.6926, -1.3763,  ..., -0.2518,  0.2096, -1.1757]],

        [[-0.3789,  1.2668,  1.8092,  ...,  0.7507, -0.2402,  1.1666]],

        ...,

        [[-1.3094,  0.0925,  0.8392,  ...,  0.5384, -1.1800,  0.1376]],

        [[-0.1798,  0.7391, -1.0055,  ..., -0.8215, -0.2188,  1.4452]],

        [[-0.8567,  0.5846,  1.2398,  ..., -0.6659,  0.1959, -0.1028]]],
       grad_fn=<EmbeddingBackward0>)


  0%|          | 7/10910 [00:00<20:02,  9.07it/s]

tensor([[[-0.5325, -0.6321,  0.1193,  ...,  0.1317,  2.2458, -1.0840]],

        [[-1.2288,  0.2436,  0.1553,  ...,  0.9258, -1.1570,  0.7287]],

        [[ 0.3707, -0.4951,  0.4017,  ...,  0.6291,  0.8508, -1.6369]],

        ...,

        [[-1.1302, -1.1681, -0.5897,  ..., -0.5806, -1.7073,  0.4863]],

        [[ 0.9243, -0.3556,  0.7373,  ...,  1.1951,  1.4302, -1.4401]],

        [[ 0.3357, -0.6847, -0.6230,  ...,  0.4394,  0.2908,  0.8121]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-1.3097,  0.0927,  0.8392,  ...,  0.5386, -1.1805,  0.1375]],

        [[-1.9502, -0.1188,  0.4771,  ...,  0.6775, -0.0453,  0.1683]],

        [[-0.4307,  0.2399,  0.0236,  ..., -0.3761,  0.7204, -0.1785]],

        ...,

        [[ 1.4644,  0.0978, -2.8943,  ..., -1.3664, -1.6057,  2.5424]],

        [[ 0.8015, -1.5560, -1.5362,  ...,  0.3195,  0.1219,  1.1461]],

        [[ 0.6798,  0.1730, -0.3893,  ..., -0.7821, -1.1608, -1.3413]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.8105, -0.

  0%|          | 9/10910 [00:00<18:58,  9.58it/s]

tensor([[[ 0.8975,  0.8944,  2.0132,  ...,  0.3943, -1.6518, -0.0351]],

        [[ 0.6607, -1.7268,  1.6045,  ...,  0.9735, -1.0114,  0.3943]],

        [[-1.1306, -1.1679, -0.5893,  ..., -0.5805, -1.7078,  0.4862]],

        ...,

        [[-0.3244,  0.1470,  0.5059,  ...,  0.2186, -0.6072,  0.5297]],

        [[-0.9940, -1.2286, -1.5375,  ..., -0.3707,  0.0296, -1.2057]],

        [[-1.2376,  0.5106,  1.6851,  ...,  0.1223, -0.4911,  0.1709]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.0906,  1.2153,  1.3003,  ..., -1.3120,  0.3127, -1.0449]],

        [[ 1.1525, -0.3100, -1.9362,  ...,  0.4026,  0.6406,  0.4583]],

        [[-1.3081,  0.0928,  0.8386,  ...,  0.5388, -1.1803,  0.1370]],

        ...,

        [[-1.3081,  0.0928,  0.8386,  ...,  0.5388, -1.1803,  0.1370]],

        [[-0.3650, -0.3804,  1.8045,  ..., -0.4678, -0.5419, -0.2934]],

        [[-0.7926,  0.4279,  0.7042,  ...,  0.1936,  0.4994,  0.8285]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-2.4014,  0.

  0%|          | 13/10910 [00:01<17:39, 10.29it/s]

tensor([[[ 0.3275, -2.0988,  1.3302,  ..., -0.2910,  1.3248,  0.8079]],

        [[-0.4476, -0.3359, -0.5782,  ..., -0.7680,  0.4108,  0.2380]],

        [[-0.3903, -0.3176,  0.1844,  ...,  0.0679,  0.3893,  0.1393]],

        ...,

        [[-0.5063,  0.0202, -1.6998,  ..., -1.5522,  1.2185,  0.4823]],

        [[ 0.6746,  0.7821, -0.1496,  ...,  0.5787,  0.5020, -0.3163]],

        [[-0.9819, -0.2607,  0.3725,  ...,  0.9188,  0.1192,  0.5857]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.3385, -0.1735,  0.0464,  ...,  1.8719,  0.6370, -0.4785]],

        [[-1.4798,  0.1253,  0.9716,  ...,  2.1980, -1.8745,  2.3657]],

        [[-0.0527,  0.5852,  2.3677,  ...,  0.8098, -1.0011,  1.2449]],

        ...,

        [[ 1.0101,  0.5083,  1.6106,  ...,  0.6336,  1.4291, -0.5343]],

        [[-0.9877,  0.4370,  0.4655,  ...,  1.3071,  0.3887, -0.6337]],

        [[ 0.6470,  0.7603,  0.4534,  ..., -0.1140,  1.1576, -0.9588]]],
       grad_fn=<EmbeddingBackward0>)


  0%|          | 15/10910 [00:01<17:04, 10.63it/s]

tensor([[[-1.1253, -1.1665, -0.5874,  ..., -0.5787, -1.7043,  0.4806]],

        [[ 1.1521, -0.3098, -1.9368,  ...,  0.4030,  0.6402,  0.4584]],

        [[-0.3847,  0.7212,  0.0925,  ..., -0.2116,  1.3481,  0.7272]],

        ...,

        [[-1.0448,  0.5540, -0.3509,  ..., -0.2210,  0.9196, -2.5127]],

        [[ 1.4649,  0.0982, -2.8951,  ..., -1.3671, -1.6064,  2.5433]],

        [[ 1.0100,  0.5084,  1.6106,  ...,  0.6335,  1.4289, -0.5343]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.9886, -0.4317,  0.1063,  ...,  0.4561,  0.0093,  0.1192]],

        [[ 0.6154,  0.1264,  0.1270,  ..., -0.2527,  0.9260,  0.6128]],

        [[-1.4149, -1.7207, -0.1973,  ...,  0.7615, -0.8370,  0.6597]],

        ...,

        [[ 0.0755, -1.1357, -1.0940,  ..., -2.0204,  0.0240, -1.4677]],

        [[-0.4596, -0.4592, -0.1535,  ...,  0.2921, -0.0686, -1.9050]],

        [[-2.7279, -1.0845, -0.8919,  ..., -1.2057, -2.0289, -0.6976]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.5454,  0.

  0%|          | 17/10910 [00:01<16:50, 10.78it/s]

tensor([[[-1.2976,  0.0924,  0.8351,  ...,  0.5388, -1.1773,  0.1365]],

        [[ 0.1326, -0.0871,  2.5296,  ...,  1.0864, -0.4762, -0.6303]],

        [[ 1.0418, -1.0451, -1.1479,  ...,  0.9300, -0.7086,  2.0325]],

        ...,

        [[-1.5734, -0.0648, -1.5663,  ...,  0.3925,  0.2446, -0.8065]],

        [[ 1.3917, -0.8673, -0.1984,  ..., -0.5217, -0.2081,  1.0591]],

        [[-0.9353,  0.4506, -0.7668,  ..., -1.0178,  1.8601, -1.2289]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.4039, -0.1857, -1.1765,  ...,  1.0066,  2.5936, -0.8383]],

        [[-0.4287,  0.2401,  0.0238,  ..., -0.3783,  0.7174, -0.1782]],

        [[-1.2969,  0.0924,  0.8348,  ...,  0.5390, -1.1771,  0.1364]],

        ...,

        [[-0.3838,  0.7198,  0.0917,  ..., -0.2115,  1.3467,  0.7252]],

        [[-1.3312,  0.1562,  0.0621,  ...,  0.4931, -1.3059, -1.0332]],

        [[ 0.3364, -0.6809, -0.6216,  ...,  0.4378,  0.2884,  0.8085]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.8730, -2.

  0%|          | 21/10910 [00:01<16:26, 11.04it/s]

tensor([[[-1.2962,  0.0925,  0.8343,  ...,  0.5395, -1.1769,  0.1363]],

        [[-1.0585, -0.2090,  0.1060,  ..., -0.1784, -0.3843, -0.9810]],

        [[ 1.1519, -0.3095, -1.9372,  ...,  0.4039,  0.6398,  0.4589]],

        ...,

        [[-0.1777,  0.6066, -0.4838,  ..., -0.6211,  0.7047,  0.4590]],

        [[-0.5328, -0.6311,  0.1203,  ...,  0.1317,  2.2457, -1.0837]],

        [[-1.3949, -1.2906, -0.8457,  ...,  1.5127,  0.7468, -0.0253]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.7436, -0.5201, -0.9459,  ...,  0.8549, -0.5941, -1.9359]],

        [[ 1.7037,  0.2484,  0.0684,  ...,  0.7077, -0.4031, -1.2684]],

        [[-0.4281,  0.2402,  0.0232,  ..., -0.3788,  0.7164, -0.1780]],

        ...,

        [[ 0.1223,  0.7648,  0.3236,  ..., -0.7039, -0.4317,  0.1626]],

        [[-0.2912,  0.7668,  1.1776,  ...,  0.0075, -1.2565,  2.1229]],

        [[-0.1778, -1.7999,  1.3179,  ..., -1.4480,  0.1161, -0.9392]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-2.0182,  0.

  0%|          | 23/10910 [00:02<16:42, 10.86it/s]

tensor([[[-0.2786,  0.1763, -1.4993,  ..., -2.0164,  0.0100,  0.7242]],

        [[ 0.3952, -2.2033,  0.5197,  ...,  0.9356,  1.0217, -0.5435]],

        [[-1.1195, -1.1667, -0.5870,  ..., -0.5760, -1.6999,  0.4727]],

        ...,

        [[ 0.0179,  0.8786, -0.1268,  ...,  0.3850, -0.6146,  0.6272]],

        [[ 1.3429, -1.4739,  1.9608,  ...,  1.4718, -0.3162,  1.0439]],

        [[-0.3826,  0.7185,  0.0913,  ..., -0.2113,  1.3462,  0.7242]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.5542,  0.0699, -0.3386,  ...,  1.3398, -0.6031, -0.7083]],

        [[-0.5329, -0.6309,  0.1206,  ...,  0.1316,  2.2444, -1.0831]],

        [[-0.2649, -0.4421, -0.1974,  ..., -0.4103,  0.6432, -1.4499]],

        ...,

        [[ 0.3412, -0.6808, -0.6258,  ...,  0.4388,  0.2887,  0.8082]],

        [[-0.5329, -0.6309,  0.1206,  ...,  0.1316,  2.2444, -1.0831]],

        [[-1.2945,  0.0923,  0.8337,  ...,  0.5402, -1.1764,  0.1364]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.0763, -0.

  0%|          | 25/10910 [00:02<16:26, 11.04it/s]


tensor([[[ 1.3784, -0.8424,  0.2761,  ..., -0.3649,  1.2024,  0.0050]],

        [[ 0.0265, -0.4193,  1.5628,  ...,  0.7692,  1.6050,  0.5189]],

        [[-0.5335, -0.6307,  0.1205,  ...,  0.1312,  2.2423, -1.0819]],

        ...,

        [[-1.1183, -1.1668, -0.5882,  ..., -0.5767, -1.7002,  0.4719]],

        [[-1.4135, -1.7208, -0.1995,  ...,  0.7616, -0.8363,  0.6576]],

        [[-1.2940,  0.0920,  0.8334,  ...,  0.5405, -1.1760,  0.1358]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.7801, -0.3130,  0.5125,  ..., -0.6557,  0.5081, -2.4534]],

        [[ 0.4682,  1.5845,  0.1472,  ..., -0.9012, -1.4850,  1.8205]],

        [[-1.1180, -1.1670, -0.5891,  ..., -0.5773, -1.7007,  0.4721]],

        ...,

        [[-2.0149,  0.3452,  0.0978,  ..., -1.0143, -0.7678, -0.7577]],

        [[-0.0199, -0.8522, -1.8343,  ..., -1.0013,  0.0782, -2.4833]],

        [[ 1.4610,  0.0952, -2.8919,  ..., -1.3641, -1.6049,  2.5385]]],
       grad_fn=<EmbeddingBackward0>)


  0%|          | 29/10910 [00:02<15:41, 11.56it/s]

tensor([[[-0.3821,  0.7173,  0.0914,  ..., -0.2109,  1.3456,  0.7240]],

        [[-1.1178, -1.1673, -0.5903,  ..., -0.5778, -1.7013,  0.4722]],

        [[-0.2457,  1.5274, -0.0118,  ...,  0.0246, -0.7915,  1.1356]],

        ...,

        [[-0.1704, -0.4285,  0.4479,  ...,  0.4740,  0.7203, -1.7959]],

        [[ 1.0750,  1.2956,  1.0157,  ...,  0.0673, -0.5767, -0.9832]],

        [[ 1.4596,  0.0944, -2.8908,  ..., -1.3634, -1.6045,  2.5371]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.2801, -0.9153,  2.5773,  ...,  1.4976, -0.8364, -0.1284]],

        [[-0.5339, -0.6307,  0.1202,  ...,  0.1312,  2.2400, -1.0807]],

        [[-1.1174, -1.1673, -0.5913,  ..., -0.5783, -1.7018,  0.4726]],

        ...,

        [[-0.9941, -1.2281, -1.5371,  ..., -0.3697,  0.0293, -1.2054]],

        [[-2.0538, -0.1523,  0.5200,  ..., -0.6917, -0.6974,  1.2826]],

        [[-0.3821,  0.7171,  0.0915,  ..., -0.2108,  1.3454,  0.7242]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-1.3636,  1.

  0%|          | 31/10910 [00:02<15:18, 11.84it/s]

tensor([[[-2.1486, -0.6268, -0.8674,  ..., -0.1664,  0.2819, -2.3964]],

        [[ 0.0139,  0.2325, -0.9079,  ...,  0.3686,  0.0674,  1.1910]],

        [[ 0.1685,  1.7012, -0.2518,  ..., -0.0978, -0.5045,  0.9856]],

        ...,

        [[-0.3534,  1.2900, -0.3687,  ...,  0.8138, -0.3317, -1.3838]],

        [[-0.3895, -1.4889,  0.0049,  ..., -1.0076, -2.7434,  0.0147]],

        [[-1.4709,  0.1257,  0.9623,  ...,  2.1919, -1.8703,  2.3594]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-2.0129,  0.3454,  0.0986,  ..., -1.0141, -0.7686, -0.7579]],

        [[-1.7116,  1.1476, -0.4428,  ...,  0.1034,  0.5390,  0.5377]],

        [[ 0.3407, -1.4163,  0.7184,  ..., -2.6509,  0.5388, -0.8170]],

        ...,

        [[-0.0183, -2.5890, -0.0039,  ..., -1.6066, -0.5304, -0.3028]],

        [[ 0.5355, -0.3501,  0.9811,  ...,  2.8659, -0.8849,  0.2364]],

        [[-0.8223, -0.4533, -0.9619,  ...,  0.2304, -0.1204,  0.2492]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.6174,  0.

  0%|          | 35/10910 [00:03<15:26, 11.74it/s]

tensor([[[ 0.9214,  0.2364,  0.5447,  ...,  0.1730, -2.2118,  1.0158]],

        [[ 0.0793,  1.0789,  0.3845,  ...,  1.0450, -0.1164,  0.1262]],

        [[-1.2948,  0.0908,  0.8325,  ...,  0.5425, -1.1754,  0.1332]],

        ...,

        [[ 0.5278, -0.4038, -0.9401,  ..., -0.2126, -0.0272, -0.2954]],

        [[ 0.2169,  1.1271, -1.1634,  ...,  0.0164, -1.4045, -0.2381]],

        [[ 1.1820, -0.6793, -0.8764,  ...,  0.6045, -0.4679, -0.4667]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.2575, -1.9963, -0.4651,  ..., -1.0807,  1.1406,  0.5239]],

        [[ 0.2170,  1.1271, -1.1633,  ...,  0.0167, -1.4047, -0.2381]],

        [[ 0.5348, -0.3508,  0.9805,  ...,  2.8651, -0.8851,  0.2364]],

        ...,

        [[-1.4680,  0.1259,  0.9601,  ...,  2.1890, -1.8684,  2.3549]],

        [[ 0.7975,  1.7687, -1.2404,  ..., -0.7026, -0.5042,  0.8329]],

        [[-1.7219,  0.2663,  0.3497,  ..., -0.8322, -0.5400,  0.4097]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.9404,  0.

  0%|          | 37/10910 [00:03<15:19, 11.82it/s]

tensor([[[-2.0141,  0.3478,  0.1020,  ..., -1.0146, -0.7693, -0.7610]],

        [[ 0.3429, -0.6802, -0.6245,  ...,  0.4361,  0.2915,  0.7962]],

        [[ 0.3821,  0.7688, -0.4335,  ...,  2.1023,  0.2948,  0.3241]],

        ...,

        [[ 1.4241,  0.0056, -0.8531,  ..., -0.4702,  0.2219, -0.6607]],

        [[ 0.4800, -0.1770,  0.7895,  ...,  0.3541, -1.7287,  1.0499]],

        [[ 0.3402, -1.4154,  0.7171,  ..., -2.6488,  0.5390, -0.8153]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.5324,  1.0633, -0.7857,  ...,  0.5494, -0.4940, -0.4199]],

        [[ 0.2572, -1.9964, -0.4651,  ..., -1.0800,  1.1407,  0.5240]],

        [[ 0.3822,  0.7688, -0.4336,  ...,  2.1024,  0.2950,  0.3242]],

        ...,

        [[-2.0145,  0.3482,  0.1025,  ..., -1.0146, -0.7699, -0.7611]],

        [[-2.0145,  0.3482,  0.1025,  ..., -1.0146, -0.7699, -0.7611]],

        [[ 0.3428, -0.6801, -0.6240,  ...,  0.4363,  0.2917,  0.7956]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.2181,  1.

  0%|          | 41/10910 [00:03<15:12, 11.91it/s]

tensor([[[-1.2932,  0.0905,  0.8295,  ...,  0.5418, -1.1750,  0.1322]],

        [[ 0.2664, -1.0219,  0.9134,  ..., -1.2102, -0.2219, -1.1908]],

        [[-1.4115, -1.7177, -0.2026,  ...,  0.7601, -0.8389,  0.6599]],

        ...,

        [[-1.4115, -1.7177, -0.2026,  ...,  0.7601, -0.8389,  0.6599]],

        [[ 0.9151, -0.7785,  0.3361,  ..., -0.0338, -0.4899, -0.8331]],

        [[-0.5966,  0.2371,  1.6299,  ...,  0.3950, -0.3931, -0.4894]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.4357,  0.2374,  0.0308,  ..., -0.3756,  0.7178, -0.1793]],

        [[-1.7478, -0.1234,  0.7768,  ..., -1.2438, -1.2245,  0.9905]],

        [[-0.5240, -0.4805, -0.0283,  ...,  0.4825,  0.8320,  0.3304]],

        ...,

        [[-1.0515,  0.4501,  1.5516,  ..., -0.2796,  0.0650,  0.7009]],

        [[ 1.1088,  0.5582,  0.0172,  ...,  0.0747,  0.3716, -0.6467]],

        [[ 1.1814, -0.6797, -0.8773,  ...,  0.6046, -0.4677, -0.4659]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.5350, -0.

  0%|          | 43/10910 [00:03<15:23, 11.77it/s]

tensor([[[-0.5353, -0.6321,  0.1178,  ...,  0.1347,  2.2370, -1.0787]],

        [[-1.2923,  0.0901,  0.8278,  ...,  0.5417, -1.1755,  0.1320]],

        [[ 0.2848,  0.9092,  0.4567,  ...,  0.4454, -0.1327, -1.3281]],

        ...,

        [[-0.1689, -0.4277,  0.4485,  ...,  0.4733,  0.7191, -1.7949]],

        [[ 1.0732,  1.2943,  1.0180,  ...,  0.0677, -0.5780, -0.9836]],

        [[-0.7057,  1.7453, -0.6819,  ..., -0.3317,  0.0789,  1.5492]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-1.1384, -0.3847,  0.0209,  ...,  0.1999, -0.4583,  0.2915]],

        [[-0.2459,  0.3900, -0.3527,  ..., -1.1251, -1.7276, -0.5068]],

        [[-0.4363,  0.2379,  0.0332,  ..., -0.3750,  0.7195, -0.1804]],

        ...,

        [[-2.0135,  0.3496,  0.1021,  ..., -1.0134, -0.7722, -0.7594]],

        [[-0.0144, -0.9683,  0.0722,  ..., -0.0472,  3.4144,  2.0140]],

        [[-1.9448, -0.5700,  1.0743,  ..., -0.5264, -2.3148, -0.4716]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.0147, -0.

  0%|          | 45/10910 [00:03<15:31, 11.66it/s]

tensor([[[-0.5360, -0.6327,  0.1177,  ...,  0.1356,  2.2353, -1.0783]],

        [[-0.3776,  1.5962, -0.7180,  ...,  2.7065, -0.9556,  1.2236]],

        [[-0.1216,  1.9673,  0.1818,  ..., -0.0196, -1.2121, -0.6637]],

        ...,

        [[-0.5360, -0.6327,  0.1177,  ...,  0.1356,  2.2353, -1.0783]],

        [[-1.1267,  0.2607, -1.4123,  ..., -0.1711,  0.3668,  0.0634]],

        [[ 0.6556,  0.4193, -0.1408,  ..., -1.1677, -0.3884, -0.8640]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.5544,  0.9279, -0.8634,  ..., -2.2416,  1.0415, -1.5713]],

        [[-0.4365,  0.2376,  0.0341,  ..., -0.3752,  0.7200, -0.1804]],

        [[ 0.6703, -1.1339,  1.3067,  ...,  0.8043,  0.1352, -1.6883]],

        ...,

        [[-0.2406,  0.0430,  0.7074,  ..., -0.4696, -0.1731,  1.9606]],

        [[-1.4114, -1.7144, -0.2029,  ...,  0.7595, -0.8376,  0.6603]],

        [[ 1.4959,  2.1982, -0.3536,  ..., -0.2550, -0.4814, -0.5589]]],
       grad_fn=<EmbeddingBackward0>)


  0%|          | 47/10910 [00:04<17:04, 10.61it/s]

tensor([[[-0.9836, -0.4238, -1.5968,  ...,  1.0070,  0.0884,  0.4717]],

        [[-0.5370, -0.6330,  0.1167,  ...,  0.1346,  2.2328, -1.0765]],

        [[ 0.9495,  0.4841, -0.4124,  ...,  0.4645, -0.6300,  1.1749]],

        ...,

        [[-0.3261,  0.9897,  0.0200,  ...,  0.3426, -0.9924,  2.1374]],

        [[-0.4368,  0.2371,  0.0337,  ..., -0.3756,  0.7198, -0.1800]],

        [[ 0.9420,  0.9674, -0.8787,  ..., -0.2158, -0.5076, -0.4179]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.1907, -1.8057,  1.3239,  ..., -1.4507,  0.1067, -0.9387]],

        [[-2.0108,  0.3468,  0.1002,  ..., -1.0119, -0.7710, -0.7535]],

        [[-1.4113, -1.7142, -0.2025,  ...,  0.7595, -0.8374,  0.6600]],

        ...,

        [[-1.7033,  0.3105, -0.3341,  ..., -1.0807,  0.9461, -0.4973]],

        [[-1.5688,  0.3264,  2.0595,  ..., -1.3385,  0.2040, -2.1558]],

        [[-0.3599, -0.3330,  0.2898,  ..., -2.8255,  1.4637, -0.8293]]],
       grad_fn=<EmbeddingBackward0>)


  0%|          | 49/10910 [00:04<18:03, 10.03it/s]

tensor([[[-0.7394,  1.1963,  1.6066,  ..., -0.2736, -0.7437,  0.0937]],

        [[-1.1073, -1.1676, -0.5951,  ..., -0.5799, -1.6928,  0.4696]],

        [[-0.4903,  0.9073,  1.5035,  ..., -0.0374, -1.4419,  0.0819]],

        ...,

        [[-0.1684, -0.4274,  0.4476,  ...,  0.4731,  0.7181, -1.7937]],

        [[-1.4111, -1.7141, -0.2023,  ...,  0.7595, -0.8374,  0.6598]],

        [[-0.4455, -2.2356, -1.4101,  ...,  0.7077,  0.8444, -0.1567]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 1.4958,  2.1978, -0.3534,  ..., -0.2549, -0.4817, -0.5584]],

        [[-2.0095,  0.3459,  0.1000,  ..., -1.0118, -0.7704, -0.7520]],

        [[-0.3621, -1.4303,  0.4052,  ...,  0.5295,  0.2181, -0.0519]],

        ...,

        [[ 0.0029, -0.0254, -0.4320,  ..., -2.2688,  0.2658,  1.1635]],

        [[ 1.8880, -0.3027,  1.0860,  ..., -1.5946,  1.3908, -0.0935]],

        [[ 2.0124,  1.7649,  0.5561,  ..., -0.5989,  1.6154, -1.7085]]],
       grad_fn=<EmbeddingBackward0>)


  0%|          | 52/10910 [00:04<19:57,  9.07it/s]

tensor([[[ 0.5457,  0.7128, -0.5352,  ...,  0.2656, -1.3014,  0.1993]],

        [[-0.1911, -1.8059,  1.3240,  ..., -1.4507,  0.1064, -0.9387]],

        [[-2.0088,  0.3457,  0.1002,  ..., -1.0116, -0.7704, -0.7521]],

        ...,

        [[ 0.1118,  0.0816, -0.7791,  ...,  0.4821, -1.4997, -1.2364]],

        [[-0.7437,  0.6251, -1.4483,  ...,  0.0138, -1.5243,  1.1959]],

        [[-0.5380, -0.6332,  0.1142,  ...,  0.1319,  2.2283, -1.0730]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.6411, -1.1480,  0.0373,  ..., -0.4259,  1.3181, -0.1491]],

        [[-1.2918,  0.0899,  0.8239,  ...,  0.5413, -1.1752,  0.1308]],

        [[ 1.0428, -1.0439, -1.1472,  ...,  0.9273, -0.7078,  2.0327]],

        ...,

        [[-0.2395,  0.0437,  0.7075,  ..., -0.4708, -0.1723,  1.9615]],

        [[-0.4383,  0.2346,  0.0311,  ..., -0.3771,  0.7174, -0.1774]],

        [[-1.2918,  0.0899,  0.8239,  ...,  0.5413, -1.1752,  0.1308]]],
       grad_fn=<EmbeddingBackward0>)


  0%|          | 54/10910 [00:05<21:07,  8.57it/s]

tensor([[[-0.5796, -0.0049,  1.4896,  ...,  0.9538,  0.4636, -0.9959]],

        [[ 0.5314, -0.3533,  0.9757,  ...,  2.8616, -0.8863,  0.2361]],

        [[ 0.9632,  0.6510,  0.8851,  ..., -0.4336, -1.1024,  0.8171]],

        ...,

        [[-0.0210,  0.0390,  0.2081,  ..., -0.9132, -0.0976, -0.0746]],

        [[-1.7260, -0.8500,  0.3810,  ...,  0.4279,  1.2229, -0.3605]],

        [[ 0.8470, -1.1697, -0.6963,  ..., -0.3969, -0.9894,  2.1150]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-1.3940, -1.2902, -0.8474,  ...,  1.5131,  0.7448, -0.0250]],

        [[-1.8586, -0.1420,  0.0444,  ...,  0.6427, -1.6349,  0.4560]],

        [[ 0.5461,  0.7127, -0.5353,  ...,  0.2654, -1.3017,  0.1996]],

        ...,

        [[-1.1059, -1.1681, -0.5949,  ..., -0.5791, -1.6926,  0.4700]],

        [[-2.0070,  0.3459,  0.1010,  ..., -1.0106, -0.7713, -0.7553]],

        [[ 0.3216, -1.5437, -0.2363,  ...,  0.4062,  0.8544,  0.1651]]],
       grad_fn=<EmbeddingBackward0>)


  1%|          | 57/10910 [00:05<19:21,  9.34it/s]

tensor([[[ 1.1493, -0.3072, -1.9354,  ...,  0.4052,  0.6372,  0.4595]],

        [[ 0.3380, -0.5225,  0.1291,  ...,  0.4554,  0.0113,  0.2923]],

        [[ 0.4299,  0.0815,  1.4583,  ...,  0.3722,  0.5304, -0.3289]],

        ...,

        [[-0.5186, -0.5004,  1.3902,  ...,  0.4788, -1.2979,  1.0007]],

        [[-0.0856, -0.7889,  0.8110,  ..., -0.0424, -2.1383,  0.3478]],

        [[-0.6136,  0.1650, -0.0207,  ..., -0.4954,  0.4856,  0.3444]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.4395,  0.2332,  0.0304,  ..., -0.3776,  0.7162, -0.1759]],

        [[-2.0057,  0.3460,  0.1008,  ..., -1.0103, -0.7716, -0.7563]],

        [[-2.0057,  0.3460,  0.1008,  ..., -1.0103, -0.7716, -0.7563]],

        ...,

        [[-1.1940,  0.1300, -1.4802,  ..., -0.3159, -0.3376,  0.7082]],

        [[-0.4395,  0.2332,  0.0304,  ..., -0.3776,  0.7162, -0.1759]],

        [[-1.2918,  0.0902,  0.8231,  ...,  0.5400, -1.1758,  0.1305]]],
       grad_fn=<EmbeddingBackward0>)


  1%|          | 59/10910 [00:05<18:02, 10.03it/s]

tensor([[[ 0.1721,  0.0344, -0.9280,  ...,  0.2671,  0.1187, -1.1835]],

        [[ 0.9584, -0.6543, -1.7383,  ..., -0.2905,  1.0420, -0.2397]],

        [[-0.9939, -1.2273, -1.5368,  ..., -0.3704,  0.0271, -1.2063]],

        ...,

        [[-1.2919,  0.0902,  0.8229,  ...,  0.5396, -1.1758,  0.1305]],

        [[-0.0755,  0.7308,  1.1713,  ...,  1.9010, -0.7419, -0.9252]],

        [[ 0.2539, -1.9953, -0.4652,  ..., -1.0794,  1.1427,  0.5227]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.3410, -0.6770, -0.6168,  ...,  0.4461,  0.3013,  0.7883]],

        [[-1.2919,  0.0903,  0.8229,  ...,  0.5396, -1.1759,  0.1305]],

        [[-0.5789, -0.0050,  1.4894,  ...,  0.9553,  0.4614, -0.9961]],

        ...,

        [[ 0.3815,  0.7691, -0.4342,  ...,  2.1027,  0.2968,  0.3261]],

        [[-0.9130,  0.7394,  0.1719,  ...,  0.1310, -0.0861,  0.2571]],

        [[-0.4791, -0.1032, -0.1943,  ...,  1.0350,  1.0313, -1.0881]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.0751,  1.

  1%|          | 61/10910 [00:05<17:28, 10.35it/s]

tensor([[[-1.5470,  0.0650,  0.4074,  ..., -0.6247, -1.3670,  0.1822]],

        [[-1.2921,  0.0907,  0.8228,  ...,  0.5403, -1.1764,  0.1309]],

        [[ 1.9522,  0.1021,  0.3400,  ...,  2.3681, -0.7296,  1.2430]],

        ...,

        [[-1.7038,  0.3103, -0.3346,  ..., -1.0806,  0.9456, -0.4968]],

        [[-1.2921,  0.0907,  0.8228,  ...,  0.5403, -1.1764,  0.1309]],

        [[-1.1036, -1.1688, -0.5935,  ..., -0.5785, -1.6958,  0.4716]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.1645, -0.4248,  0.4489,  ...,  0.4719,  0.7180, -1.7954]],

        [[-1.6749, -0.3171,  0.4527,  ...,  0.1279,  0.1860,  0.3565]],

        [[-1.4099, -1.7154, -0.1988,  ...,  0.7610, -0.8416,  0.6604]],

        ...,

        [[-1.2781,  0.5922, -0.5242,  ..., -2.1245, -1.6811,  1.4560]],

        [[-0.3871,  0.7094,  0.0899,  ..., -0.2098,  1.3338,  0.7303]],

        [[ 1.9521,  0.1020,  0.3400,  ...,  2.3683, -0.7295,  1.2428]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.3859, -0.

  1%|          | 63/10910 [00:05<17:21, 10.41it/s]


tensor([[[ 0.3375, -0.6736, -0.6148,  ...,  0.4422,  0.3011,  0.7880]],

        [[ 0.9360,  1.2230,  1.0179,  ...,  0.0944,  2.4099, -0.8758]],

        [[-1.4820,  0.8803,  1.0060,  ...,  0.9108, -1.9256,  1.5964]],

        ...,

        [[ 0.2534, -1.9954, -0.4660,  ..., -1.0801,  1.1438,  0.5229]],

        [[ 0.3375, -0.6736, -0.6148,  ...,  0.4422,  0.3011,  0.7880]],

        [[-0.3895, -0.8776,  1.5658,  ..., -0.7341, -0.6064,  0.6485]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-1.2343e-01,  3.1290e-01, -1.3677e-03,  ...,  9.1344e-01,
          -1.3483e+00, -2.9668e-01]],

        [[ 4.7630e-01, -1.5580e+00,  1.4072e-01,  ...,  1.6703e+00,
           1.3040e+00,  1.0901e+00]],

        [[-1.0532e+00,  4.3574e-01, -2.1075e-01,  ..., -4.7831e-01,
          -9.2992e-01, -3.1497e-01]],

        ...,

        [[-1.4707e+00,  1.1972e-01,  9.5951e-01,  ...,  2.1872e+00,
          -1.8586e+00,  2.3457e+00]],

        [[ 1.3791e+00, -2.0631e+00,  8.1548e-01,  ...,  1.1705e-01,


  1%|          | 67/10910 [00:06<16:26, 10.99it/s]

tensor([[[ 0.3141,  0.0291, -0.0831,  ..., -0.8426,  2.6301,  1.8605]],

        [[-0.5383, -1.6305,  0.3778,  ...,  0.1910,  0.0686, -0.2343]],

        [[ 1.0838, -0.2095,  1.4054,  ...,  1.3976, -1.1841,  1.0117]],

        ...,

        [[-1.0059,  1.2532,  0.4144,  ...,  0.1537, -0.4651,  0.4950]],

        [[-0.6281, -0.1036, -0.2445,  ...,  0.6216, -1.2304,  0.4099]],

        [[-0.1297,  0.3320, -1.3119,  ...,  2.7852,  0.8900, -0.4015]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 6.6968e-02, -6.1704e-01,  5.7792e-01,  ...,  8.5308e-01,
           5.5752e-02,  2.0095e+00]],

        [[-5.3660e-01, -6.3354e-01,  1.1054e-01,  ...,  1.2767e-01,
           2.2214e+00, -1.0680e+00]],

        [[ 2.5232e-01,  6.7322e-01,  9.3514e-01,  ..., -2.5415e-01,
           1.2630e+00, -1.9777e-01]],

        ...,

        [[-5.6109e-01, -9.3701e-02,  1.0789e+00,  ..., -5.8486e-01,
           7.8241e-01, -5.5713e-01]],

        [[-9.1278e-01,  7.3862e-01,  1.6827e-01,  ...,  1.3058e-01,
 

  1%|          | 69/10910 [00:06<16:36, 10.88it/s]

tensor([[[ 1.8455e+00,  3.7229e-01,  1.1019e-01,  ...,  2.5365e-01,
           7.0744e-01,  3.5101e-01]],

        [[-9.1209e-01,  7.3906e-01,  1.6711e-01,  ...,  1.3058e-01,
          -8.4400e-02,  2.5922e-01]],

        [[-4.3591e-03,  2.8674e-02,  1.1123e-03,  ..., -1.6785e+00,
          -7.3611e-01, -9.8160e-01]],

        ...,

        [[-2.3470e-01,  4.4014e-02,  7.1235e-01,  ..., -4.7072e-01,
          -1.6701e-01,  1.9594e+00]],

        [[ 1.4446e+00,  8.3813e-02, -2.8741e+00,  ..., -1.3535e+00,
          -1.5985e+00,  2.5238e+00]],

        [[-1.2924e+00,  9.0814e-02,  8.2288e-01,  ...,  5.4274e-01,
          -1.1751e+00,  1.3474e-01]]], grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.2214, -0.7007, -0.4847,  ...,  0.5186,  0.1085, -0.2407]],

        [[-0.1176,  0.7589, -0.2115,  ...,  0.9176, -0.8732,  1.6356]],

        [[-0.4454,  0.2321,  0.0318,  ..., -0.3785,  0.7141, -0.1733]],

        ...,

        [[ 0.3315, -0.6685, -0.6140,  ...,  0.4375,  0.3002,  0.7883]],

        

  1%|          | 73/10910 [00:06<15:40, 11.52it/s]

tensor([[[-1.1022, -1.1635, -0.5930,  ..., -0.5740, -1.6865,  0.4729]],

        [[-1.9961,  0.3458,  0.0964,  ..., -1.0124, -0.7645, -0.7478]],

        [[-0.1557, -0.0487, -0.4956,  ...,  0.2212, -1.0425,  2.2875]],

        ...,

        [[ 1.1808, -0.6810, -0.8761,  ...,  0.6074, -0.4677, -0.4696]],

        [[-0.5354, -0.6339,  0.1100,  ...,  0.1270,  2.2217, -1.0687]],

        [[-1.1277,  0.2600, -1.4117,  ..., -0.1719,  0.3675,  0.0650]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 1.0941,  0.0617,  2.1384,  ..., -0.0892,  0.0254, -1.9452]],

        [[-0.5497,  0.9282, -0.8640,  ..., -2.2425,  1.0447, -1.5733]],

        [[ 0.3300, -0.6673, -0.6145,  ...,  0.4365,  0.3007,  0.7877]],

        ...,

        [[ 1.1795, -0.6808, -0.8757,  ...,  0.6069, -0.4677, -0.4690]],

        [[-0.4463,  0.2317,  0.0324,  ..., -0.3788,  0.7137, -0.1731]],

        [[-0.5754, -0.1259,  0.7942,  ...,  0.0577, -0.4897, -0.3513]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 1.8935, -0.

  1%|          | 75/10910 [00:06<15:32, 11.62it/s]

tensor([[[ 1.2506, -1.1188,  0.5474,  ..., -0.0166,  0.2915, -1.1322]],

        [[-1.2514, -1.2449, -0.2233,  ...,  0.1977,  0.7125, -0.2832]],

        [[-0.1400,  0.8874, -0.0301,  ...,  0.4475,  1.0692, -1.0010]],

        ...,

        [[-2.1687, -0.5204, -1.0711,  ...,  1.6053, -1.0312,  1.3092]],

        [[-0.4474,  0.2305,  0.0326,  ..., -0.3786,  0.7127, -0.1719]],

        [[-1.2931,  0.0900,  0.8228,  ...,  0.5433, -1.1745,  0.1352]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-1.9938,  0.3473,  0.0948,  ..., -1.0133, -0.7638, -0.7486]],

        [[ 0.9114, -0.7382,  1.3245,  ..., -0.0605, -0.1365,  0.1773]],

        [[-0.3871, -0.3216,  0.1813,  ...,  0.0715,  0.3845,  0.1478]],

        ...,

        [[-1.1009, -1.1662, -0.5923,  ..., -0.5750, -1.6883,  0.4722]],

        [[-1.0586, -0.4059,  0.5027,  ..., -0.5057, -2.1211,  0.2610]],

        [[ 0.7868,  0.9823,  1.2732,  ...,  0.8320, -1.0801,  0.0078]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.3631, -0.

  1%|          | 79/10910 [00:07<15:24, 11.71it/s]

tensor([[[ 0.9227,  0.0318,  1.2826,  ...,  1.1955, -0.3798, -0.3683]],

        [[ 0.2515, -1.9935, -0.4659,  ..., -1.0789,  1.1450,  0.5257]],

        [[ 0.3293, -0.6679, -0.6164,  ...,  0.4360,  0.3003,  0.7861]],

        ...,

        [[-0.9392,  1.5142, -0.7397,  ...,  0.0215,  0.3872, -0.4789]],

        [[-1.3957, -1.2913, -0.8457,  ...,  1.5126,  0.7451, -0.0252]],

        [[-0.1661, -1.3404, -0.0055,  ..., -1.2214,  0.8811,  0.3328]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-1.4109, -1.7116, -0.1992,  ...,  0.7593, -0.8393,  0.6609]],

        [[-0.6908, -1.6224, -1.8540,  ..., -0.2914, -1.3429, -0.2939]],

        [[-0.3907,  0.7079,  0.0883,  ..., -0.2125,  1.3307,  0.7303]],

        ...,

        [[-0.3786,  1.5952, -0.7174,  ...,  2.7057, -0.9565,  1.2242]],

        [[-0.4494,  0.2283,  0.0323,  ..., -0.3784,  0.7113, -0.1699]],

        [[ 0.6231, -0.0952,  0.8231,  ..., -0.2679, -0.0035, -0.6632]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.0410,  0.

  1%|          | 81/10910 [00:07<15:29, 11.64it/s]

tensor([[[ 0.2781, -0.8465,  1.4850,  ...,  0.3383, -0.5511, -0.3946]],

        [[-0.7542, -3.5614, -1.1109,  ...,  1.0455,  0.4903,  0.7087]],

        [[ 0.3293, -0.6682, -0.6168,  ...,  0.4357,  0.3002,  0.7855]],

        ...,

        [[ 0.4103, -2.1086, -1.3262,  ..., -0.8448, -0.1237,  0.1305]],

        [[-1.0512,  0.5124, -0.7421,  ..., -0.5086,  0.2963, -0.4546]],

        [[-1.4711,  0.1184,  0.9568,  ...,  2.1850, -1.8563,  2.3440]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 1.5084,  0.2200,  0.2976,  ...,  0.4836, -0.5845, -0.0131]],

        [[ 0.3294, -0.6683, -0.6170,  ...,  0.4356,  0.3001,  0.7853]],

        [[ 0.9224,  0.0311,  1.2823,  ...,  1.1950, -0.3801, -0.3680]],

        ...,

        [[ 0.2145,  1.0968, -0.5757,  ...,  0.9648,  0.2381,  0.3142]],

        [[ 0.5805,  0.0199, -0.1064,  ..., -3.0345, -1.4489,  1.6280]],

        [[-0.8407,  1.0641, -0.8456,  ...,  0.2745,  1.3762,  0.6335]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.1088, -0.

  1%|          | 85/10910 [00:07<15:12, 11.87it/s]

tensor([[[-1.6694,  0.6581, -1.3621,  ..., -0.9377,  0.9456,  0.5629]],

        [[-0.0177, -2.5917, -0.0031,  ..., -1.6071, -0.5318, -0.3037]],

        [[-1.3961, -1.2907, -0.8449,  ...,  1.5121,  0.7451, -0.0250]],

        ...,

        [[-0.1711,  0.2038,  1.5019,  ..., -0.6243,  0.6178, -0.7495]],

        [[ 0.3371, -1.4113,  0.7138,  ..., -2.6435,  0.5382, -0.8093]],

        [[ 1.9928, -1.3066, -0.3004,  ...,  0.0386,  0.9894, -0.1582]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-1.4122, -1.7105, -0.1988,  ...,  0.7582, -0.8376,  0.6610]],

        [[-0.6202,  0.1069, -1.1267,  ...,  0.9994,  0.3057, -0.0542]],

        [[ 0.3582, -0.4958,  0.4057,  ...,  0.6294,  0.8381, -1.6282]],

        ...,

        [[-1.0096, -0.6038,  1.3145,  ..., -0.6009,  0.5783, -1.2710]],

        [[-0.9270,  0.1634,  1.2121,  ...,  1.5192,  0.3759, -0.7350]],

        [[-0.5413, -1.6277,  0.3751,  ...,  0.1918,  0.0677, -0.2296]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.9841, -0.

  1%|          | 87/10910 [00:07<15:19, 11.77it/s]

tensor([[[ 0.6227,  0.0680,  1.8933,  ..., -0.5262, -0.5244, -0.0796]],

        [[-0.5359, -0.6325,  0.1118,  ...,  0.1259,  2.2169, -1.0668]],

        [[-1.0996, -1.1658, -0.5916,  ..., -0.5755, -1.6862,  0.4749]],

        ...,

        [[ 0.2458, -0.3946,  1.0235,  ..., -0.5735,  0.4849,  0.3989]],

        [[ 0.7768,  0.0073, -0.1360,  ...,  0.9113, -0.1509, -0.9519]],

        [[ 1.4372,  0.0824, -2.8678,  ..., -1.3473, -1.5964,  2.5229]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-1.2935,  0.0827,  0.8193,  ...,  0.5412, -1.1699,  0.1376]],

        [[ 0.0142,  2.2723,  0.3958,  ..., -1.6317,  0.1507,  1.7632]],

        [[-0.5361, -0.6323,  0.1121,  ...,  0.1259,  2.2166, -1.0665]],

        ...,

        [[-2.0528, -0.1551,  0.5212,  ..., -0.6929, -0.6981,  1.2814]],

        [[ 1.0282, -0.1423,  0.3160,  ..., -0.1928,  0.8345, -0.4905]],

        [[-1.9906,  0.3499,  0.0927,  ..., -1.0141, -0.7624, -0.7485]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.6623, -1.

  1%|          | 91/10910 [00:08<15:55, 11.32it/s]

tensor([[[-1.0988, -1.1658, -0.5931,  ..., -0.5761, -1.6869,  0.4764]],

        [[-0.5624, -0.0929,  1.0780,  ..., -0.5849,  0.7819, -0.5551]],

        [[-1.0506,  0.5546,  1.6137,  ..., -0.7581,  0.7209,  0.1583]],

        ...,

        [[-0.0787,  0.7282,  1.1690,  ...,  1.8949, -0.7410, -0.9223]],

        [[ 0.2513, -1.9888, -0.4661,  ..., -1.0770,  1.1425,  0.5205]],

        [[ 1.4875,  2.1933, -0.3492,  ..., -0.2475, -0.4817, -0.5539]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-1.0987, -1.1653, -0.5936,  ..., -0.5761, -1.6868,  0.4769]],

        [[ 0.2072, -0.0149, -1.8029,  ..., -0.9555,  0.8077,  2.2120]],

        [[-0.9524, -0.3525, -1.3266,  ..., -0.3298, -0.0423, -1.2259]],

        ...,

        [[-1.4125, -1.7101, -0.1981,  ...,  0.7564, -0.8361,  0.6609]],

        [[-0.7438,  0.6258, -1.4474,  ...,  0.0156, -1.5256,  1.1940]],

        [[ 1.4870,  2.1931, -0.3490,  ..., -0.2473, -0.4818, -0.5537]]],
       grad_fn=<EmbeddingBackward0>)


  1%|          | 93/10910 [00:08<15:53, 11.34it/s]

tensor([[[-0.0573,  0.5803,  2.3662,  ...,  0.8111, -1.0015,  1.2419]],

        [[-1.4125, -1.7100, -0.1979,  ...,  0.7563, -0.8357,  0.6610]],

        [[-1.9903,  0.3496,  0.0932,  ..., -1.0139, -0.7608, -0.7467]],

        ...,

        [[-1.7434, -0.1178,  0.7729,  ..., -1.2326, -1.2236,  0.9924]],

        [[ 0.3410, -0.3583, -0.1256,  ...,  0.2754, -0.3060, -1.9504]],

        [[ 1.7244,  0.4438,  0.6031,  ...,  1.1348, -0.1283,  0.6666]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.0680, -0.6172,  0.5756,  ...,  0.8594,  0.0564,  2.0103]],

        [[-0.5362, -0.6314,  0.1125,  ...,  0.1263,  2.2164, -1.0662]],

        [[ 0.5634, -0.1810, -1.7062,  ..., -1.4699, -0.4653, -0.0218]],

        ...,

        [[-0.3230,  0.1465,  0.5072,  ...,  0.2187, -0.6048,  0.5322]],

        [[-0.6495, -1.6264, -0.7117,  ..., -1.1585, -0.5097,  1.9976]],

        [[-1.4127, -1.7104, -0.1977,  ...,  0.7563, -0.8359,  0.6609]]],
       grad_fn=<EmbeddingBackward0>)


  1%|          | 95/10910 [00:08<15:44, 11.45it/s]

tensor([[[ 1.3276, -0.5960,  1.1372,  ...,  0.8422,  0.8780, -1.8005]],

        [[ 0.8179,  1.9800, -1.9201,  ..., -0.2639,  0.2657, -0.8489]],

        [[-0.1804,  0.7389, -1.0047,  ..., -0.8216, -0.2183,  1.4448]],

        ...,

        [[ 0.3632,  1.6329,  1.1072,  ..., -0.6442, -0.1617,  1.2404]],

        [[ 1.1384,  0.4320,  1.3342,  ..., -0.1637,  0.7718, -1.4315]],

        [[ 1.1448, -0.3067, -1.9300,  ...,  0.3983,  0.6356,  0.4584]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-1.2936,  0.0804,  0.8191,  ...,  0.5450, -1.1704,  0.1379]],

        [[ 0.5162,  1.1143, -0.1550,  ..., -0.7403, -0.6323, -0.5173]],

        [[ 0.5254, -0.3505,  0.9686,  ...,  2.8544, -0.8844,  0.2401]],

        ...,

        [[ 1.1506, -0.6707, -0.8678,  ...,  0.6087, -0.4619, -0.4659]],

        [[-0.7713,  0.1006,  0.3584,  ..., -0.4705, -0.4694,  1.1380]],

        [[-1.3959, -1.2906, -0.8442,  ...,  1.5115,  0.7455, -0.0255]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-1.1574, -0.

  1%|          | 97/10910 [00:08<15:37, 11.54it/s]

tensor([[[-1.0968, -1.1587, -0.5943,  ..., -0.5747, -1.6810,  0.4759]],

        [[ 0.0213,  0.8797, -0.1269,  ...,  0.3838, -0.6129,  0.6280]],

        [[-0.5748,  1.6079, -1.5353,  ..., -0.7515, -0.3067, -3.0948]],

        ...,

        [[-1.0968, -1.1587, -0.5943,  ..., -0.5747, -1.6810,  0.4759]],

        [[ 0.0736,  1.0752,  0.3856,  ...,  1.0446, -0.1173,  0.1295]],

        [[-0.3947,  0.7118,  0.0890,  ..., -0.2136,  1.3299,  0.7330]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.8506,  0.1514, -0.6592,  ...,  0.1631,  1.1936, -0.4304]],

        [[ 0.4538,  1.4980,  0.9129,  ...,  0.4056,  1.1238, -0.3206]],

        [[ 1.9013,  0.3426, -0.3361,  ...,  1.7319,  1.1289,  1.2356]],

        ...,

        [[-0.1882, -1.8075,  1.3284,  ..., -1.4544,  0.1038, -0.9392]],

        [[-1.9906,  0.3489,  0.0936,  ..., -1.0134, -0.7592, -0.7472]],

        [[-0.2371,  0.0431,  0.7079,  ..., -0.4676, -0.1684,  1.9549]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-1.4720,  0.

  1%|          | 101/10910 [00:09<15:41, 11.49it/s]

tensor([[[ 0.1295, -0.4721, -0.4598,  ..., -1.5921,  0.4118, -1.0547]],

        [[ 1.6932,  1.2153, -0.0809,  ..., -1.7524,  0.4872,  0.4760]],

        [[-0.5364, -0.6313,  0.1116,  ...,  0.1265,  2.2160, -1.0680]],

        ...,

        [[-1.4719,  0.1132,  0.9571,  ...,  2.1825, -1.8582,  2.3388]],

        [[-1.5534, -0.2927, -0.9262,  ..., -0.0173, -0.1407, -1.5474]],

        [[-0.4508,  0.2269,  0.0307,  ..., -0.3738,  0.7086, -0.1671]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-1.2912,  0.0811,  0.8183,  ...,  0.5466, -1.1683,  0.1394]],

        [[-0.8891, -0.8233,  0.7832,  ...,  0.1050, -0.1467, -0.3869]],

        [[-2.2781,  1.4609, -0.1848,  ..., -1.5307,  0.7403, -0.0354]],

        ...,

        [[-0.4504,  0.2273,  0.0308,  ..., -0.3733,  0.7087, -0.1665]],

        [[ 0.9734,  0.8317, -0.4412,  ..., -0.5247,  1.1247, -0.4153]],

        [[-0.4180,  0.8064, -1.0305,  ..., -1.3214, -0.6060, -1.1334]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-0.7044,  1.

  1%|          | 103/10910 [00:09<15:53, 11.34it/s]

tensor([[[ 0.3691,  0.7627, -0.4301,  ...,  2.0937,  0.2941,  0.3334]],

        [[ 1.4136,  0.7540,  1.2608,  ..., -2.3383,  2.1072,  0.6525]],

        [[ 0.7863,  0.9849,  1.2730,  ...,  0.8301, -1.0789,  0.0075]],

        ...,

        [[-0.2370,  0.0427,  0.7077,  ..., -0.4675, -0.1689,  1.9536]],

        [[-1.2906,  0.0818,  0.8184,  ...,  0.5471, -1.1675,  0.1399]],

        [[ 0.4523,  1.4972,  0.9120,  ...,  0.4059,  1.1235, -0.3198]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[-1.0260, -0.4982,  0.5264,  ...,  0.1421, -1.4814,  0.0974]],

        [[-1.3131,  1.1988, -0.2671,  ..., -0.2855,  1.0091,  0.8064]],

        [[-0.1530,  0.4134, -2.1864,  ..., -0.4470,  0.5353,  0.6312]],

        ...,

        [[ 0.3407, -0.5136,  0.1256,  ...,  0.4515,  0.0081,  0.2883]],

        [[-0.1533,  1.1557, -0.9706,  ..., -0.5088,  0.0134,  0.3478]],

        [[-0.6179, -1.4169,  2.1333,  ...,  2.1228,  1.5674,  0.3188]]],
       grad_fn=<EmbeddingBackward0>)
tensor([[[ 0.2752,  1.

  1%|          | 104/10910 [00:09<16:19, 11.03it/s]







KeyboardInterrupt: 

### Save the trained model

In [13]:
torch.save(neural_model, 'model.ckpt')

### Neural Model testing

In [17]:
neural_model = torch.load('model.ckpt')
neural_model.eval()
test_ppl = 0
with torch.no_grad():
    for i in range(0, len(test_ids)):
        inputs = test_ids[i][:, 0:n_gram-1].to(device)
        gold = test_ids[i][:, n_gram-1:].to(device)
        output = neural_model(inputs)
        cross_entropy = criterion(output, gold.reshape(-1)).item()
        print('Perplexity for test sample '+str(i)+' :', np.exp(cross_entropy))
        test_ppl += np.exp(cross_entropy)
    print('The average testing perplexity for neural LM: '+str(test_ppl/len(test_ids)))

Perplexity for test sample 0 : 131.72969878621925
Perplexity for test sample 1 : 1574.7212829407829
Perplexity for test sample 2 : 218.10068995888324
Perplexity for test sample 3 : 120.06119094099284
Perplexity for test sample 4 : 85.2189484101646
Perplexity for test sample 5 : 515.2834415076813
Perplexity for test sample 6 : 371.81643851171054
Perplexity for test sample 7 : 287.75292782528777
Perplexity for test sample 8 : 119.88790696081533
Perplexity for test sample 9 : 356.07168889966863
Perplexity for test sample 10 : 1160.8481994215201
Perplexity for test sample 11 : 279.610504443154
Perplexity for test sample 12 : 206.2456031778053
Perplexity for test sample 13 : 682.8737300757066
Perplexity for test sample 14 : 440.3478334698897
Perplexity for test sample 15 : 452.6268063150043
Perplexity for test sample 16 : 168.31670265709465
Perplexity for test sample 17 : 1125.5227618518334
Perplexity for test sample 18 : 843.6833387365839
Perplexity for test sample 19 : 551.4993537902428
P

Perplexity for test sample 179 : 390.15148680103533
Perplexity for test sample 180 : 257.0361527765083
Perplexity for test sample 181 : 208.62771096613014
Perplexity for test sample 182 : 117.51346527575234
Perplexity for test sample 183 : 173.71215614960136
Perplexity for test sample 184 : 699.3643507298634
Perplexity for test sample 185 : 402.58037402664274
Perplexity for test sample 186 : 292.3491922671582
Perplexity for test sample 187 : 203.49547878925674
Perplexity for test sample 188 : 317.8527091221824
Perplexity for test sample 189 : 292.04447691064973
Perplexity for test sample 190 : 438.7068008240047
Perplexity for test sample 191 : 372.5351806726931
Perplexity for test sample 192 : 379.72485554992653
Perplexity for test sample 193 : 89.87452370040987
Perplexity for test sample 194 : 816.3362612784208
Perplexity for test sample 195 : 75.12929250008855
Perplexity for test sample 196 : 1583.7841712815425
Perplexity for test sample 197 : 19.003800158585406
Perplexity for test s

Perplexity for test sample 344 : 299.98657862785797
Perplexity for test sample 345 : 516.7430345591669
Perplexity for test sample 346 : 459.70466007034145
Perplexity for test sample 347 : 155.76774282418833
Perplexity for test sample 348 : 416.55883886160694
Perplexity for test sample 349 : 131.0485562675835
Perplexity for test sample 350 : 309.52688560125193
Perplexity for test sample 351 : 808.0927779041558
Perplexity for test sample 352 : 381.5454885471305
Perplexity for test sample 353 : 417.8623025283808
Perplexity for test sample 354 : 42.31317870548912
Perplexity for test sample 355 : 837.4211662048051
Perplexity for test sample 356 : 47.72140544733154
Perplexity for test sample 357 : 295.66585996151775
Perplexity for test sample 358 : 558.8113077468736
Perplexity for test sample 359 : 451.13615935025337
Perplexity for test sample 360 : 348.5721083715625
Perplexity for test sample 361 : 431.1408609375156
Perplexity for test sample 362 : 368.61802747261174
Perplexity for test sam

Perplexity for test sample 506 : 456.9742818824559
Perplexity for test sample 507 : 518.1988549099124
Perplexity for test sample 508 : 902.4547593359952
Perplexity for test sample 509 : 234.34105931101485
Perplexity for test sample 510 : 607.9692481038168
Perplexity for test sample 511 : 682.8470297930626
Perplexity for test sample 512 : 219.69988562656167
Perplexity for test sample 513 : 739.3627804507234
Perplexity for test sample 514 : 259.10004470337
Perplexity for test sample 515 : 270.64542629295744
Perplexity for test sample 516 : 321.7170976750267
Perplexity for test sample 517 : 802.2225281645957
Perplexity for test sample 518 : 242.91464185022613
Perplexity for test sample 519 : 265.21558968793096
Perplexity for test sample 520 : 370.4809677913031
Perplexity for test sample 521 : 487.4680354600938
Perplexity for test sample 522 : 86.67060736792448
Perplexity for test sample 523 : 304.623804404759
Perplexity for test sample 524 : 184.02501108140643
Perplexity for test sample 5

Perplexity for test sample 666 : 115.84913894587257
Perplexity for test sample 667 : 280.91107450601083
Perplexity for test sample 668 : 1054.4221762685115
Perplexity for test sample 669 : 308.18331078297496
Perplexity for test sample 670 : 288.7085411815176
Perplexity for test sample 671 : 476.0244267501216
Perplexity for test sample 672 : 816.1295906692438
Perplexity for test sample 673 : 228.77451633018677
Perplexity for test sample 674 : 910.223076708887
Perplexity for test sample 675 : 307.7988272005231
Perplexity for test sample 676 : 276.9513232878786
Perplexity for test sample 677 : 549.2410868161959
Perplexity for test sample 678 : 404.8303282170591
Perplexity for test sample 679 : 331.2383892947579
Perplexity for test sample 680 : 73.37195391165386
Perplexity for test sample 681 : 86.0901277751469
Perplexity for test sample 682 : 162.61188550778746
Perplexity for test sample 683 : 481.96693822259925
Perplexity for test sample 684 : 510.7924506440719
Perplexity for test sample

Perplexity for test sample 841 : 209.40381296942604
Perplexity for test sample 842 : 125.06857542151207
Perplexity for test sample 843 : 377.34983549748415
Perplexity for test sample 844 : 643.8629575947732
Perplexity for test sample 845 : 231.5645228311551
Perplexity for test sample 846 : 218.75811743926644
Perplexity for test sample 847 : 351.54562251518627
Perplexity for test sample 848 : 303.42779685712725
Perplexity for test sample 849 : 553.0670702712513
Perplexity for test sample 850 : 304.89729880078943
Perplexity for test sample 851 : 134.76319857418895
Perplexity for test sample 852 : 335.8264809003096
Perplexity for test sample 853 : 292.7651875757262
Perplexity for test sample 854 : 294.843515412373
Perplexity for test sample 855 : 485.15122184322877
Perplexity for test sample 856 : 556.4135125626375
Perplexity for test sample 857 : 171.7098807019989
Perplexity for test sample 858 : 567.5958780706483
Perplexity for test sample 859 : 1135.3861022160306
Perplexity for test sa

Perplexity for test sample 1002 : 551.627700710296
Perplexity for test sample 1003 : 253.20651908164854
Perplexity for test sample 1004 : 273.77460663361006
Perplexity for test sample 1005 : 405.92052865955395
Perplexity for test sample 1006 : 227.05718291048447
Perplexity for test sample 1007 : 401.78584693237093
Perplexity for test sample 1008 : 667.3406711054267
Perplexity for test sample 1009 : 230.96891942690584
Perplexity for test sample 1010 : 291.56749160328786
Perplexity for test sample 1011 : 107.00590152998107
Perplexity for test sample 1012 : 170.60857062882624
Perplexity for test sample 1013 : 362.5103605846845
Perplexity for test sample 1014 : 56.42760603945141
Perplexity for test sample 1015 : 449.59767979595085
Perplexity for test sample 1016 : 189.47199645517128
Perplexity for test sample 1017 : 52.72150925888967
Perplexity for test sample 1018 : 98.46739659839982
Perplexity for test sample 1019 : 91.59388409719135
Perplexity for test sample 1020 : 95.2282195208738
Per

Perplexity for test sample 1162 : 50.937952406033915
Perplexity for test sample 1163 : 244.3673156541102
Perplexity for test sample 1164 : 285.41683515240436
Perplexity for test sample 1165 : 541.1651698130802
Perplexity for test sample 1166 : 303.8007360871377
Perplexity for test sample 1167 : 67.48735780393977
Perplexity for test sample 1168 : 231.5709271967185
Perplexity for test sample 1169 : 529.4699094243473
Perplexity for test sample 1170 : 374.33312252193855
Perplexity for test sample 1171 : 449.99317936808984
Perplexity for test sample 1172 : 305.9509856471932
Perplexity for test sample 1173 : 164.99034156007625
Perplexity for test sample 1174 : 223.94154530881926
Perplexity for test sample 1175 : 145.71515025648384
Perplexity for test sample 1176 : 247.70452271099833
Perplexity for test sample 1177 : 184.38927021156803
Perplexity for test sample 1178 : 286.257109804068
Perplexity for test sample 1179 : 487.3506659413777
Perplexity for test sample 1180 : 398.9207818824176
Perp

Perplexity for test sample 1328 : 91.62625326270039
Perplexity for test sample 1329 : 169.8237837836329
Perplexity for test sample 1330 : 97.83836722052901
Perplexity for test sample 1331 : 174.4305520653056
Perplexity for test sample 1332 : 189.0873273223109
Perplexity for test sample 1333 : 341.5093224287235
Perplexity for test sample 1334 : 367.6115344577577
Perplexity for test sample 1335 : 619.5608174436364
Perplexity for test sample 1336 : 195.032121825578
Perplexity for test sample 1337 : 192.7113118280382
Perplexity for test sample 1338 : 867.6488868901657
Perplexity for test sample 1339 : 1357.0875048685027
Perplexity for test sample 1340 : 389.41212620087634
Perplexity for test sample 1341 : 1042.9689841192735
Perplexity for test sample 1342 : 2367.338750797006
Perplexity for test sample 1343 : 544.85798811702
Perplexity for test sample 1344 : 203.6182640878723
Perplexity for test sample 1345 : 531.5107680622536
Perplexity for test sample 1346 : 294.87697828901486
Perplexity 

Perplexity for test sample 1496 : 490.9622612261127
Perplexity for test sample 1497 : 328.00358557238195
Perplexity for test sample 1498 : 232.7977475017291
Perplexity for test sample 1499 : 833.7245184717831
Perplexity for test sample 1500 : 185.52547700572603
Perplexity for test sample 1501 : 237.909523854581
Perplexity for test sample 1502 : 185.45144620571577
Perplexity for test sample 1503 : 206.81452876671804
Perplexity for test sample 1504 : 285.36185711767826
Perplexity for test sample 1505 : 174.07924254827097
Perplexity for test sample 1506 : 699.0366138459823
Perplexity for test sample 1507 : 407.7969452133125
Perplexity for test sample 1508 : 218.68187865592887
Perplexity for test sample 1509 : 459.9638325407894
Perplexity for test sample 1510 : 561.8773397142073
Perplexity for test sample 1511 : 308.72280698141594
Perplexity for test sample 1512 : 221.8387654359096
Perplexity for test sample 1513 : 629.7959931217086
Perplexity for test sample 1514 : 292.3618782115735
Perpl

Perplexity for test sample 1659 : 878.1046262987734
Perplexity for test sample 1660 : 396.30776185683584
Perplexity for test sample 1661 : 330.78160658086125
Perplexity for test sample 1662 : 1056.046424605917
Perplexity for test sample 1663 : 477.7362489396691
Perplexity for test sample 1664 : 250.6322728283422
Perplexity for test sample 1665 : 387.78835312889254
Perplexity for test sample 1666 : 214.19199538031643
Perplexity for test sample 1667 : 278.8102132301471
Perplexity for test sample 1668 : 259.39598876687415
Perplexity for test sample 1669 : 208.41950078659758
Perplexity for test sample 1670 : 210.7289237731233
Perplexity for test sample 1671 : 172.17278476343685
Perplexity for test sample 1672 : 191.5302751318045
Perplexity for test sample 1673 : 267.51884334226526
Perplexity for test sample 1674 : 87.6553382549052
Perplexity for test sample 1675 : 176.257089421371
Perplexity for test sample 1676 : 79.98493637933498
Perplexity for test sample 1677 : 231.76678751567826
Perpl

Perplexity for test sample 1829 : 19.14966629764543
Perplexity for test sample 1830 : 19.916998680572117
Perplexity for test sample 1831 : 100.64298152733562
Perplexity for test sample 1832 : 40.02135035853686
Perplexity for test sample 1833 : 201.04759283583536
Perplexity for test sample 1834 : 11.752772323634925
Perplexity for test sample 1835 : 25.266774711488754
Perplexity for test sample 1836 : 172.80964750361306
Perplexity for test sample 1837 : 11.31764206478881
Perplexity for test sample 1838 : 28.22336994767506
Perplexity for test sample 1839 : 277.5393525775208
Perplexity for test sample 1840 : 112.36640909777525
Perplexity for test sample 1841 : 46.63960821932585
Perplexity for test sample 1842 : 274.47666732899734
Perplexity for test sample 1843 : 150.45775775426984
Perplexity for test sample 1844 : 97.43578952440512
Perplexity for test sample 1845 : 93.48868064667116
Perplexity for test sample 1846 : 528.2480996855222
Perplexity for test sample 1847 : 153.62058320177542
Pe

Perplexity for test sample 1996 : 30.220577989879576
Perplexity for test sample 1997 : 323.76280911433025
Perplexity for test sample 1998 : 329.2892395116451
Perplexity for test sample 1999 : 2099.9447405819287
Perplexity for test sample 2000 : 662.1562254818155
Perplexity for test sample 2001 : 233.6024901169632
Perplexity for test sample 2002 : 239.31024150876254
Perplexity for test sample 2003 : 86.88901003428525
Perplexity for test sample 2004 : 457.4458477405905
Perplexity for test sample 2005 : 535.7813856947919
Perplexity for test sample 2006 : 356.29723905166213
Perplexity for test sample 2007 : 345.95203044782824
Perplexity for test sample 2008 : 228.59306557784112
Perplexity for test sample 2009 : 354.7773744443004
Perplexity for test sample 2010 : 259.11993678009543
Perplexity for test sample 2011 : 457.5804520093496
Perplexity for test sample 2012 : 942.4493067726208
Perplexity for test sample 2013 : 477.9280970671602
Perplexity for test sample 2014 : 481.2116534986371
Perp

Perplexity for test sample 2163 : 42.735131499065545
Perplexity for test sample 2164 : 28.14575718915661
Perplexity for test sample 2165 : 36.02874915294177
Perplexity for test sample 2166 : 552.833197494532
Perplexity for test sample 2167 : 9.469744715857077
Perplexity for test sample 2168 : 32.625807202561035
Perplexity for test sample 2169 : 298.0994849172126
Perplexity for test sample 2170 : 79.75699514967857
Perplexity for test sample 2171 : 51.11075825766078
Perplexity for test sample 2172 : 91.46600462381397
Perplexity for test sample 2173 : 580.8280432899224
Perplexity for test sample 2174 : 166.7127001731088
Perplexity for test sample 2175 : 83.2313586490183
Perplexity for test sample 2176 : 639.6811985149885
Perplexity for test sample 2177 : 399.678577334466
Perplexity for test sample 2178 : 604.2781556722514
Perplexity for test sample 2179 : 218.00669556305328
Perplexity for test sample 2180 : 3897.0786327892356
Perplexity for test sample 2181 : 645.4543520961337
Perplexity 

Perplexity for test sample 2325 : 151.07083498840532
Perplexity for test sample 2326 : 451.0486146224858
Perplexity for test sample 2327 : 126.42746579156824
Perplexity for test sample 2328 : 180.64170881710473
Perplexity for test sample 2329 : 379.968468670495
Perplexity for test sample 2330 : 127.84579779030457
Perplexity for test sample 2331 : 233.169805655809
Perplexity for test sample 2332 : 164.18488763139277
Perplexity for test sample 2333 : 241.9312657096678
Perplexity for test sample 2334 : 387.6554242638212
Perplexity for test sample 2335 : 176.1994433424697
Perplexity for test sample 2336 : 454.18561347508347
Perplexity for test sample 2337 : 292.1390481696869
Perplexity for test sample 2338 : 113.4909197735615
Perplexity for test sample 2339 : 318.3189584345658
Perplexity for test sample 2340 : 480.4275412080096
Perplexity for test sample 2341 : 145.3845811264012
Perplexity for test sample 2342 : 2657.537446411084
Perplexity for test sample 2343 : 389.9381587253607
Perplexi

Perplexity for test sample 2486 : 284.1575904616635
Perplexity for test sample 2487 : 722.4242097584179
Perplexity for test sample 2488 : 864.7926746213942
Perplexity for test sample 2489 : 242.50390623925796
Perplexity for test sample 2490 : 58.52110368363079
Perplexity for test sample 2491 : 235.31265781458163
Perplexity for test sample 2492 : 140.09908338377502
Perplexity for test sample 2493 : 380.0603395992393
Perplexity for test sample 2494 : 388.49943579024927
Perplexity for test sample 2495 : 554.533199838405
Perplexity for test sample 2496 : 821.1875575703863
Perplexity for test sample 2497 : 579.5200495300147
Perplexity for test sample 2498 : 425.2594978995974
Perplexity for test sample 2499 : 344.11757639028684
Perplexity for test sample 2500 : 407.4853565764044
Perplexity for test sample 2501 : 341.14555921173735
Perplexity for test sample 2502 : 141.23506428315068
Perplexity for test sample 2503 : 198.06045480732638
Perplexity for test sample 2504 : 643.6045005098405
Perpl

Perplexity for test sample 2651 : 434.15141028320977
Perplexity for test sample 2652 : 344.7714313126114
Perplexity for test sample 2653 : 524.1622337669652
Perplexity for test sample 2654 : 482.3651504367407
Perplexity for test sample 2655 : 325.58142473469763
Perplexity for test sample 2656 : 199.7555452261349
Perplexity for test sample 2657 : 530.12572977693
Perplexity for test sample 2658 : 201.8765330789004
Perplexity for test sample 2659 : 1330.9393260057764
Perplexity for test sample 2660 : 224.60672556270794
Perplexity for test sample 2661 : 194.81416207814416
Perplexity for test sample 2662 : 221.010463969588
Perplexity for test sample 2663 : 159.3131608057615
Perplexity for test sample 2664 : 317.28939087141856
Perplexity for test sample 2665 : 207.1964295405689
Perplexity for test sample 2666 : 97.60119154947684
Perplexity for test sample 2667 : 233.57119152622076
Perplexity for test sample 2668 : 197.02543354755034
Perplexity for test sample 2669 : 72.0098916422405
Perplexi

Perplexity for test sample 2815 : 609.6567880666879
Perplexity for test sample 2816 : 273.6868938390273
Perplexity for test sample 2817 : 219.72838248533938
Perplexity for test sample 2818 : 984.1712601320131
Perplexity for test sample 2819 : 414.86044253784723
Perplexity for test sample 2820 : 160.2988232440106
Perplexity for test sample 2821 : 681.8598529074734
Perplexity for test sample 2822 : 425.89629637297617
Perplexity for test sample 2823 : 480.91000870981156
Perplexity for test sample 2824 : 400.9102944234697
Perplexity for test sample 2825 : 303.9294020957635
Perplexity for test sample 2826 : 479.4661154498495
Perplexity for test sample 2827 : 147.90231033322266
Perplexity for test sample 2828 : 564.236257807964
Perplexity for test sample 2829 : 213.713518849927
Perplexity for test sample 2830 : 367.6842873543184
Perplexity for test sample 2831 : 1157.3302868575909
Perplexity for test sample 2832 : 652.3761663999504
Perplexity for test sample 2833 : 1187.1105395383138
Perplex

Perplexity for test sample 2972 : 187.25868747626362
Perplexity for test sample 2973 : 267.4975411819621
Perplexity for test sample 2974 : 320.7268507027058
Perplexity for test sample 2975 : 244.57994682021848
Perplexity for test sample 2976 : 218.47561449035177
Perplexity for test sample 2977 : 115.68905025376037
Perplexity for test sample 2978 : 62.08969874048472
Perplexity for test sample 2979 : 538.4331712902207
Perplexity for test sample 2980 : 379.7355386485419
Perplexity for test sample 2981 : 107.99686547993946
Perplexity for test sample 2982 : 235.99855984781797
Perplexity for test sample 2983 : 1041.6310404471951
Perplexity for test sample 2984 : 158.07541109751932
Perplexity for test sample 2985 : 566.6172374555539
Perplexity for test sample 2986 : 652.9690376946422
Perplexity for test sample 2987 : 16.409620438814905
Perplexity for test sample 2988 : 422.596856568211
Perplexity for test sample 2989 : 54.51265489807105
Perplexity for test sample 2990 : 70.63521257827658
Perp

Perplexity for test sample 3135 : 127.02067406773818
Perplexity for test sample 3136 : 290.9835940260357
Perplexity for test sample 3137 : 261.1165028024313
Perplexity for test sample 3138 : 149.9296476208465
Perplexity for test sample 3139 : 239.75557811113066
Perplexity for test sample 3140 : 360.4493735594759
Perplexity for test sample 3141 : 138.33238368449537
Perplexity for test sample 3142 : 780.1400221811165
Perplexity for test sample 3143 : 86.83334338851779
Perplexity for test sample 3144 : 194.15294800698933
Perplexity for test sample 3145 : 236.41856827033553
Perplexity for test sample 3146 : 525.0420103548654
Perplexity for test sample 3147 : 330.7778211074515
Perplexity for test sample 3148 : 607.8683704635001
Perplexity for test sample 3149 : 163.71348026650068
Perplexity for test sample 3150 : 170.9916823241568
Perplexity for test sample 3151 : 266.2697385163267
Perplexity for test sample 3152 : 614.6514240299603
Perplexity for test sample 3153 : 112.03229672006725
Perpl

Perplexity for test sample 3295 : 621.8166789315345
Perplexity for test sample 3296 : 561.3096278439074
Perplexity for test sample 3297 : 941.098047298797
Perplexity for test sample 3298 : 232.2164667325371
Perplexity for test sample 3299 : 164.6869410550976
Perplexity for test sample 3300 : 867.6108248199267
Perplexity for test sample 3301 : 141.16753226855607
Perplexity for test sample 3302 : 1286.791694554485
Perplexity for test sample 3303 : 297.4286062588036
Perplexity for test sample 3304 : 383.4594999296865
Perplexity for test sample 3305 : 512.1709357420046
Perplexity for test sample 3306 : 631.7959238920117
Perplexity for test sample 3307 : 571.2598150915524
Perplexity for test sample 3308 : 215.880673737401
Perplexity for test sample 3309 : 556.962727701612
Perplexity for test sample 3310 : 271.07125129709397
Perplexity for test sample 3311 : 139.04700882461947
Perplexity for test sample 3312 : 825.5774243153642
Perplexity for test sample 3313 : 425.3073565615926
Perplexity f

Perplexity for test sample 3456 : 346.96969506160514
Perplexity for test sample 3457 : 779.5350133034922
Perplexity for test sample 3458 : 215.9938343651339
Perplexity for test sample 3459 : 369.1668404254327
Perplexity for test sample 3460 : 383.55971358292027
Perplexity for test sample 3461 : 542.4678114097662
Perplexity for test sample 3462 : 382.5162533824486
Perplexity for test sample 3463 : 922.3575135277409
Perplexity for test sample 3464 : 341.338215911939
Perplexity for test sample 3465 : 854.8744602085543
Perplexity for test sample 3466 : 305.95828017407763
Perplexity for test sample 3467 : 433.19004249426155
Perplexity for test sample 3468 : 496.0385249459678
Perplexity for test sample 3469 : 258.6098976928039
Perplexity for test sample 3470 : 297.3079376587382
Perplexity for test sample 3471 : 174.56243422190045
Perplexity for test sample 3472 : 576.7277837485801
Perplexity for test sample 3473 : 583.066307435759
Perplexity for test sample 3474 : 1537.4897924876639
Perplexi

Perplexity for test sample 3615 : 396.31154136019666
Perplexity for test sample 3616 : 361.9068961458811
Perplexity for test sample 3617 : 417.61370966709814
Perplexity for test sample 3618 : 464.0913489746982
Perplexity for test sample 3619 : 143.02328513326034
Perplexity for test sample 3620 : 200.90346061306929
Perplexity for test sample 3621 : 168.06167035108956
Perplexity for test sample 3622 : 112.90241665426706
Perplexity for test sample 3623 : 432.53841966508634
Perplexity for test sample 3624 : 332.9476617767729
Perplexity for test sample 3625 : 97.37824134067287
Perplexity for test sample 3626 : 409.9833786957401
Perplexity for test sample 3627 : 992.0193402375203
Perplexity for test sample 3628 : 103.21691184419983
Perplexity for test sample 3629 : 624.6145520302589
Perplexity for test sample 3630 : 606.6594328414052
Perplexity for test sample 3631 : 458.94990244901703
Perplexity for test sample 3632 : 495.61862619872005
Perplexity for test sample 3633 : 191.64538393553008
P

Perplexity for test sample 3781 : 151.22478267162074
Perplexity for test sample 3782 : 62.107731824086265
Perplexity for test sample 3783 : 58.02876495725955
Perplexity for test sample 3784 : 95.23843694852437
Perplexity for test sample 3785 : 258.57611164239927
Perplexity for test sample 3786 : 108.34657958744512
Perplexity for test sample 3787 : 43.941568201059354
Perplexity for test sample 3788 : 65.31284525411203
Perplexity for test sample 3789 : 1149.841075829033
Perplexity for test sample 3790 : 158.38536012183388
Perplexity for test sample 3791 : 97.51917622866702
Perplexity for test sample 3792 : 165.56850117786917
Perplexity for test sample 3793 : 649.4415518187487
Perplexity for test sample 3794 : 92.00636307305375
Perplexity for test sample 3795 : 144.30822806189587
Perplexity for test sample 3796 : 94.34227188360477
Perplexity for test sample 3797 : 199.83519082428762
Perplexity for test sample 3798 : 127.05302159255983
Perplexity for test sample 3799 : 408.62401246034005
P

Perplexity for test sample 3944 : 353.82014504386063
Perplexity for test sample 3945 : 134.9482017539889
Perplexity for test sample 3946 : 295.0333761797733
Perplexity for test sample 3947 : 151.39072572189576
Perplexity for test sample 3948 : 48.84482859705798
Perplexity for test sample 3949 : 110.63604046926572
Perplexity for test sample 3950 : 80.83085649226966
Perplexity for test sample 3951 : 693.3787491738024
Perplexity for test sample 3952 : 252.95032141914848
Perplexity for test sample 3953 : 531.1965902902167
Perplexity for test sample 3954 : 489.8872405899352
Perplexity for test sample 3955 : 143.8103429957771
Perplexity for test sample 3956 : 398.57587153755344
Perplexity for test sample 3957 : 302.142534987776
Perplexity for test sample 3958 : 199.4078966807271
Perplexity for test sample 3959 : 310.4246705781878
Perplexity for test sample 3960 : 359.0709949087227
Perplexity for test sample 3961 : 1980.321303905847
Perplexity for test sample 3962 : 157.11840347627768
Perplex

Perplexity for test sample 4106 : 652.1382357429038
Perplexity for test sample 4107 : 213.60511778514325
Perplexity for test sample 4108 : 309.3934894500949
Perplexity for test sample 4109 : 291.1590291526794
Perplexity for test sample 4110 : 281.22616480733814
Perplexity for test sample 4111 : 81.70808386257194
Perplexity for test sample 4112 : 142.81713242873408
Perplexity for test sample 4113 : 285.51729267336304
Perplexity for test sample 4114 : 197.62150060104045
Perplexity for test sample 4115 : 121.58503325319178
Perplexity for test sample 4116 : 64.5318923954727
Perplexity for test sample 4117 : 508.689226262845
Perplexity for test sample 4118 : 95.4068871579686
Perplexity for test sample 4119 : 47.46131784363054
Perplexity for test sample 4120 : 199.78002619883685
Perplexity for test sample 4121 : 74.93787601807989
Perplexity for test sample 4122 : 11.752772323634925
Perplexity for test sample 4123 : 254.7890100295439
Perplexity for test sample 4124 : 680.7195796483549
Perplex

In [170]:
neural_model(the)

TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not int

In [57]:
    value = 0
    key
    for pair in x:
    #     print(pair)
        try:
            prob = neural_model.compute_mle(pair)
            if prob > value:
                value = prob
                key = pair

        except TypeError:
            # hit end of dict
            
            break
    prob = count_model.compute_mle(key)
    x = count_model.n_gram_count
    print(prob)
    
    
    for i in range(0, len(test_ids)):
        inputs = test_ids[i][:, 0:n_gram-1].to(device)
        gold = test_ids[i][:, n_gram-1:].to(device)
        output = neural_model(inputs)
        cross_entropy = criterion(output, gold.reshape(-1)).item()
        print('Perplexity for test sample '+str(i)+' :', np.exp(cross_entropy))
        test_ppl += np.exp(cross_entropy)

TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not str