In [2]:
import tensorflow as tf
tf.config.experimental.list_physical_devices("GPU")

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [116]:
import torch
import spacy
import pandas as pd
import numpy as np
from collections import Counter
from Token import Clean
from Token import Tokenise,paddingString
from nltk.util import ngrams
from torch.utils.data import DataLoader
import argparse

filePath = '../DATA/Pride and Prejudice - Jane Austen.txt'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
class Dataset(torch.utils.data.Dataset):
    def __init__(self, train_data, batch_size, min_freq=5):
        self.data = train_data
        self.max_len = 20
        self.min_freq = min_freq
        self.vocab = []
        self.batch_size = batch_size
        self.ngramList = []
        sents = self.loadingWords()
        self.wordToIndex = {w: i for i, w in enumerate(self.vocab)}
        self.indexToWord = {i: w for i, w in enumerate(self.vocab)}
        self.padIndex = self.wordToIndex['<PAD>']
        self.unKnownIndex = self.wordToIndex['<UNK>']
        self.startIndex = self.wordToIndex['<START>']
        self.endIndex = self.wordToIndex['<END>']
        for sent in sents:
            tokens = sent
            prefix_seqs = []
            try:
                pfx = [tokens[0]]
                for token in tokens[1:]:
                    pfx.append(token)
                    prefix_seqs.append(pfx.copy())
                for i in range(len(prefix_seqs)):
                    currSeq = [self.wordToIndex.get(w,self.unKnownIndex) for w in prefix_seqs[i]]
                    pref_sq = [self.startIndex]+[self.padIndex]*(self.max_len-len(currSeq)) + [w for w in currSeq[1:]]
                    
                    self.ngramList.append(list(pref_sq))
            except IndexError:
                continue

                
    def loadingWords(self):
        text = [line for line in self.data if line.strip()]
        sentences = []
        wordFreq = {}
        mx = 0
        for line in text:
            tokens = Tokenise(line)
            tokens = ['<START>'] + tokens
            sentences.append(tokens)
            self.vocab += tokens
            mx = max(mx, len(tokens))
            for token in tokens:
                if token in wordFreq:
                    wordFreq[token] += 1
                else:
                    wordFreq[token] = 1

        # wordCount = Counter(wordFreq)
        wordCount = {}
        self.vocab = list(filter(lambda w: wordFreq[w] >= self.min_freq, self.vocab))
        self.vocab = ['<PAD>', '<UNK>','<END>'] + self.vocab
        self.vocab = set(self.vocab)
        self.vocabSize = len(self.vocab)
        print(self.vocabSize)
        self.max_len = max(mx,self.max_len)
        # print(sentences)
        return sentences



cuda


In [130]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [117]:
def splitData(corpus,train_ratio,valid_ratio,test_ratio):
    with open(corpus, 'r') as f:
        text = f.readlines()
    text = [line.strip() for line in text if line.strip()]
    train_size = int(len(text) * train_ratio)
    valid_size = int(len(text) * valid_ratio)
    test_size = int(len(text) * test_ratio)
    train_data = text[:train_size]
    valid_data = text[train_size:train_size + valid_size]
    test_data = text[train_size + valid_size:]
    return train_data, valid_data, test_data

train_data, valid_data, test_data = splitData(filePath,0.7,0.15,0.15)
BATCH_SIZE = 256
TRAINSET = Dataset(train_data, BATCH_SIZE)
VLAIDSET = Dataset(valid_data, BATCH_SIZE)
TESTSET = Dataset(test_data, BATCH_SIZE)

#TRAINSET.ngramList[500]

1697
516
522


In [118]:
def generateBatch(dataset):
    input_ngram,trg = [],[]
    for tg in dataset:
        input_ngram.append(tg[:-1])
        trg.append(tg[-1])
    return torch.tensor(input_ngram,dtype=torch.long),torch.tensor(trg,dtype=torch.long)

train_loader = DataLoader(TRAINSET.ngramList, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generateBatch)
valid_loader = DataLoader(VLAIDSET.ngramList, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generateBatch)
test_loader = DataLoader(TESTSET.ngramList, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generateBatch)


In [81]:
print(valid_loader) 

<torch.utils.data.dataloader.DataLoader object at 0x7efbcab409d0>


In [119]:
import torch
from torch import nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device {device}")

