# 1. DATA 

In [1]:
import torch 
from torch import nn 
import numpy as np
import argparse
from Dataset import FairytalesDataset
from torch import nn, optim
from torch.utils.data import DataLoader
from train import save_checkpoint, load_checkpoint, train
from models import basic_model
from test import predict 

In [2]:
DIR_PATH = "data/fairytales.txt"
START_TOKEN = "<s>"
END_TOKEN = "</s>"

parser = argparse.ArgumentParser()
parser.add_argument('--max-epochs', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--sequence_length', type=int, default=7)
args, unknown = parser.parse_known_args()

dataset = FairytalesDataset(DIR_PATH, START_TOKEN, END_TOKEN, args)
print(dataset.__getitem__(3))

(tensor([ 53,   9,  73,   6,  19, 149,  40]), tensor([  9,  73,   6,  19, 149,  40,  93]))


# 2. MODEL

Standard lstm model 

In [3]:
model = basic_model(dataset)
train(dataset, model, args, load_model=True)

-> Loading checkpoint


100%|██████████| 146/146 [01:08<00:00,  2.13it/s, loss=0.521]


[34m EPOCH 1: [0m Mean loss 1.939, [0m Perplexity 6.952


100%|██████████| 146/146 [01:09<00:00,  2.11it/s, loss=0.493]

[34m EPOCH 2: [0m Mean loss 1.932, [0m Perplexity 6.906





([1.93909981683509, 1.9323197333780053],
 [6.952489638760287e+00, 6.9055106195771945])

# 3. TESTING 

In [4]:
print(predict(dataset, model, text='<s> there was once'))

<s> there was once a king son who stood then the girl are talking carefully into the neighbouring swineherd barred </s> <s> but even will me you all she heard in a advice aches that paperarello asked on the lame seas </s> <s> the did would did a spot who is in kungla swimming </s> <s> then finding dreamed himself </s> <s> then you a ship replied you wish him you will never be had anything should be </s> <s> day they held him and fro out to cry where paperarello had gave it a thousand ropes and wife and he counted the princess </s>


# 4. METRIC 

## 4.1. BLEU Score

BLEU is a precision focused metric that. calculates n-gram overlap of the reference and generated texts

In [5]:
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu

def bleu(
    ref, 
    gen,
    weights=(1, 0, 0, 0)
): 
    
    """
    Implements the BLEU evaluation metric using nltk 
    
    Input: 
        ref(list): reference sentences
        gen(list): generated sentences
        weights(tuple): customized weights to evaluate the text with
    higher/lower order ngrams.
        
    Output: 
        bleu_score(float): the bleu score
        
    """
    reference = ref.split(" ")
    generated = gen.split(" ")
    ref_bleu = [] 
    gen_bleu = [] 
    for sentence in generated: 
        gen_bleu.append(sentence.split()) 
        
    for idx, sentence in enumerate(reference): 
        ref_bleu.append([sentence.split()]) 
    
    print(len(gen_bleu))
    print(len(ref_bleu))
    
    chencherry = SmoothingFunction() 
    score_bleu = corpus_bleu(ref_bleu, gen_bleu, weights, 
                             smoothing_function=chencherry.method4)
    
    return score_bleu
    
        
    

In [6]:
generated = predict(dataset, model, text='<s> lovely ilonka', next_words=16)

print('Generated:')
print(' ') 
print(generated)



Generated:
 
<s> lovely ilonka pursued said silent first sprang on her long ago before still this stream man in your </s>


In [7]:
vocab = dataset.words
words = vocab[:19]
reference =  ' '.join(words)
reference2 = '<s> lovely ilonka there was once a man son who which so it so there will herself a desert </s>'
print('Reference:')
print(' ') 
print(reference2)

Reference:
 
<s> lovely ilonka there was once a man son who which so it so there will herself a desert </s>


In [8]:
b = bleu(generated, reference2)
print(b)

20
20
0.2


## 4.2. Perplexity 

Training: 80%
Validation: 10% 
Test: 10%

**Perplexity** refers to the power of a probability distribution to predict, or assign probabilities, to a sample. 

Lower the perplexity value, the better the model. 

If the model is completely dumb(worst possible), perplexity = |v| i.e. size of the vocabulary.

Perplexity is model dependent

In [None]:
perplexity = 1 # log cannot take 0 
sent_perplexity = []