class LSTMmodel(nn.Module):
    def __init__(self, embedding_size, hidden_size, num_layers, vocabSize,dropout):
        super(LSTMmodel, self).__init__()
        self.vocab_size = vocabSize
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout

        self.embedding = nn.Embedding(
            self.vocab_size, self.embedding_size, device=device)
        self.lstm = nn.LSTM(input_size=self.embedding_size,
                            hidden_size=self.hidden_size, batch_first=True, device=device)
        self.dropLayer = nn.Dropout(p=self.dropout)
        self.output = nn.Linear(
            self.hidden_size, self.vocab_size, bias=False, device=device)
        self.log_softmax = nn.LogSoftmax(dim=1)

    def forward(self, xContext):
        xembed = self.dropLayer(self.embedding(xContext))
        out, hidden = self.lstm(xembed)
        out = self.log_softmax(self.output(out[:,-1]))
        return out, hidden



Using device cuda


In [120]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
# from dataset import Dataset
from torch.utils.data import DataLoader
# from Model import LSTMmodel
import sys
# from Token import Clean
# from Token import Tokenise

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device {device}")


class Evaluation:
    def __init__(self, model:nn.Module,epochs,datasetTrain:torch.utils.data.DataLoader,datasetValid:torch.utils.data.DataLoader,datasetTest:torch.utils.data.DataLoader):
        self.model = model
        self.datasetTrain = datasetTrain
        self.datasetValid = datasetValid
        self.datasetTest = datasetTest
        self.criterion = nn.CrossEntropyLoss()
        self.epochs = epochs
        self.clip = 1
        self.patience = 10
        self.learning_rate = 0.005
        self.optimizer = torch.optim.Adam(
            self.model.parameters(), lr=self.learning_rate,amsgrad = True)
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, step_size=6, gamma=0.1,last_epoch=-1,verbose=False)

    def train(self):
        maxValidLoss = math.inf
        ctr = 0
        def trainModel():
            epochAcc = 0
            epochLoss = 0
            self.model.train()
            # hidden = self.model.init_hidden(24)
            for i, (x, y) in enumerate(tqdm(self.datasetTrain)):
                x = x.to(device)
                y = y.to(device)
                self.optimizer.zero_grad()
                outputs, hidden = self.model(x)
                y = y.view(-1)
                loss = self.criterion(outputs, y)
                loss.backward()

                epochAcc += 100*(outputs.argmax(dim=1)==y).sum().item()/y.shape[0]
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
                epochLoss += loss.item()
                self.optimizer.step()
                if i % 100 == 0:
                    print(
                        f"Epoch: {epoch}, Iteration: {i}, Loss: {loss.item()}")
                    
            print(f"Epoch: {epoch}, Loss: {epochLoss/len(self.datasetTrain)}, Accuracy: {epochAcc/len(self.datasetTrain)}")
        
        def validate():
            self.model.eval()
            epochAcc = 0
            epochLoss = 0
            with torch.no_grad():
                for i, (x, y) in enumerate(tqdm(self.datasetValid)):
                    x = x.to(device)
                    y = y.to(device)
                    outputs, hidden = self.model(x)
                    y = y.view(-1)
                    loss = self.criterion(outputs, y)
                    epochAcc += 100*(outputs.argmax(dim=1)==y).float().mean()
                    epochLoss += loss.item()

            print(f"Validation Loss: {epochLoss/len(self.datasetValid)}, Validation Accuracy: {epochAcc/len(self.datasetValid)}")
            return epochLoss/len(self.datasetValid)

        for epoch in range(self.epochs):
            trainModel()
            valid_loss = validate()

            #valid_loss = self.validate()
            self.scheduler.step()
            if valid_loss < maxValidLoss:
                maxValidLoss = valid_loss
                torch.save(self.model.state_dict(), 'model1.pt')
                print("Model saved")
                ctr = 0
            else:
                ctr += 1
                print(f"Validation loss not improved for {ctr} epochs")
            if ctr > self.patience:
                print("Early stopping")
                break

    
    
    def test(self):
        self.model.eval()
        epochAcc = 0
        epochLoss = 0
        with torch.no_grad():
            for i, (x, y) in enumerate(tqdm(self.datasetTest)):
                x = x.to(device)
                y = y.to(device)
                outputs, hidden = self.model(x)
                y = y.view(-1)
                loss = self.criterion(outputs, y)
                epochAcc += 100*(outputs.argmax(dim=1)==y).sum().item()/y.shape[0]
                epochLoss += loss.item()

        print(f"Test Loss: {epochLoss/len(self.datasetTest)}, Test Accuracy: {epochAcc/len(self.datasetTest)}")
        return epochLoss/len(self.datasetTest)





Using device cuda


In [121]:
VOCAB_SIZE = TRAINSET.vocabSize
EMBEDDING_DIM = 512
HIDDEN_DIM = 256
NUM_LAYERS = 2
DROP_OUT = 0.5


lngMOD = LSTMmodel(EMBEDDING_DIM,HIDDEN_DIM,NUM_LAYERS, VOCAB_SIZE, DROP_OUT)
eval = Evaluation(lngMOD,20,train_loader,valid_loader,test_loader)


In [122]:
eval.train()

  1%|▍                                          | 3/342 [00:00<00:15, 22.07it/s]

Epoch: 0, Iteration: 0, Loss: 7.445207118988037


 31%|████████████▌                            | 105/342 [00:04<00:09, 24.50it/s]

Epoch: 0, Iteration: 100, Loss: 5.192803859710693


 60%|████████████████████████▍                | 204/342 [00:08<00:05, 24.64it/s]

Epoch: 0, Iteration: 200, Loss: 5.3461527824401855


 89%|████████████████████████████████████▋    | 306/342 [00:12<00:01, 24.66it/s]

Epoch: 0, Iteration: 300, Loss: 4.864221096038818


100%|█████████████████████████████████████████| 342/342 [00:13<00:00, 24.65it/s]


Epoch: 0, Loss: 5.173751306812665, Accuracy: 12.925628218314962


100%|███████████████████████████████████████████| 75/75 [00:00<00:00, 83.60it/s]


Validation Loss: 9.67641342163086, Validation Accuracy: 0.010416666977107525
Model saved


  1%|▍                                          | 3/342 [00:00<00:14, 23.66it/s]

Epoch: 1, Iteration: 0, Loss: 4.659819602966309


 31%|████████████▌                            | 105/342 [00:04<00:09, 24.02it/s]

Epoch: 1, Iteration: 100, Loss: 4.820914268493652


 60%|████████████████████████▍                | 204/342 [00:08<00:05, 23.64it/s]

Epoch: 1, Iteration: 200, Loss: 4.682744026184082


 89%|████████████████████████████████████▎    | 303/342 [00:12<00:01, 23.23it/s]

Epoch: 1, Iteration: 300, Loss: 4.804134845733643


100%|█████████████████████████████████████████| 342/342 [00:14<00:00, 23.81it/s]


Epoch: 1, Loss: 4.68869405880309, Accuracy: 15.346278156212804


100%|███████████████████████████████████████████| 75/75 [00:00<00:00, 76.22it/s]


Validation Loss: 10.428387247721354, Validation Accuracy: 0.0052083334885537624
Validation loss not improved for 1 epochs


  1%|▍                                          | 3/342 [00:00<00:14, 22.95it/s]

Epoch: 2, Iteration: 0, Loss: 4.440420627593994


 31%|████████████▌                            | 105/342 [00:04<00:10, 22.86it/s]

Epoch: 2, Iteration: 100, Loss: 4.528829574584961


 60%|████████████████████████▍                | 204/342 [00:08<00:06, 22.70it/s]

Epoch: 2, Iteration: 200, Loss: 4.544564723968506


 89%|████████████████████████████████████▎    | 303/342 [00:13<00:01, 22.61it/s]

Epoch: 2, Iteration: 300, Loss: 4.6070733070373535


100%|█████████████████████████████████████████| 342/342 [00:15<00:00, 22.72it/s]


Epoch: 2, Loss: 4.5136420043588386, Accuracy: 16.133097539201987


100%|███████████████████████████████████████████| 75/75 [00:01<00:00, 73.95it/s]


Validation Loss: 10.850994567871094, Validation Accuracy: 0.0052083334885537624
Validation loss not improved for 2 epochs


  1%|▍                                          | 3/342 [00:00<00:15, 22.42it/s]

Epoch: 3, Iteration: 0, Loss: 4.190368175506592


 31%|████████████▌                            | 105/342 [00:04<00:10, 22.13it/s]

Epoch: 3, Iteration: 100, Loss: 4.448470115661621


 60%|████████████████████████▍                | 204/342 [00:09<00:06, 21.58it/s]

Epoch: 3, Iteration: 200, Loss: 4.313256740570068


 89%|████████████████████████████████████▎    | 303/342 [00:13<00:01, 22.17it/s]

Epoch: 3, Iteration: 300, Loss: 4.448286533355713


100%|█████████████████████████████████████████| 342/342 [00:15<00:00, 22.17it/s]


Epoch: 3, Loss: 4.39848185979832, Accuracy: 16.757020050587382


100%|███████████████████████████████████████████| 75/75 [00:01<00:00, 71.91it/s]


Validation Loss: 11.478456598917644, Validation Accuracy: 0.0052083334885537624
Validation loss not improved for 3 epochs


  1%|▍                                          | 3/342 [00:00<00:15, 21.78it/s]

Epoch: 4, Iteration: 0, Loss: 4.134676933288574


 31%|████████████▌                            | 105/342 [00:04<00:10, 22.00it/s]

Epoch: 4, Iteration: 100, Loss: 4.323306560516357


 60%|████████████████████████▍                | 204/342 [00:09<00:06, 21.83it/s]

Epoch: 4, Iteration: 200, Loss: 4.429746150970459


 89%|████████████████████████████████████▎    | 303/342 [00:13<00:01, 21.52it/s]

Epoch: 4, Iteration: 300, Loss: 4.326157093048096


100%|█████████████████████████████████████████| 342/342 [00:15<00:00, 21.50it/s]


Epoch: 4, Loss: 4.297818169259188, Accuracy: 17.000688137194018


100%|███████████████████████████████████████████| 75/75 [00:01<00:00, 63.61it/s]


Validation Loss: 11.522048174540203, Validation Accuracy: 0.015625
Validation loss not improved for 4 epochs


  1%|▎                                          | 2/342 [00:00<00:20, 16.61it/s]

Epoch: 5, Iteration: 0, Loss: 4.149287223815918


 30%|████████████▎                            | 103/342 [00:05<00:12, 19.32it/s]

Epoch: 5, Iteration: 100, Loss: 4.198856353759766


 59%|████████████████████████▎                | 203/342 [00:10<00:07, 18.99it/s]

Epoch: 5, Iteration: 200, Loss: 4.370410919189453


 89%|████████████████████████████████████▎    | 303/342 [00:16<00:02, 19.26it/s]

Epoch: 5, Iteration: 300, Loss: 4.233919143676758


100%|█████████████████████████████████████████| 342/342 [00:18<00:00, 18.71it/s]


Epoch: 5, Loss: 4.2291547568917975, Accuracy: 17.418806280080734


100%|███████████████████████████████████████████| 75/75 [00:01<00:00, 58.89it/s]


Validation Loss: 11.771859537760417, Validation Accuracy: 0.015625
Validation loss not improved for 5 epochs


  1%|▎                                          | 2/342 [00:00<00:20, 16.74it/s]

Epoch: 6, Iteration: 0, Loss: 3.8611955642700195


 30%|████████████▎                            | 103/342 [00:05<00:12, 18.86it/s]

Epoch: 6, Iteration: 100, Loss: 4.006894111633301


 59%|████████████████████████▎                | 203/342 [00:11<00:07, 18.84it/s]

Epoch: 6, Iteration: 200, Loss: 4.021483898162842


 89%|████████████████████████████████████▎    | 303/342 [00:16<00:02, 19.14it/s]

Epoch: 6, Iteration: 300, Loss: 4.080804824829102


100%|█████████████████████████████████████████| 342/342 [00:18<00:00, 18.06it/s]


Epoch: 6, Loss: 4.013866158256754, Accuracy: 18.714794626222634


100%|███████████████████████████████████████████| 75/75 [00:01<00:00, 60.31it/s]


Validation Loss: 11.879722162882487, Validation Accuracy: 0.015625
Validation loss not improved for 6 epochs


  1%|▎                                          | 2/342 [00:00<00:18, 18.39it/s]

Epoch: 7, Iteration: 0, Loss: 3.9744203090667725


 30%|████████████▍                            | 104/342 [00:05<00:12, 18.38it/s]

Epoch: 7, Iteration: 100, Loss: 4.110064506530762


 60%|████████████████████████▍                | 204/342 [00:11<00:07, 18.42it/s]

Epoch: 7, Iteration: 200, Loss: 3.9358701705932617


 89%|████████████████████████████████████▍    | 304/342 [00:16<00:02, 18.50it/s]

Epoch: 7, Iteration: 300, Loss: 3.841094970703125


100%|█████████████████████████████████████████| 342/342 [00:18<00:00, 18.28it/s]


Epoch: 7, Loss: 3.968020526289243, Accuracy: 19.098404831677275


100%|███████████████████████████████████████████| 75/75 [00:01<00:00, 57.19it/s]


Validation Loss: 11.950617688496907, Validation Accuracy: 0.015625
Validation loss not improved for 7 epochs


  1%|▎                                          | 2/342 [00:00<00:19, 17.39it/s]

Epoch: 8, Iteration: 0, Loss: 3.9498209953308105


 30%|████████████▍                            | 104/342 [00:05<00:12, 18.51it/s]

Epoch: 8, Iteration: 100, Loss: 3.788978099822998


 60%|████████████████████████▍                | 204/342 [00:11<00:07, 18.41it/s]

Epoch: 8, Iteration: 200, Loss: 4.0999531745910645


 89%|████████████████████████████████████▍    | 304/342 [00:16<00:02, 17.62it/s]

Epoch: 8, Iteration: 300, Loss: 4.120283126831055


100%|█████████████████████████████████████████| 342/342 [00:19<00:00, 17.92it/s]


Epoch: 8, Loss: 3.941975834773995, Accuracy: 19.469016440899445


100%|███████████████████████████████████████████| 75/75 [00:01<00:00, 55.43it/s]


Validation Loss: 12.031474011739094, Validation Accuracy: 0.02083333395421505
Validation loss not improved for 8 epochs


  1%|▎                                          | 2/342 [00:00<00:19, 17.56it/s]

Epoch: 9, Iteration: 0, Loss: 3.876734733581543


 30%|████████████▍                            | 104/342 [00:05<00:13, 17.83it/s]

Epoch: 9, Iteration: 100, Loss: 3.886780261993408


 60%|████████████████████████▍                | 204/342 [00:11<00:08, 16.44it/s]

Epoch: 9, Iteration: 200, Loss: 3.740150213241577


 89%|████████████████████████████████████▍    | 304/342 [00:17<00:02, 16.65it/s]

Epoch: 9, Iteration: 300, Loss: 3.907977819442749


100%|█████████████████████████████████████████| 342/342 [00:20<00:00, 16.96it/s]


Epoch: 9, Loss: 3.9210219683005794, Accuracy: 19.506698219091238


100%|███████████████████████████████████████████| 75/75 [00:01<00:00, 51.75it/s]


Validation Loss: 12.087802429199218, Validation Accuracy: 0.02083333395421505
Validation loss not improved for 9 epochs


  1%|▎                                          | 2/342 [00:00<00:20, 16.91it/s]

Epoch: 10, Iteration: 0, Loss: 3.7973880767822266


 30%|████████████▍                            | 104/342 [00:06<00:14, 16.84it/s]

Epoch: 10, Iteration: 100, Loss: 3.8286118507385254


 60%|████████████████████████▍                | 204/342 [00:12<00:08, 17.10it/s]

Epoch: 10, Iteration: 200, Loss: 3.8157031536102295


 89%|████████████████████████████████████▍    | 304/342 [00:18<00:02, 17.13it/s]

Epoch: 10, Iteration: 300, Loss: 3.846782684326172


100%|█████████████████████████████████████████| 342/342 [00:20<00:00, 16.45it/s]


Epoch: 10, Loss: 3.901076406066181, Accuracy: 19.55384087292346


100%|███████████████████████████████████████████| 75/75 [00:01<00:00, 53.65it/s]


Validation Loss: 12.125290807088216, Validation Accuracy: 0.03125
Validation loss not improved for 10 epochs


  1%|▎                                          | 2/342 [00:00<00:18, 18.21it/s]

Epoch: 11, Iteration: 0, Loss: 3.6682705879211426


 30%|████████████▍                            | 104/342 [00:06<00:13, 17.18it/s]

Epoch: 11, Iteration: 100, Loss: 3.7639076709747314


 60%|████████████████████████▍                | 204/342 [00:11<00:08, 17.13it/s]

Epoch: 11, Iteration: 200, Loss: 3.7933802604675293


 89%|████████████████████████████████████▍    | 304/342 [00:17<00:02, 17.16it/s]

Epoch: 11, Iteration: 300, Loss: 3.774815797805786


100%|█████████████████████████████████████████| 342/342 [00:19<00:00, 17.13it/s]


Epoch: 11, Loss: 3.885681061716805, Accuracy: 19.6758113711639


100%|███████████████████████████████████████████| 75/75 [00:01<00:00, 53.76it/s]

Validation Loss: 12.231569582621256, Validation Accuracy: 0.0260416679084301
Validation loss not improved for 11 epochs
Early stopping





In [123]:
eval.test()

100%|███████████████████████████████████████████| 72/72 [00:00<00:00, 81.70it/s]

Test Loss: 13.974626845783657, Test Accuracy: 0.005425347222222222





13.974626845783657

In [124]:
globalVocab = TRAINSET.vocab
globalWordToIndex = TRAINSET.wordToIndex
PAD_INDEX = TRAINSET.padIndex
START_INDEX = TRAINSET.startIndex
UNK_INDEX = TRAINSET.unKnownIndex
MAX_LEN = TRAINSET.max_len

def getProbPerplexity(model,dataset):
    model.eval()
    perplexity_list = []
    with torch.no_grad():
        for line in dataset.data:
            perplexity = perpForSentence(model,line)
            #print(perplexity)
            if perplexity != -1:
                perplexity_list.append({'line':line,'perplexity':perplexity})
            
    # averagePerplexty
    avgPerplexity = sum([line['perplexity'] for line in perplexity_list])/len(perplexity_list)
    return perplexity_list,avgPerplexity

def writeToFile(filePath,perplexity_list,avg):
    with open(filePath,'w') as f:
        f.write(f"Average Perplexity: {avg}\n")
        for line in perplexity_list:
            f.write(f"{line['line']}\t {line['perplexity']}\n")
            
            
def perpForSentence(model,sentence):
    model.eval()
    with torch.no_grad():
        prob_gram = 1
        tokens = Tokenise(sentence)
        tokens = ['<START>'] + tokens 

        if len(tokens)==0 or len(tokens)==1:
            return -1
        
        elif len(tokens)>1:
            prefix_seqs = []
            gramList = []
            try:
                pfx = [tokens[0]]
                for token in tokens[1:]:
                    pfx.append(token)
                    prefix_seqs.append(pfx.copy())
                for i in range(len(prefix_seqs)):
                    currSeq = [globalWordToIndex.get(w,UNK_INDEX) for w in prefix_seqs[i]]
                    pref_sq = [START_INDEX]+ [PAD_INDEX]*(MAX_LEN-len(currSeq)) + [w for w in currSeq[1:]]
                    gramList.append(list(pref_sq))
            except IndexError:
                return -1
  
            if len(gramList)>0:
                for gram in gramList:
                    input_gram = torch.tensor(gram[:-1],dtype=torch.long).to(device)
                    output_gram = gram[-1]
                    output,hidden = model(input_gram.unsqueeze(dim=0))
                    output = torch.exp(output.view(-1))
                
                    prob_gram = prob_gram * output[output_gram].cpu().numpy()
              
                perplexity = (1/prob_gram)**(1/len(gramList))
                return perplexity
            else:
                return -1


In [126]:


perplexity_list, avgPerplexity = getProbPerplexity(lngMOD,TESTSET)
path = 'test1_perplexity.txt'
writeToFile(path,perplexity_list,avgPerplexity)

In [56]:
str1 = "mary king is safe"
str2 = "the warrior"

print(perpForSentence(lngMOD,str1))
print(perpForSentence(lngMOD,str2))

370.513098422167
13.476429683065973


In [108]:
MAX_LEN = TRAINSET.max_len
MAX_LEN

21

In [131]:
lngMOD.load_state_dict(torch.load('model1.pt'))


<All keys matched successfully>

In [129]:
str1 = "mary king is safe"
str2 = "the warrior"

print(perpForSentence(lngMOD,str1))
print(perpForSentence(lngMOD,str2))

1028.9875237221133
12.814801035080004
