In [105]:
# Import Libraries..

In [1]:
import pandas as pd
import numpy as np
import re
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset

In [2]:
txt = '''I am honored to be with you today at your commencement from one of the finest universities in the world. I never graduated from college. Truth be told, this is the closest I’ve ever gotten to a college graduation. Today I want to tell you three stories from my life. That’s it. No big deal. Just three stories.

The first story is about connecting the dots.

I dropped out of Reed College after the first 6 months, but then stayed around as a drop-in for another 18 months or so before I really quit. So why did I drop out?

It started before I was born. My biological mother was a young, unwed college graduate student, and she decided to put me up for adoption. She felt very strongly that I should be adopted by college graduates, so everything was all set for me to be adopted at birth by a lawyer and his wife. Except that when I popped out they decided at the last minute that they really wanted a girl. So my parents, who were on a waiting list, got a call in the middle of the night asking: “We have an unexpected baby boy; do you want him?” They said: “Of course.” My biological mother later found out that my mother had never graduated from college and that my father had never graduated from high school. She refused to sign the final adoption papers. She only relented a few months later when my parents promised that I would someday go to college.

And 17 years later I did go to college. But I naively chose a college that was almost as expensive as Stanford, and all of my working-class parents’ savings were being spent on my college tuition. After six months, I couldn’t see the value in it. I had no idea what I wanted to do with my life and no idea how college was going to help me figure it out. And here I was spending all of the money my parents had saved their entire life. So I decided to drop out and trust that it would all work out OK. It was pretty scary at the time, but looking back it was one of the best decisions I ever made. The minute I dropped out I could stop taking the required classes that didn’t interest me, and begin dropping in on the ones that looked interesting.
'''

In [3]:
#lower case
sent_low = ' '.join([x.lower() for x in txt.split()])

In [4]:
sent_low

'i am honored to be with you today at your commencement from one of the finest universities in the world. i never graduated from college. truth be told, this is the closest i’ve ever gotten to a college graduation. today i want to tell you three stories from my life. that’s it. no big deal. just three stories. the first story is about connecting the dots. i dropped out of reed college after the first 6 months, but then stayed around as a drop-in for another 18 months or so before i really quit. so why did i drop out? it started before i was born. my biological mother was a young, unwed college graduate student, and she decided to put me up for adoption. she felt very strongly that i should be adopted by college graduates, so everything was all set for me to be adopted at birth by a lawyer and his wife. except that when i popped out they decided at the last minute that they really wanted a girl. so my parents, who were on a waiting list, got a call in the middle of the night asking: “we

In [5]:
#punctuation remove
non_punctuated_sent = ' '.join(re.sub('[^a-zA-Z]', ' ', sent_low).split())  #by splitting later it removes the spaces within the words

In [6]:
non_punctuated_sent

'i am honored to be with you today at your commencement from one of the finest universities in the world i never graduated from college truth be told this is the closest i ve ever gotten to a college graduation today i want to tell you three stories from my life that s it no big deal just three stories the first story is about connecting the dots i dropped out of reed college after the first months but then stayed around as a drop in for another months or so before i really quit so why did i drop out it started before i was born my biological mother was a young unwed college graduate student and she decided to put me up for adoption she felt very strongly that i should be adopted by college graduates so everything was all set for me to be adopted at birth by a lawyer and his wife except that when i popped out they decided at the last minute that they really wanted a girl so my parents who were on a waiting list got a call in the middle of the night asking we have an unexpected baby boy

In [7]:
vocab = set(non_punctuated_sent.split())      #vocabulary

In [8]:
vocab.add("<UNK>")
vocab.add("<PAD>")

In [9]:
vocab = sorted(list(vocab))
vocab[:5]

['<PAD>', '<UNK>', 'a', 'about', 'adopted']

In [10]:
# word to index

word_to_index = {}

for i, word in enumerate(vocab):
    word_to_index[word] = i

In [11]:
word_to_index

{'<PAD>': 0,
 '<UNK>': 1,
 'a': 2,
 'about': 3,
 'adopted': 4,
 'adoption': 5,
 'after': 6,
 'all': 7,
 'almost': 8,
 'am': 9,
 'an': 10,
 'and': 11,
 'another': 12,
 'around': 13,
 'as': 14,
 'asking': 15,
 'at': 16,
 'baby': 17,
 'back': 18,
 'be': 19,
 'before': 20,
 'begin': 21,
 'being': 22,
 'best': 23,
 'big': 24,
 'biological': 25,
 'birth': 26,
 'born': 27,
 'boy': 28,
 'but': 29,
 'by': 30,
 'call': 31,
 'chose': 32,
 'class': 33,
 'classes': 34,
 'closest': 35,
 'college': 36,
 'commencement': 37,
 'connecting': 38,
 'could': 39,
 'couldn': 40,
 'course': 41,
 'deal': 42,
 'decided': 43,
 'decisions': 44,
 'did': 45,
 'didn': 46,
 'do': 47,
 'dots': 48,
 'drop': 49,
 'dropped': 50,
 'dropping': 51,
 'entire': 52,
 'ever': 53,
 'everything': 54,
 'except': 55,
 'expensive': 56,
 'father': 57,
 'felt': 58,
 'few': 59,
 'figure': 60,
 'final': 61,
 'finest': 62,
 'first': 63,
 'for': 64,
 'found': 65,
 'from': 66,
 'girl': 67,
 'go': 68,
 'going': 69,
 'got': 70,
 'gotten': 71,

In [12]:
# index to word

index_to_word = {}

for i, word in enumerate(vocab):
    index_to_word[i] = word

In [13]:
index_to_word

{0: '<PAD>',
 1: '<UNK>',
 2: 'a',
 3: 'about',
 4: 'adopted',
 5: 'adoption',
 6: 'after',
 7: 'all',
 8: 'almost',
 9: 'am',
 10: 'an',
 11: 'and',
 12: 'another',
 13: 'around',
 14: 'as',
 15: 'asking',
 16: 'at',
 17: 'baby',
 18: 'back',
 19: 'be',
 20: 'before',
 21: 'begin',
 22: 'being',
 23: 'best',
 24: 'big',
 25: 'biological',
 26: 'birth',
 27: 'born',
 28: 'boy',
 29: 'but',
 30: 'by',
 31: 'call',
 32: 'chose',
 33: 'class',
 34: 'classes',
 35: 'closest',
 36: 'college',
 37: 'commencement',
 38: 'connecting',
 39: 'could',
 40: 'couldn',
 41: 'course',
 42: 'deal',
 43: 'decided',
 44: 'decisions',
 45: 'did',
 46: 'didn',
 47: 'do',
 48: 'dots',
 49: 'drop',
 50: 'dropped',
 51: 'dropping',
 52: 'entire',
 53: 'ever',
 54: 'everything',
 55: 'except',
 56: 'expensive',
 57: 'father',
 58: 'felt',
 59: 'few',
 60: 'figure',
 61: 'final',
 62: 'finest',
 63: 'first',
 64: 'for',
 65: 'found',
 66: 'from',
 67: 'girl',
 68: 'go',
 69: 'going',
 70: 'got',
 71: 'gotten',

In [14]:
#sentences to index convert

sent_to_index = []

for w in non_punctuated_sent.split():
    if w not in word_to_index:
        sent_to_index.append(word_to_index['<UNK>'])
    else:
        sent_to_index.append(word_to_index[w])

In [15]:
sent_to_index[:5]

[85, 9, 83, 167, 19]

In [16]:
index_to_word[83]

'honored'

In [17]:
non_punctuated_sent

'i am honored to be with you today at your commencement from one of the finest universities in the world i never graduated from college truth be told this is the closest i ve ever gotten to a college graduation today i want to tell you three stories from my life that s it no big deal just three stories the first story is about connecting the dots i dropped out of reed college after the first months but then stayed around as a drop in for another months or so before i really quit so why did i drop out it started before i was born my biological mother was a young unwed college graduate student and she decided to put me up for adoption she felt very strongly that i should be adopted by college graduates so everything was all set for me to be adopted at birth by a lawyer and his wife except that when i popped out they decided at the last minute that they really wanted a girl so my parents who were on a waiting list got a call in the middle of the night asking we have an unexpected baby boy

In [18]:
# window size data
seq_len = 4

X = []
Y = []

for l in range(seq_len, len(sent_to_index)):
    X.append(sent_to_index[l-seq_len:l])
    Y.append(sent_to_index[l])

In [19]:
X[:4]

[[85, 9, 83, 167], [9, 83, 167, 19], [83, 167, 19, 191], [167, 19, 191, 197]]

In [20]:
Y[:4]

[19, 191, 197, 168]

In [21]:
sent_to_index[:10]

[85, 9, 83, 167, 19, 191, 197, 168, 16, 199]

In [22]:
len(X)

403

In [23]:
len(vocab)

200

In [24]:
# data split in batch

xarray = np.array(X)
yarray = np.array(Y)

ix_train = int(xarray.shape[0]*0.7)
ix_valid = int(xarray.shape[0]*0.85)

X_train = torch.tensor(xarray[:ix_train], dtype = torch.long)
Y_train = torch.tensor(yarray[:ix_train], dtype = torch.long)

X_valid = torch.tensor(xarray[ix_train:ix_valid], dtype = torch.long)
Y_valid = torch.tensor(yarray[ix_train:ix_valid], dtype = torch.long)

X_test = torch.tensor(xarray[ix_valid:], dtype = torch.long)
Y_test = torch.tensor(yarray[ix_valid:], dtype = torch.long)

In [25]:
Y_train[:5]

tensor([ 19, 191, 197, 168,  16])

In [26]:
#model initializtion

num_embed = len(vocab)
embed_dim = 500
batch = xarray.shape[0]
seq_len = xarray.shape[1]
inp_sz = embed_dim
rnn_hidsz = 100
nlayer = 2
fc_hidsz = seq_len*rnn_hidsz
out_sz = len(vocab)

In [27]:
out_sz

200

In [28]:
def TestAccuracy(pred, yt):
    
    pred_ix = pred.argmax(dim = 1)
    from sklearn.metrics import accuracy_score
    acc = accuracy_score(yt.numpy(), pred_ix.numpy())
    return (np.round(acc, 3))

In [29]:
def ValidationAnalysis(model, xtest, ytest):
    with torch.no_grad():
        yp = model(xtest)
        pred_ix = yp.argmax(dim = 1)
    from sklearn.metrics import accuracy_score
    acc = np.round(accuracy_score(ytest.numpy(), pred_ix.numpy()),3)
    return (yp, acc)

## RNN Model

In [30]:
from model import RNN

In [31]:
#Model

model_rnn = RNN.RNNTG(num_embed, embed_dim, seq_len, inp_sz, rnn_hidsz, nlayer, fc_hidsz, out_sz)

#loss function and optimizer

criterion_rnn = nn.CrossEntropyLoss()
optimizer_rnn = torch.optim.Adam(model_rnn.parameters(), 0.05)

In [32]:
print(model_rnn)

RNNTG(
  (embed): Embedding(200, 500)
  (rnn): RNN(500, 100, num_layers=2)
  (fc): Linear(in_features=400, out_features=200, bias=True)
  (prob): Softmax(dim=1)
  (drop): Dropout(p=0.4, inplace=False)
)


In [33]:
#Training-Validation

epochs = 1000
train_losses = []
train_accuracy = []
valid_losses = []
valid_accuracy = []

print("Training started...")

for e in range(epochs):
    
    optimizer_rnn.zero_grad()
    
    #trainig
    
    out = model_rnn(X_train)
    loss = criterion_rnn(out, Y_train)
    train_losses.append(loss.detach())
    train_acc = TestAccuracy(out, Y_train)
    train_accuracy.append(train_acc)
    
    #validation
    
    pred_valid, valid_acc = ValidationAnalysis(model_rnn, X_valid, Y_valid)
    valid_loss = criterion_rnn(pred_valid, Y_valid)
    valid_losses.append(valid_loss.detach())
    valid_accuracy.append(valid_acc)   
    
    #model save
    
    if (e==0):
        d = "do nothing"
    elif (e > 2):
        if ((train_accuracy[-1]>train_accuracy[-2])):
            FILE = "model_rnn.pth"
            torch.save(model_rnn, FILE)
   
    torch.nn.utils.clip_grad_norm_(model_rnn.parameters(), max_norm=1.0)
    loss.backward()  
    optimizer_rnn.step()
    
    print("Epoch: {} | Train loss: {} | Train Accuracy: {} | Validation loss : {} validation Accuracy: {}".
             format(e, loss.detach(), train_acc, valid_loss.detach(), valid_acc))       

Training started...
Epoch: 0 | Train loss: 5.2982611656188965 | Train Accuracy: 0.007 | Validation loss : 5.2978129386901855 validation Accuracy: 0.017
Epoch: 1 | Train loss: 5.13262939453125 | Train Accuracy: 0.301 | Validation loss : 5.295111656188965 validation Accuracy: 0.0
Epoch: 2 | Train loss: 5.032896518707275 | Train Accuracy: 0.294 | Validation loss : 5.275433540344238 validation Accuracy: 0.05
Epoch: 3 | Train loss: 4.957578182220459 | Train Accuracy: 0.362 | Validation loss : 5.282975673675537 validation Accuracy: 0.017
Epoch: 4 | Train loss: 4.965887546539307 | Train Accuracy: 0.344 | Validation loss : 5.299034595489502 validation Accuracy: 0.0
Epoch: 5 | Train loss: 4.934190273284912 | Train Accuracy: 0.372 | Validation loss : 5.304874897003174 validation Accuracy: 0.0
Epoch: 6 | Train loss: 4.902218341827393 | Train Accuracy: 0.408 | Validation loss : 5.300260543823242 validation Accuracy: 0.0
Epoch: 7 | Train loss: 4.936458110809326 | Train Accuracy: 0.369 | Validation 

Epoch: 66 | Train loss: 4.833051681518555 | Train Accuracy: 0.475 | Validation loss : 5.306195259094238 validation Accuracy: 0.0
Epoch: 67 | Train loss: 4.855969429016113 | Train Accuracy: 0.454 | Validation loss : 5.29523229598999 validation Accuracy: 0.017
Epoch: 68 | Train loss: 4.881511211395264 | Train Accuracy: 0.426 | Validation loss : 5.290103435516357 validation Accuracy: 0.017
Epoch: 69 | Train loss: 4.824021816253662 | Train Accuracy: 0.482 | Validation loss : 5.289345741271973 validation Accuracy: 0.017
Epoch: 70 | Train loss: 4.8654375076293945 | Train Accuracy: 0.443 | Validation loss : 5.284965991973877 validation Accuracy: 0.017
Epoch: 71 | Train loss: 4.794620990753174 | Train Accuracy: 0.514 | Validation loss : 5.27573299407959 validation Accuracy: 0.033
Epoch: 72 | Train loss: 4.860933303833008 | Train Accuracy: 0.447 | Validation loss : 5.305999755859375 validation Accuracy: 0.0
Epoch: 73 | Train loss: 4.853252410888672 | Train Accuracy: 0.454 | Validation loss : 5.

Epoch: 133 | Train loss: 4.8904500007629395 | Train Accuracy: 0.415 | Validation loss : 5.273036479949951 validation Accuracy: 0.033
Epoch: 134 | Train loss: 4.846502780914307 | Train Accuracy: 0.461 | Validation loss : 5.289971828460693 validation Accuracy: 0.017
Epoch: 135 | Train loss: 4.843342304229736 | Train Accuracy: 0.465 | Validation loss : 5.273199081420898 validation Accuracy: 0.033
Epoch: 136 | Train loss: 4.917154312133789 | Train Accuracy: 0.39 | Validation loss : 5.306422233581543 validation Accuracy: 0.0
Epoch: 137 | Train loss: 4.864794731140137 | Train Accuracy: 0.443 | Validation loss : 5.30641508102417 validation Accuracy: 0.0
Epoch: 138 | Train loss: 4.898582935333252 | Train Accuracy: 0.408 | Validation loss : 5.296485900878906 validation Accuracy: 0.017
Epoch: 139 | Train loss: 4.843423366546631 | Train Accuracy: 0.465 | Validation loss : 5.2732110023498535 validation Accuracy: 0.033
Epoch: 140 | Train loss: 4.868093013763428 | Train Accuracy: 0.436 | Validation 

Epoch: 200 | Train loss: 4.842569351196289 | Train Accuracy: 0.465 | Validation loss : 5.289729595184326 validation Accuracy: 0.017
Epoch: 201 | Train loss: 4.850162982940674 | Train Accuracy: 0.457 | Validation loss : 5.302626132965088 validation Accuracy: 0.0
Epoch: 202 | Train loss: 4.865540981292725 | Train Accuracy: 0.44 | Validation loss : 5.289824962615967 validation Accuracy: 0.017
Epoch: 203 | Train loss: 4.827719688415527 | Train Accuracy: 0.479 | Validation loss : 5.306474685668945 validation Accuracy: 0.0
Epoch: 204 | Train loss: 4.888180255889893 | Train Accuracy: 0.418 | Validation loss : 5.30665922164917 validation Accuracy: 0.0
Epoch: 205 | Train loss: 4.852897644042969 | Train Accuracy: 0.454 | Validation loss : 5.304295539855957 validation Accuracy: 0.0
Epoch: 206 | Train loss: 4.873838901519775 | Train Accuracy: 0.433 | Validation loss : 5.306653022766113 validation Accuracy: 0.0
Epoch: 207 | Train loss: 4.842424392700195 | Train Accuracy: 0.465 | Validation loss : 5

Epoch: 266 | Train loss: 4.91649866104126 | Train Accuracy: 0.39 | Validation loss : 5.28046989440918 validation Accuracy: 0.033
Epoch: 267 | Train loss: 4.8906145095825195 | Train Accuracy: 0.415 | Validation loss : 5.289854526519775 validation Accuracy: 0.017
Epoch: 268 | Train loss: 4.901970863342285 | Train Accuracy: 0.404 | Validation loss : 5.306694030761719 validation Accuracy: 0.0
Epoch: 269 | Train loss: 4.913266658782959 | Train Accuracy: 0.394 | Validation loss : 5.289971828460693 validation Accuracy: 0.017
Epoch: 270 | Train loss: 4.896839618682861 | Train Accuracy: 0.408 | Validation loss : 5.306751728057861 validation Accuracy: 0.0
Epoch: 271 | Train loss: 4.863366603851318 | Train Accuracy: 0.443 | Validation loss : 5.306663513183594 validation Accuracy: 0.0
Epoch: 272 | Train loss: 4.890852928161621 | Train Accuracy: 0.418 | Validation loss : 5.306674003601074 validation Accuracy: 0.0
Epoch: 273 | Train loss: 4.903736591339111 | Train Accuracy: 0.404 | Validation loss :

Epoch: 331 | Train loss: 4.838963508605957 | Train Accuracy: 0.468 | Validation loss : 5.2899088859558105 validation Accuracy: 0.017
Epoch: 332 | Train loss: 4.887302398681641 | Train Accuracy: 0.418 | Validation loss : 5.273313045501709 validation Accuracy: 0.033
Epoch: 333 | Train loss: 4.8812360763549805 | Train Accuracy: 0.426 | Validation loss : 5.290008068084717 validation Accuracy: 0.017
Epoch: 334 | Train loss: 4.838199138641357 | Train Accuracy: 0.468 | Validation loss : 5.289089679718018 validation Accuracy: 0.017
Epoch: 335 | Train loss: 4.88413667678833 | Train Accuracy: 0.422 | Validation loss : 5.25683069229126 validation Accuracy: 0.05
Epoch: 336 | Train loss: 4.825057506561279 | Train Accuracy: 0.479 | Validation loss : 5.256826400756836 validation Accuracy: 0.05
Epoch: 337 | Train loss: 4.907139778137207 | Train Accuracy: 0.397 | Validation loss : 5.273897647857666 validation Accuracy: 0.033
Epoch: 338 | Train loss: 4.877401351928711 | Train Accuracy: 0.429 | Validatio

Epoch: 395 | Train loss: 4.886131286621094 | Train Accuracy: 0.418 | Validation loss : 5.30667781829834 validation Accuracy: 0.0
Epoch: 396 | Train loss: 4.834840774536133 | Train Accuracy: 0.472 | Validation loss : 5.306609630584717 validation Accuracy: 0.0
Epoch: 397 | Train loss: 4.865817546844482 | Train Accuracy: 0.44 | Validation loss : 5.306708812713623 validation Accuracy: 0.0
Epoch: 398 | Train loss: 4.884284496307373 | Train Accuracy: 0.422 | Validation loss : 5.306556701660156 validation Accuracy: 0.0
Epoch: 399 | Train loss: 4.894505500793457 | Train Accuracy: 0.411 | Validation loss : 5.28902530670166 validation Accuracy: 0.017
Epoch: 400 | Train loss: 4.888639450073242 | Train Accuracy: 0.418 | Validation loss : 5.3065714836120605 validation Accuracy: 0.0
Epoch: 401 | Train loss: 4.832109451293945 | Train Accuracy: 0.475 | Validation loss : 5.30666446685791 validation Accuracy: 0.0
Epoch: 402 | Train loss: 4.89307975769043 | Train Accuracy: 0.411 | Validation loss : 5.306

Epoch: 461 | Train loss: 4.866790771484375 | Train Accuracy: 0.44 | Validation loss : 5.306397438049316 validation Accuracy: 0.0
Epoch: 462 | Train loss: 4.86966609954834 | Train Accuracy: 0.436 | Validation loss : 5.2900471687316895 validation Accuracy: 0.017
Epoch: 463 | Train loss: 4.820324897766113 | Train Accuracy: 0.486 | Validation loss : 5.28707218170166 validation Accuracy: 0.017
Epoch: 464 | Train loss: 4.838685512542725 | Train Accuracy: 0.468 | Validation loss : 5.289971351623535 validation Accuracy: 0.017
Epoch: 465 | Train loss: 4.842461109161377 | Train Accuracy: 0.465 | Validation loss : 5.306606769561768 validation Accuracy: 0.0
Epoch: 466 | Train loss: 4.851445198059082 | Train Accuracy: 0.454 | Validation loss : 5.306595802307129 validation Accuracy: 0.0
Epoch: 467 | Train loss: 4.85344934463501 | Train Accuracy: 0.454 | Validation loss : 5.305936813354492 validation Accuracy: 0.0
Epoch: 468 | Train loss: 4.853431701660156 | Train Accuracy: 0.454 | Validation loss : 

Epoch: 524 | Train loss: 4.891520977020264 | Train Accuracy: 0.415 | Validation loss : 5.306787967681885 validation Accuracy: 0.0
Epoch: 525 | Train loss: 4.856428623199463 | Train Accuracy: 0.45 | Validation loss : 5.3066277503967285 validation Accuracy: 0.0
Epoch: 526 | Train loss: 4.889739513397217 | Train Accuracy: 0.418 | Validation loss : 5.289933681488037 validation Accuracy: 0.017
Epoch: 527 | Train loss: 4.81033182144165 | Train Accuracy: 0.496 | Validation loss : 5.290106296539307 validation Accuracy: 0.017
Epoch: 528 | Train loss: 4.906212329864502 | Train Accuracy: 0.401 | Validation loss : 5.273355007171631 validation Accuracy: 0.033
Epoch: 529 | Train loss: 4.868405342102051 | Train Accuracy: 0.44 | Validation loss : 5.306716442108154 validation Accuracy: 0.0
Epoch: 530 | Train loss: 4.87860107421875 | Train Accuracy: 0.429 | Validation loss : 5.273420810699463 validation Accuracy: 0.033
Epoch: 531 | Train loss: 4.902493476867676 | Train Accuracy: 0.404 | Validation loss 

Epoch: 590 | Train loss: 4.862030029296875 | Train Accuracy: 0.447 | Validation loss : 5.306802749633789 validation Accuracy: 0.0
Epoch: 591 | Train loss: 4.873571395874023 | Train Accuracy: 0.433 | Validation loss : 5.306687831878662 validation Accuracy: 0.0
Epoch: 592 | Train loss: 4.820563793182373 | Train Accuracy: 0.486 | Validation loss : 5.3067474365234375 validation Accuracy: 0.0
Epoch: 593 | Train loss: 4.838704586029053 | Train Accuracy: 0.468 | Validation loss : 5.306632041931152 validation Accuracy: 0.0
Epoch: 594 | Train loss: 4.854133129119873 | Train Accuracy: 0.454 | Validation loss : 5.291123390197754 validation Accuracy: 0.017
Epoch: 595 | Train loss: 4.867100238800049 | Train Accuracy: 0.44 | Validation loss : 5.290313720703125 validation Accuracy: 0.017
Epoch: 596 | Train loss: 4.842306613922119 | Train Accuracy: 0.465 | Validation loss : 5.273387432098389 validation Accuracy: 0.033
Epoch: 597 | Train loss: 4.843103408813477 | Train Accuracy: 0.465 | Validation loss

Epoch: 654 | Train loss: 4.835127830505371 | Train Accuracy: 0.472 | Validation loss : 5.290071487426758 validation Accuracy: 0.017
Epoch: 655 | Train loss: 4.8780107498168945 | Train Accuracy: 0.429 | Validation loss : 5.290107727050781 validation Accuracy: 0.017
Epoch: 656 | Train loss: 4.857256889343262 | Train Accuracy: 0.45 | Validation loss : 5.2893805503845215 validation Accuracy: 0.017
Epoch: 657 | Train loss: 4.891561031341553 | Train Accuracy: 0.415 | Validation loss : 5.306753158569336 validation Accuracy: 0.0
Epoch: 658 | Train loss: 4.896378993988037 | Train Accuracy: 0.411 | Validation loss : 5.290093421936035 validation Accuracy: 0.017
Epoch: 659 | Train loss: 4.902474403381348 | Train Accuracy: 0.404 | Validation loss : 5.306718349456787 validation Accuracy: 0.0
Epoch: 660 | Train loss: 4.902402400970459 | Train Accuracy: 0.404 | Validation loss : 5.3066935539245605 validation Accuracy: 0.0
Epoch: 661 | Train loss: 4.86030387878418 | Train Accuracy: 0.447 | Validation l

Epoch: 720 | Train loss: 4.864839553833008 | Train Accuracy: 0.443 | Validation loss : 5.290055751800537 validation Accuracy: 0.017
Epoch: 721 | Train loss: 4.88588285446167 | Train Accuracy: 0.422 | Validation loss : 5.289235591888428 validation Accuracy: 0.017
Epoch: 722 | Train loss: 4.831569671630859 | Train Accuracy: 0.475 | Validation loss : 5.306728839874268 validation Accuracy: 0.0
Epoch: 723 | Train loss: 4.870856761932373 | Train Accuracy: 0.436 | Validation loss : 5.306204319000244 validation Accuracy: 0.0
Epoch: 724 | Train loss: 4.857149124145508 | Train Accuracy: 0.45 | Validation loss : 5.289950370788574 validation Accuracy: 0.017
Epoch: 725 | Train loss: 4.877660274505615 | Train Accuracy: 0.429 | Validation loss : 5.28998327255249 validation Accuracy: 0.017
Epoch: 726 | Train loss: 4.848390102386475 | Train Accuracy: 0.457 | Validation loss : 5.273343086242676 validation Accuracy: 0.033
Epoch: 727 | Train loss: 4.881671905517578 | Train Accuracy: 0.426 | Validation los

Epoch: 784 | Train loss: 4.863517761230469 | Train Accuracy: 0.443 | Validation loss : 5.30672550201416 validation Accuracy: 0.0
Epoch: 785 | Train loss: 4.874152183532715 | Train Accuracy: 0.433 | Validation loss : 5.290033340454102 validation Accuracy: 0.017
Epoch: 786 | Train loss: 4.842566967010498 | Train Accuracy: 0.465 | Validation loss : 5.3067193031311035 validation Accuracy: 0.0
Epoch: 787 | Train loss: 4.866657257080078 | Train Accuracy: 0.44 | Validation loss : 5.306683540344238 validation Accuracy: 0.0
Epoch: 788 | Train loss: 4.8506975173950195 | Train Accuracy: 0.457 | Validation loss : 5.306646347045898 validation Accuracy: 0.0
Epoch: 789 | Train loss: 4.891899585723877 | Train Accuracy: 0.415 | Validation loss : 5.290040493011475 validation Accuracy: 0.017
Epoch: 790 | Train loss: 4.886162281036377 | Train Accuracy: 0.422 | Validation loss : 5.290043830871582 validation Accuracy: 0.017
Epoch: 791 | Train loss: 4.855469703674316 | Train Accuracy: 0.45 | Validation loss 

Epoch: 847 | Train loss: 4.828052043914795 | Train Accuracy: 0.479 | Validation loss : 5.290043830871582 validation Accuracy: 0.017
Epoch: 848 | Train loss: 4.906055450439453 | Train Accuracy: 0.401 | Validation loss : 5.287760257720947 validation Accuracy: 0.017
Epoch: 849 | Train loss: 4.853082180023193 | Train Accuracy: 0.454 | Validation loss : 5.290018558502197 validation Accuracy: 0.017
Epoch: 850 | Train loss: 4.905239105224609 | Train Accuracy: 0.401 | Validation loss : 5.3063507080078125 validation Accuracy: 0.0
Epoch: 851 | Train loss: 4.859957218170166 | Train Accuracy: 0.447 | Validation loss : 5.29006290435791 validation Accuracy: 0.017
Epoch: 852 | Train loss: 4.828059196472168 | Train Accuracy: 0.479 | Validation loss : 5.3067545890808105 validation Accuracy: 0.0
Epoch: 853 | Train loss: 4.824666500091553 | Train Accuracy: 0.482 | Validation loss : 5.29004430770874 validation Accuracy: 0.017
Epoch: 854 | Train loss: 4.885978698730469 | Train Accuracy: 0.422 | Validation 

Epoch: 913 | Train loss: 4.848326683044434 | Train Accuracy: 0.457 | Validation loss : 5.289957046508789 validation Accuracy: 0.017
Epoch: 914 | Train loss: 4.852877140045166 | Train Accuracy: 0.454 | Validation loss : 5.290136337280273 validation Accuracy: 0.017
Epoch: 915 | Train loss: 4.827488899230957 | Train Accuracy: 0.479 | Validation loss : 5.289956092834473 validation Accuracy: 0.017
Epoch: 916 | Train loss: 4.849362373352051 | Train Accuracy: 0.457 | Validation loss : 5.294834613800049 validation Accuracy: 0.017
Epoch: 917 | Train loss: 4.849933624267578 | Train Accuracy: 0.457 | Validation loss : 5.273334980010986 validation Accuracy: 0.033
Epoch: 918 | Train loss: 4.878472328186035 | Train Accuracy: 0.429 | Validation loss : 5.256950378417969 validation Accuracy: 0.05
Epoch: 919 | Train loss: 4.868396282196045 | Train Accuracy: 0.44 | Validation loss : 5.279397487640381 validation Accuracy: 0.033
Epoch: 920 | Train loss: 4.834768772125244 | Train Accuracy: 0.472 | Validatio

Epoch: 978 | Train loss: 4.846461296081543 | Train Accuracy: 0.461 | Validation loss : 5.306758403778076 validation Accuracy: 0.0
Epoch: 979 | Train loss: 4.820979595184326 | Train Accuracy: 0.486 | Validation loss : 5.290052890777588 validation Accuracy: 0.017
Epoch: 980 | Train loss: 4.83194637298584 | Train Accuracy: 0.475 | Validation loss : 5.273074626922607 validation Accuracy: 0.033
Epoch: 981 | Train loss: 4.866991996765137 | Train Accuracy: 0.44 | Validation loss : 5.273411750793457 validation Accuracy: 0.033
Epoch: 982 | Train loss: 4.781974792480469 | Train Accuracy: 0.525 | Validation loss : 5.273350238800049 validation Accuracy: 0.033
Epoch: 983 | Train loss: 4.80678129196167 | Train Accuracy: 0.5 | Validation loss : 5.273477554321289 validation Accuracy: 0.033
Epoch: 984 | Train loss: 4.84211540222168 | Train Accuracy: 0.465 | Validation loss : 5.290028095245361 validation Accuracy: 0.017
Epoch: 985 | Train loss: 4.8564300537109375 | Train Accuracy: 0.45 | Validation loss

In [34]:
rnn_MODEL = torch.load("model_rnn.pth")
rnn_MODEL.eval()

RNNTG(
  (embed): Embedding(200, 500)
  (rnn): RNN(500, 100, num_layers=2)
  (fc): Linear(in_features=400, out_features=200, bias=True)
  (prob): Softmax(dim=1)
  (drop): Dropout(p=0.4, inplace=False)
)

### Testing

In [35]:
def TestAnalysis(xtest,model_saved, ytest, i2w):
    word_pred = []
    word_actual = []
    pred_test = model_saved(xtest)
    pred_test_ix = pred_test.argmax(dim=1)
    test_acc = TestAccuracy(pred_test, ytest)
    for i in pred_test_ix:
        word_pred.append(i2w[int(i.numpy())])

    for j in ytest:
        word_actual.append(i2w[int(j.numpy())])
        
    return (' '.join(word_actual), ' '.join(word_pred), test_acc)

In [36]:
actual_text, predicted_text, test_accuracy = TestAnalysis(X_test, rnn_MODEL, Y_test, index_to_word)

In [37]:
print("actual text")
print(actual_text)
print()
print("predicted text")
print(predicted_text)
print()
print("test accuracy")
print(test_accuracy)

actual text
i decided to drop out and trust that it would all work out ok it was pretty scary at the time but looking back it was one of the best decisions i ever made the minute i dropped out i could stop taking the required classes that didn t interest me and begin dropping in on the ones that looked interesting

predicted text
graduated want to set college up to savings i i so i in savings minute stanford biological out biological as have at i as of have so as biological as i as would to and me to my out of a promised so it of my my for as wanted the to birth a set did have of a my i

test accuracy
0.033


## LSTM Model

In [38]:
from model import LSTM

In [39]:
#Model

model_lstm = LSTM.LSTMTG(num_embed, embed_dim, seq_len, inp_sz, rnn_hidsz, nlayer, fc_hidsz, out_sz)

#loss function and optimizer

criterion_lstm = nn.CrossEntropyLoss()
optimizer_lstm = torch.optim.Adam(model_lstm.parameters(), 0.05)

In [40]:
print(model_lstm)

LSTMTG(
  (embed): Embedding(200, 500)
  (rnn): LSTM(500, 100, num_layers=2)
  (fc): Linear(in_features=400, out_features=200, bias=True)
  (prob): Softmax(dim=1)
  (drop): Dropout(p=0.4, inplace=False)
)


In [41]:
#Training-Validation

epochs = 1000
train_losses = []
train_accuracy = []
valid_losses = []
valid_accuracy = []

print("Training started...")

for e in range(epochs):
    
    optimizer_lstm.zero_grad()
    
    #trainig
    
    out = model_lstm(X_train)
    loss = criterion_lstm(out, Y_train)
    train_losses.append(loss.detach())
    train_acc = TestAccuracy(out, Y_train)
    train_accuracy.append(train_acc)
    
    #validation
    
    pred_valid, valid_acc = ValidationAnalysis(model_lstm, X_valid, Y_valid)
    valid_loss = criterion_lstm(pred_valid, Y_valid)
    valid_losses.append(valid_loss.detach())
    valid_accuracy.append(valid_acc)   
    
    #model save
    
    if (e==0):
        d = "do nothing"
    elif (e > 2):
        if ((train_accuracy[-1]>train_accuracy[-2])):
            FILE = "model_lstm.pth"
            torch.save(model_lstm, FILE)
   
    torch.nn.utils.clip_grad_norm_(model_lstm.parameters(), max_norm=1.0)
    loss.backward()  
    optimizer_lstm.step()
    
    print("Epoch: {} | Train loss: {} | Train Accuracy: {} | Validation loss : {} validation Accuracy: {}".
             format(e, loss.detach(), train_acc, valid_loss.detach(), valid_acc))       

Training started...
Epoch: 0 | Train loss: 5.2983293533325195 | Train Accuracy: 0.007 | Validation loss : 5.298365592956543 validation Accuracy: 0.0
Epoch: 1 | Train loss: 5.290013790130615 | Train Accuracy: 0.252 | Validation loss : 5.297151565551758 validation Accuracy: 0.017
Epoch: 2 | Train loss: 5.217000484466553 | Train Accuracy: 0.142 | Validation loss : 5.288207054138184 validation Accuracy: 0.017
Epoch: 3 | Train loss: 5.109452247619629 | Train Accuracy: 0.255 | Validation loss : 5.293539524078369 validation Accuracy: 0.017
Epoch: 4 | Train loss: 5.013035774230957 | Train Accuracy: 0.33 | Validation loss : 5.294549942016602 validation Accuracy: 0.017
Epoch: 5 | Train loss: 4.921617031097412 | Train Accuracy: 0.387 | Validation loss : 5.295417308807373 validation Accuracy: 0.017
Epoch: 6 | Train loss: 4.856619834899902 | Train Accuracy: 0.472 | Validation loss : 5.295106887817383 validation Accuracy: 0.017
Epoch: 7 | Train loss: 4.826559066772461 | Train Accuracy: 0.482 | Valid

Epoch: 64 | Train loss: 4.77080774307251 | Train Accuracy: 0.535 | Validation loss : 5.304214000701904 validation Accuracy: 0.0
Epoch: 65 | Train loss: 4.787468433380127 | Train Accuracy: 0.518 | Validation loss : 5.304132461547852 validation Accuracy: 0.0
Epoch: 66 | Train loss: 4.714233875274658 | Train Accuracy: 0.592 | Validation loss : 5.30556583404541 validation Accuracy: 0.0
Epoch: 67 | Train loss: 4.7960734367370605 | Train Accuracy: 0.511 | Validation loss : 5.3049635887146 validation Accuracy: 0.0
Epoch: 68 | Train loss: 4.7436347007751465 | Train Accuracy: 0.564 | Validation loss : 5.304366588592529 validation Accuracy: 0.0
Epoch: 69 | Train loss: 4.765130519866943 | Train Accuracy: 0.543 | Validation loss : 5.303745746612549 validation Accuracy: 0.0
Epoch: 70 | Train loss: 4.75342321395874 | Train Accuracy: 0.553 | Validation loss : 5.305300235748291 validation Accuracy: 0.0
Epoch: 71 | Train loss: 4.781836032867432 | Train Accuracy: 0.525 | Validation loss : 5.305222511291

Epoch: 128 | Train loss: 4.71743631362915 | Train Accuracy: 0.589 | Validation loss : 5.295943260192871 validation Accuracy: 0.017
Epoch: 129 | Train loss: 4.784696102142334 | Train Accuracy: 0.521 | Validation loss : 5.296509742736816 validation Accuracy: 0.017
Epoch: 130 | Train loss: 4.7776198387146 | Train Accuracy: 0.528 | Validation loss : 5.303260326385498 validation Accuracy: 0.0
Epoch: 131 | Train loss: 4.788235187530518 | Train Accuracy: 0.518 | Validation loss : 5.30477237701416 validation Accuracy: 0.0
Epoch: 132 | Train loss: 4.777584075927734 | Train Accuracy: 0.528 | Validation loss : 5.284268856048584 validation Accuracy: 0.017
Epoch: 133 | Train loss: 4.791723728179932 | Train Accuracy: 0.514 | Validation loss : 5.288839340209961 validation Accuracy: 0.017
Epoch: 134 | Train loss: 4.7705559730529785 | Train Accuracy: 0.535 | Validation loss : 5.304112911224365 validation Accuracy: 0.0
Epoch: 135 | Train loss: 4.809372901916504 | Train Accuracy: 0.496 | Validation loss 

Epoch: 192 | Train loss: 4.710186958312988 | Train Accuracy: 0.596 | Validation loss : 5.299759387969971 validation Accuracy: 0.017
Epoch: 193 | Train loss: 4.7809224128723145 | Train Accuracy: 0.525 | Validation loss : 5.2918524742126465 validation Accuracy: 0.017
Epoch: 194 | Train loss: 4.727898120880127 | Train Accuracy: 0.578 | Validation loss : 5.294409275054932 validation Accuracy: 0.017
Epoch: 195 | Train loss: 4.756207466125488 | Train Accuracy: 0.55 | Validation loss : 5.293195724487305 validation Accuracy: 0.017
Epoch: 196 | Train loss: 4.802225589752197 | Train Accuracy: 0.504 | Validation loss : 5.295740127563477 validation Accuracy: 0.017
Epoch: 197 | Train loss: 4.777406692504883 | Train Accuracy: 0.528 | Validation loss : 5.301219463348389 validation Accuracy: 0.0
Epoch: 198 | Train loss: 4.727914810180664 | Train Accuracy: 0.578 | Validation loss : 5.292379856109619 validation Accuracy: 0.017
Epoch: 199 | Train loss: 4.799373149871826 | Train Accuracy: 0.507 | Validati

Epoch: 256 | Train loss: 4.784371852874756 | Train Accuracy: 0.521 | Validation loss : 5.282849311828613 validation Accuracy: 0.033
Epoch: 257 | Train loss: 4.748976230621338 | Train Accuracy: 0.557 | Validation loss : 5.298539638519287 validation Accuracy: 0.0
Epoch: 258 | Train loss: 4.759618759155273 | Train Accuracy: 0.546 | Validation loss : 5.294958114624023 validation Accuracy: 0.017
Epoch: 259 | Train loss: 4.749023914337158 | Train Accuracy: 0.557 | Validation loss : 5.294516086578369 validation Accuracy: 0.017
Epoch: 260 | Train loss: 4.770284652709961 | Train Accuracy: 0.535 | Validation loss : 5.2932915687561035 validation Accuracy: 0.017
Epoch: 261 | Train loss: 4.734853744506836 | Train Accuracy: 0.571 | Validation loss : 5.290735244750977 validation Accuracy: 0.017
Epoch: 262 | Train loss: 4.77023983001709 | Train Accuracy: 0.535 | Validation loss : 5.2938232421875 validation Accuracy: 0.017
Epoch: 263 | Train loss: 4.777384281158447 | Train Accuracy: 0.528 | Validation 

Epoch: 320 | Train loss: 4.7666215896606445 | Train Accuracy: 0.539 | Validation loss : 5.293142795562744 validation Accuracy: 0.017
Epoch: 321 | Train loss: 4.83028507232666 | Train Accuracy: 0.475 | Validation loss : 5.283698081970215 validation Accuracy: 0.017
Epoch: 322 | Train loss: 4.74538516998291 | Train Accuracy: 0.56 | Validation loss : 5.288887977600098 validation Accuracy: 0.017
Epoch: 323 | Train loss: 4.787869453430176 | Train Accuracy: 0.518 | Validation loss : 5.284487724304199 validation Accuracy: 0.017
Epoch: 324 | Train loss: 4.773707389831543 | Train Accuracy: 0.532 | Validation loss : 5.300600051879883 validation Accuracy: 0.0
Epoch: 325 | Train loss: 4.809026718139648 | Train Accuracy: 0.496 | Validation loss : 5.29467248916626 validation Accuracy: 0.017
Epoch: 326 | Train loss: 4.787827491760254 | Train Accuracy: 0.518 | Validation loss : 5.293331623077393 validation Accuracy: 0.017
Epoch: 327 | Train loss: 4.75598669052124 | Train Accuracy: 0.55 | Validation los

Epoch: 384 | Train loss: 4.766482830047607 | Train Accuracy: 0.539 | Validation loss : 5.2897748947143555 validation Accuracy: 0.017
Epoch: 385 | Train loss: 4.713464736938477 | Train Accuracy: 0.592 | Validation loss : 5.274987697601318 validation Accuracy: 0.033
Epoch: 386 | Train loss: 4.734654903411865 | Train Accuracy: 0.571 | Validation loss : 5.29740047454834 validation Accuracy: 0.017
Epoch: 387 | Train loss: 4.755919933319092 | Train Accuracy: 0.55 | Validation loss : 5.305419921875 validation Accuracy: 0.0
Epoch: 388 | Train loss: 4.762908935546875 | Train Accuracy: 0.543 | Validation loss : 5.305259704589844 validation Accuracy: 0.0
Epoch: 389 | Train loss: 4.759456634521484 | Train Accuracy: 0.546 | Validation loss : 5.27551794052124 validation Accuracy: 0.033
Epoch: 390 | Train loss: 4.759472370147705 | Train Accuracy: 0.546 | Validation loss : 5.274908065795898 validation Accuracy: 0.033
Epoch: 391 | Train loss: 4.784163475036621 | Train Accuracy: 0.521 | Validation loss 

Epoch: 448 | Train loss: 4.741708755493164 | Train Accuracy: 0.564 | Validation loss : 5.291084289550781 validation Accuracy: 0.017
Epoch: 449 | Train loss: 4.769950866699219 | Train Accuracy: 0.535 | Validation loss : 5.291989803314209 validation Accuracy: 0.017
Epoch: 450 | Train loss: 4.759404182434082 | Train Accuracy: 0.546 | Validation loss : 5.305121421813965 validation Accuracy: 0.0
Epoch: 451 | Train loss: 4.798364639282227 | Train Accuracy: 0.507 | Validation loss : 5.289787292480469 validation Accuracy: 0.017
Epoch: 452 | Train loss: 4.769901752471924 | Train Accuracy: 0.535 | Validation loss : 5.29650354385376 validation Accuracy: 0.017
Epoch: 453 | Train loss: 4.745214462280273 | Train Accuracy: 0.56 | Validation loss : 5.296961784362793 validation Accuracy: 0.017
Epoch: 454 | Train loss: 4.769961833953857 | Train Accuracy: 0.535 | Validation loss : 5.289839744567871 validation Accuracy: 0.017
Epoch: 455 | Train loss: 4.798188209533691 | Train Accuracy: 0.507 | Validation 

Epoch: 511 | Train loss: 4.72047758102417 | Train Accuracy: 0.589 | Validation loss : 5.291638374328613 validation Accuracy: 0.017
Epoch: 512 | Train loss: 4.779482364654541 | Train Accuracy: 0.528 | Validation loss : 5.288976192474365 validation Accuracy: 0.017
Epoch: 513 | Train loss: 4.787871360778809 | Train Accuracy: 0.518 | Validation loss : 5.305208206176758 validation Accuracy: 0.0
Epoch: 514 | Train loss: 4.798689842224121 | Train Accuracy: 0.507 | Validation loss : 5.2901458740234375 validation Accuracy: 0.017
Epoch: 515 | Train loss: 4.8051605224609375 | Train Accuracy: 0.5 | Validation loss : 5.295674800872803 validation Accuracy: 0.0
Epoch: 516 | Train loss: 4.718897342681885 | Train Accuracy: 0.589 | Validation loss : 5.3018646240234375 validation Accuracy: 0.0
Epoch: 517 | Train loss: 4.779212474822998 | Train Accuracy: 0.528 | Validation loss : 5.302546977996826 validation Accuracy: 0.0
Epoch: 518 | Train loss: 4.711784839630127 | Train Accuracy: 0.596 | Validation loss

Epoch: 576 | Train loss: 4.998778343200684 | Train Accuracy: 0.309 | Validation loss : 5.305947303771973 validation Accuracy: 0.0
Epoch: 577 | Train loss: 5.001289367675781 | Train Accuracy: 0.309 | Validation loss : 5.306445121765137 validation Accuracy: 0.0
Epoch: 578 | Train loss: 4.995810031890869 | Train Accuracy: 0.312 | Validation loss : 5.306543827056885 validation Accuracy: 0.0
Epoch: 579 | Train loss: 5.019339561462402 | Train Accuracy: 0.287 | Validation loss : 5.303839683532715 validation Accuracy: 0.0
Epoch: 580 | Train loss: 4.997229099273682 | Train Accuracy: 0.305 | Validation loss : 5.30525541305542 validation Accuracy: 0.0
Epoch: 581 | Train loss: 5.042233467102051 | Train Accuracy: 0.266 | Validation loss : 5.305847644805908 validation Accuracy: 0.0
Epoch: 582 | Train loss: 5.0061211585998535 | Train Accuracy: 0.298 | Validation loss : 5.306268692016602 validation Accuracy: 0.0
Epoch: 583 | Train loss: 5.018884181976318 | Train Accuracy: 0.287 | Validation loss : 5.3

Epoch: 640 | Train loss: 4.9004807472229 | Train Accuracy: 0.408 | Validation loss : 5.306242942810059 validation Accuracy: 0.0
Epoch: 641 | Train loss: 4.923010349273682 | Train Accuracy: 0.383 | Validation loss : 5.3064985275268555 validation Accuracy: 0.0
Epoch: 642 | Train loss: 4.933226585388184 | Train Accuracy: 0.376 | Validation loss : 5.289320468902588 validation Accuracy: 0.017
Epoch: 643 | Train loss: 4.935665607452393 | Train Accuracy: 0.372 | Validation loss : 5.28978967666626 validation Accuracy: 0.017
Epoch: 644 | Train loss: 4.96397590637207 | Train Accuracy: 0.344 | Validation loss : 5.289831638336182 validation Accuracy: 0.017
Epoch: 645 | Train loss: 4.959317684173584 | Train Accuracy: 0.348 | Validation loss : 5.2745361328125 validation Accuracy: 0.033
Epoch: 646 | Train loss: 4.962416648864746 | Train Accuracy: 0.344 | Validation loss : 5.27168083190918 validation Accuracy: 0.033
Epoch: 647 | Train loss: 4.944934368133545 | Train Accuracy: 0.362 | Validation loss :

Epoch: 703 | Train loss: 4.9325714111328125 | Train Accuracy: 0.372 | Validation loss : 5.306057453155518 validation Accuracy: 0.0
Epoch: 704 | Train loss: 4.924779891967773 | Train Accuracy: 0.383 | Validation loss : 5.289875507354736 validation Accuracy: 0.017
Epoch: 705 | Train loss: 4.9167351722717285 | Train Accuracy: 0.387 | Validation loss : 5.289968967437744 validation Accuracy: 0.017
Epoch: 706 | Train loss: 4.954944133758545 | Train Accuracy: 0.355 | Validation loss : 5.306671619415283 validation Accuracy: 0.0
Epoch: 707 | Train loss: 4.930905818939209 | Train Accuracy: 0.376 | Validation loss : 5.306002140045166 validation Accuracy: 0.0
Epoch: 708 | Train loss: 4.914753437042236 | Train Accuracy: 0.394 | Validation loss : 5.273449897766113 validation Accuracy: 0.033
Epoch: 709 | Train loss: 4.9243340492248535 | Train Accuracy: 0.383 | Validation loss : 5.306675910949707 validation Accuracy: 0.0
Epoch: 710 | Train loss: 4.929498672485352 | Train Accuracy: 0.376 | Validation l

Epoch: 766 | Train loss: 4.900040626525879 | Train Accuracy: 0.408 | Validation loss : 5.304858207702637 validation Accuracy: 0.0
Epoch: 767 | Train loss: 4.8917388916015625 | Train Accuracy: 0.415 | Validation loss : 5.267205238342285 validation Accuracy: 0.033
Epoch: 768 | Train loss: 4.913069725036621 | Train Accuracy: 0.394 | Validation loss : 5.306495666503906 validation Accuracy: 0.0
Epoch: 769 | Train loss: 4.9188313484191895 | Train Accuracy: 0.387 | Validation loss : 5.289761543273926 validation Accuracy: 0.017
Epoch: 770 | Train loss: 4.905830383300781 | Train Accuracy: 0.401 | Validation loss : 5.291537284851074 validation Accuracy: 0.017
Epoch: 771 | Train loss: 4.8958587646484375 | Train Accuracy: 0.411 | Validation loss : 5.289853096008301 validation Accuracy: 0.017
Epoch: 772 | Train loss: 4.859888076782227 | Train Accuracy: 0.447 | Validation loss : 5.303619384765625 validation Accuracy: 0.0
Epoch: 773 | Train loss: 4.895317077636719 | Train Accuracy: 0.411 | Validation

Epoch: 830 | Train loss: 4.8704400062561035 | Train Accuracy: 0.436 | Validation loss : 5.306650161743164 validation Accuracy: 0.0
Epoch: 831 | Train loss: 4.888151168823242 | Train Accuracy: 0.418 | Validation loss : 5.306614398956299 validation Accuracy: 0.0
Epoch: 832 | Train loss: 4.877516746520996 | Train Accuracy: 0.429 | Validation loss : 5.306434154510498 validation Accuracy: 0.0
Epoch: 833 | Train loss: 4.888250350952148 | Train Accuracy: 0.418 | Validation loss : 5.275802135467529 validation Accuracy: 0.033
Epoch: 834 | Train loss: 4.850142478942871 | Train Accuracy: 0.457 | Validation loss : 5.295224189758301 validation Accuracy: 0.017
Epoch: 835 | Train loss: 4.941397190093994 | Train Accuracy: 0.365 | Validation loss : 5.2892985343933105 validation Accuracy: 0.017
Epoch: 836 | Train loss: 4.908454895019531 | Train Accuracy: 0.397 | Validation loss : 5.289458274841309 validation Accuracy: 0.017
Epoch: 837 | Train loss: 4.89531946182251 | Train Accuracy: 0.411 | Validation l

Epoch: 894 | Train loss: 4.852725982666016 | Train Accuracy: 0.454 | Validation loss : 5.306475639343262 validation Accuracy: 0.0
Epoch: 895 | Train loss: 4.8633036613464355 | Train Accuracy: 0.443 | Validation loss : 5.289954662322998 validation Accuracy: 0.017
Epoch: 896 | Train loss: 4.874035358428955 | Train Accuracy: 0.433 | Validation loss : 5.289898872375488 validation Accuracy: 0.017
Epoch: 897 | Train loss: 4.88458251953125 | Train Accuracy: 0.422 | Validation loss : 5.306467056274414 validation Accuracy: 0.0
Epoch: 898 | Train loss: 4.877551555633545 | Train Accuracy: 0.429 | Validation loss : 5.306437015533447 validation Accuracy: 0.0
Epoch: 899 | Train loss: 4.863377571105957 | Train Accuracy: 0.443 | Validation loss : 5.306572914123535 validation Accuracy: 0.0
Epoch: 900 | Train loss: 4.856258392333984 | Train Accuracy: 0.45 | Validation loss : 5.306563854217529 validation Accuracy: 0.0
Epoch: 901 | Train loss: 4.86337423324585 | Train Accuracy: 0.443 | Validation loss : 5

Epoch: 958 | Train loss: 4.8632731437683105 | Train Accuracy: 0.443 | Validation loss : 5.306406497955322 validation Accuracy: 0.0
Epoch: 959 | Train loss: 4.895167350769043 | Train Accuracy: 0.411 | Validation loss : 5.306497097015381 validation Accuracy: 0.0
Epoch: 960 | Train loss: 4.905819892883301 | Train Accuracy: 0.401 | Validation loss : 5.306481838226318 validation Accuracy: 0.0
Epoch: 961 | Train loss: 4.863252639770508 | Train Accuracy: 0.443 | Validation loss : 5.306459903717041 validation Accuracy: 0.0
Epoch: 962 | Train loss: 4.856128215789795 | Train Accuracy: 0.45 | Validation loss : 5.290392875671387 validation Accuracy: 0.017
Epoch: 963 | Train loss: 4.919940948486328 | Train Accuracy: 0.387 | Validation loss : 5.3063459396362305 validation Accuracy: 0.0
Epoch: 964 | Train loss: 4.880954742431641 | Train Accuracy: 0.426 | Validation loss : 5.306496620178223 validation Accuracy: 0.0
Epoch: 965 | Train loss: 4.870287895202637 | Train Accuracy: 0.436 | Validation loss : 

In [42]:
lstm_MODEL = torch.load("model_lstm.pth")
lstm_MODEL.eval()

LSTMTG(
  (embed): Embedding(200, 500)
  (rnn): LSTM(500, 100, num_layers=2)
  (fc): Linear(in_features=400, out_features=200, bias=True)
  (prob): Softmax(dim=1)
  (drop): Dropout(p=0.4, inplace=False)
)

## Testing

In [43]:
def TestAnalysis(xtest,model_saved, ytest, i2w):
    word_pred = []
    word_actual = []
    pred_test = model_saved(xtest)
    pred_test_ix = pred_test.argmax(dim=1)
    test_acc = TestAccuracy(pred_test, ytest)
    for i in pred_test_ix:
        word_pred.append(i2w[int(i.numpy())])

    for j in ytest:
        word_actual.append(i2w[int(j.numpy())])
        
    return (' '.join(word_actual), ' '.join(word_pred), test_acc)

In [44]:
actual_text, predicted_text, test_accuracy = TestAnalysis(X_test, lstm_MODEL, Y_test, index_to_word)

In [45]:
print("actual text")
print(actual_text)
print()
print("predicted text")
print(predicted_text)
print()
print("test accuracy")
print(test_accuracy)

actual text
i decided to drop out and trust that it would all work out ok it was pretty scary at the time but looking back it was one of the best decisions i ever made the minute i dropped out i could stop taking the required classes that didn t interest me and begin dropping in on the ones that looked interesting

predicted text
that truth quit so college graduation adopted biological i world up finest parents we as world that promised quit mother promised the one night minute to his minute parents you i should out out born world minute dropped she born she minute adoption she class of she one as out so out she of adopted the adopted minute young that quit

test accuracy
0.016


## GRU Model

In [46]:
from model import GRU

In [47]:
#Model

model_gru = GRU.GRUTG(num_embed, embed_dim, seq_len, inp_sz, rnn_hidsz, nlayer, fc_hidsz, out_sz)

#loss function and optimizer

criterion_gru = nn.CrossEntropyLoss()
optimizer_gru = torch.optim.Adam(model_gru.parameters(), 0.05)

In [48]:
print(model_gru)

GRUTG(
  (embed): Embedding(200, 500)
  (rnn): GRU(500, 100, num_layers=2)
  (fc): Linear(in_features=400, out_features=200, bias=True)
  (prob): Softmax(dim=1)
  (drop): Dropout(p=0.4, inplace=False)
)


In [49]:
#Training-Validation

epochs = 1000
train_losses = []
train_accuracy = []
valid_losses = []
valid_accuracy = []

print("Training started...")

for e in range(epochs):
    
    optimizer_gru.zero_grad()
    
    #trainig
    
    out = model_gru(X_train)
    loss = criterion_gru(out, Y_train)
    train_losses.append(loss.detach())
    train_acc = TestAccuracy(out, Y_train)
    train_accuracy.append(train_acc)
    
    #validation
    
    pred_valid, valid_acc = ValidationAnalysis(model_gru, X_valid, Y_valid)
    valid_loss = criterion_gru(pred_valid, Y_valid)
    valid_losses.append(valid_loss.detach())
    valid_accuracy.append(valid_acc)   
    
    #model save
    
    if (e==0):
        d = "do nothing"
    elif (e > 2):
        if ((train_accuracy[-1]>train_accuracy[-2])):
            FILE = "model_gru.pth"
            torch.save(model_gru, FILE)
   
    torch.nn.utils.clip_grad_norm_(model_gru.parameters(), max_norm=1.0)
    loss.backward()  
    optimizer_gru.step()
    
    print("Epoch: {} | Train loss: {} | Train Accuracy: {} | Validation loss : {} validation Accuracy: {}".
             format(e, loss.detach(), train_acc, valid_loss.detach(), valid_acc))       

Training started...
Epoch: 0 | Train loss: 5.298305511474609 | Train Accuracy: 0.0 | Validation loss : 5.298099040985107 validation Accuracy: 0.017
Epoch: 1 | Train loss: 5.219174861907959 | Train Accuracy: 0.319 | Validation loss : 5.297563552856445 validation Accuracy: 0.0
Epoch: 2 | Train loss: 5.0813212394714355 | Train Accuracy: 0.245 | Validation loss : 5.265552997589111 validation Accuracy: 0.05
Epoch: 3 | Train loss: 5.0145087242126465 | Train Accuracy: 0.305 | Validation loss : 5.3032073974609375 validation Accuracy: 0.0
Epoch: 4 | Train loss: 5.012026786804199 | Train Accuracy: 0.291 | Validation loss : 5.274920463562012 validation Accuracy: 0.05
Epoch: 5 | Train loss: 4.946122646331787 | Train Accuracy: 0.362 | Validation loss : 5.291968822479248 validation Accuracy: 0.017
Epoch: 6 | Train loss: 4.98896598815918 | Train Accuracy: 0.319 | Validation loss : 5.2778706550598145 validation Accuracy: 0.033
Epoch: 7 | Train loss: 4.921446323394775 | Train Accuracy: 0.387 | Validati

Epoch: 64 | Train loss: 4.874423503875732 | Train Accuracy: 0.436 | Validation loss : 5.293687343597412 validation Accuracy: 0.017
Epoch: 65 | Train loss: 4.872755527496338 | Train Accuracy: 0.433 | Validation loss : 5.2898149490356445 validation Accuracy: 0.017
Epoch: 66 | Train loss: 4.8867974281311035 | Train Accuracy: 0.422 | Validation loss : 5.3052873611450195 validation Accuracy: 0.0
Epoch: 67 | Train loss: 4.860253810882568 | Train Accuracy: 0.447 | Validation loss : 5.306191921234131 validation Accuracy: 0.0
Epoch: 68 | Train loss: 4.905762195587158 | Train Accuracy: 0.401 | Validation loss : 5.306222438812256 validation Accuracy: 0.0
Epoch: 69 | Train loss: 4.887745380401611 | Train Accuracy: 0.418 | Validation loss : 5.301804065704346 validation Accuracy: 0.0
Epoch: 70 | Train loss: 4.896505355834961 | Train Accuracy: 0.411 | Validation loss : 5.274024963378906 validation Accuracy: 0.033
Epoch: 71 | Train loss: 4.878122329711914 | Train Accuracy: 0.433 | Validation loss : 5.

Epoch: 127 | Train loss: 4.905919551849365 | Train Accuracy: 0.401 | Validation loss : 5.284363269805908 validation Accuracy: 0.017
Epoch: 128 | Train loss: 4.902697563171387 | Train Accuracy: 0.404 | Validation loss : 5.267373085021973 validation Accuracy: 0.033
Epoch: 129 | Train loss: 4.9232988357543945 | Train Accuracy: 0.383 | Validation loss : 5.297230243682861 validation Accuracy: 0.0
Epoch: 130 | Train loss: 4.890678882598877 | Train Accuracy: 0.418 | Validation loss : 5.273185729980469 validation Accuracy: 0.033
Epoch: 131 | Train loss: 4.921173572540283 | Train Accuracy: 0.387 | Validation loss : 5.306177139282227 validation Accuracy: 0.0
Epoch: 132 | Train loss: 4.904788494110107 | Train Accuracy: 0.401 | Validation loss : 5.2910895347595215 validation Accuracy: 0.017
Epoch: 133 | Train loss: 4.930656909942627 | Train Accuracy: 0.379 | Validation loss : 5.276862144470215 validation Accuracy: 0.033
Epoch: 134 | Train loss: 4.908203601837158 | Train Accuracy: 0.401 | Validatio

Epoch: 191 | Train loss: 4.878101348876953 | Train Accuracy: 0.429 | Validation loss : 5.289886474609375 validation Accuracy: 0.017
Epoch: 192 | Train loss: 4.917597770690918 | Train Accuracy: 0.39 | Validation loss : 5.289917469024658 validation Accuracy: 0.017
Epoch: 193 | Train loss: 4.9058427810668945 | Train Accuracy: 0.404 | Validation loss : 5.2873454093933105 validation Accuracy: 0.017
Epoch: 194 | Train loss: 4.932845592498779 | Train Accuracy: 0.376 | Validation loss : 5.2981157302856445 validation Accuracy: 0.0
Epoch: 195 | Train loss: 4.934600353240967 | Train Accuracy: 0.372 | Validation loss : 5.273895740509033 validation Accuracy: 0.033
Epoch: 196 | Train loss: 4.9428229331970215 | Train Accuracy: 0.365 | Validation loss : 5.256562232971191 validation Accuracy: 0.05
Epoch: 197 | Train loss: 4.888514518737793 | Train Accuracy: 0.418 | Validation loss : 5.251914024353027 validation Accuracy: 0.067
Epoch: 198 | Train loss: 4.906411170959473 | Train Accuracy: 0.404 | Validat

Epoch: 255 | Train loss: 4.908508777618408 | Train Accuracy: 0.401 | Validation loss : 5.306574821472168 validation Accuracy: 0.0
Epoch: 256 | Train loss: 4.923518657684326 | Train Accuracy: 0.383 | Validation loss : 5.289930820465088 validation Accuracy: 0.017
Epoch: 257 | Train loss: 4.905694961547852 | Train Accuracy: 0.401 | Validation loss : 5.30639123916626 validation Accuracy: 0.0
Epoch: 258 | Train loss: 4.911223888397217 | Train Accuracy: 0.394 | Validation loss : 5.3064422607421875 validation Accuracy: 0.0
Epoch: 259 | Train loss: 4.894722938537598 | Train Accuracy: 0.411 | Validation loss : 5.278248310089111 validation Accuracy: 0.033
Epoch: 260 | Train loss: 4.928370475769043 | Train Accuracy: 0.379 | Validation loss : 5.305788516998291 validation Accuracy: 0.0
Epoch: 261 | Train loss: 4.897097587585449 | Train Accuracy: 0.408 | Validation loss : 5.305920124053955 validation Accuracy: 0.0
Epoch: 262 | Train loss: 4.885851860046387 | Train Accuracy: 0.422 | Validation loss :

Epoch: 319 | Train loss: 4.914436340332031 | Train Accuracy: 0.394 | Validation loss : 5.2899322509765625 validation Accuracy: 0.017
Epoch: 320 | Train loss: 4.855602264404297 | Train Accuracy: 0.45 | Validation loss : 5.306063175201416 validation Accuracy: 0.0
Epoch: 321 | Train loss: 4.888092994689941 | Train Accuracy: 0.418 | Validation loss : 5.2897820472717285 validation Accuracy: 0.017
Epoch: 322 | Train loss: 4.869461536407471 | Train Accuracy: 0.436 | Validation loss : 5.289750099182129 validation Accuracy: 0.017
Epoch: 323 | Train loss: 4.874000072479248 | Train Accuracy: 0.433 | Validation loss : 5.306656837463379 validation Accuracy: 0.0
Epoch: 324 | Train loss: 4.874401569366455 | Train Accuracy: 0.433 | Validation loss : 5.306521892547607 validation Accuracy: 0.0
Epoch: 325 | Train loss: 4.882318496704102 | Train Accuracy: 0.426 | Validation loss : 5.291946887969971 validation Accuracy: 0.017
Epoch: 326 | Train loss: 4.898782253265381 | Train Accuracy: 0.408 | Validation l

Epoch: 383 | Train loss: 4.8778862953186035 | Train Accuracy: 0.429 | Validation loss : 5.289883613586426 validation Accuracy: 0.017
Epoch: 384 | Train loss: 4.930652618408203 | Train Accuracy: 0.376 | Validation loss : 5.290147304534912 validation Accuracy: 0.017
Epoch: 385 | Train loss: 4.881070613861084 | Train Accuracy: 0.426 | Validation loss : 5.289870262145996 validation Accuracy: 0.017
Epoch: 386 | Train loss: 4.896829605102539 | Train Accuracy: 0.411 | Validation loss : 5.290026664733887 validation Accuracy: 0.017
Epoch: 387 | Train loss: 4.888757705688477 | Train Accuracy: 0.418 | Validation loss : 5.289286136627197 validation Accuracy: 0.017
Epoch: 388 | Train loss: 4.888269424438477 | Train Accuracy: 0.418 | Validation loss : 5.28996467590332 validation Accuracy: 0.017
Epoch: 389 | Train loss: 4.866866588592529 | Train Accuracy: 0.443 | Validation loss : 5.256466865539551 validation Accuracy: 0.05
Epoch: 390 | Train loss: 4.9232177734375 | Train Accuracy: 0.383 | Validation

Epoch: 446 | Train loss: 4.90029239654541 | Train Accuracy: 0.408 | Validation loss : 5.287339687347412 validation Accuracy: 0.017
Epoch: 447 | Train loss: 4.852132797241211 | Train Accuracy: 0.454 | Validation loss : 5.305287837982178 validation Accuracy: 0.0
Epoch: 448 | Train loss: 4.932157516479492 | Train Accuracy: 0.376 | Validation loss : 5.306417465209961 validation Accuracy: 0.0
Epoch: 449 | Train loss: 4.888040542602539 | Train Accuracy: 0.418 | Validation loss : 5.306683540344238 validation Accuracy: 0.0
Epoch: 450 | Train loss: 4.905305862426758 | Train Accuracy: 0.401 | Validation loss : 5.289985656738281 validation Accuracy: 0.017
Epoch: 451 | Train loss: 4.867110252380371 | Train Accuracy: 0.44 | Validation loss : 5.274848461151123 validation Accuracy: 0.033
Epoch: 452 | Train loss: 4.890566825866699 | Train Accuracy: 0.418 | Validation loss : 5.289966106414795 validation Accuracy: 0.017
Epoch: 453 | Train loss: 4.891411304473877 | Train Accuracy: 0.415 | Validation loss

Epoch: 510 | Train loss: 4.8810553550720215 | Train Accuracy: 0.426 | Validation loss : 5.273666858673096 validation Accuracy: 0.033
Epoch: 511 | Train loss: 4.871881484985352 | Train Accuracy: 0.436 | Validation loss : 5.266993045806885 validation Accuracy: 0.033
Epoch: 512 | Train loss: 4.899032115936279 | Train Accuracy: 0.408 | Validation loss : 5.27340841293335 validation Accuracy: 0.033
Epoch: 513 | Train loss: 4.859870433807373 | Train Accuracy: 0.447 | Validation loss : 5.272550106048584 validation Accuracy: 0.033
Epoch: 514 | Train loss: 4.901956081390381 | Train Accuracy: 0.404 | Validation loss : 5.2900214195251465 validation Accuracy: 0.017
Epoch: 515 | Train loss: 4.855853080749512 | Train Accuracy: 0.45 | Validation loss : 5.290261745452881 validation Accuracy: 0.017
Epoch: 516 | Train loss: 4.853254795074463 | Train Accuracy: 0.454 | Validation loss : 5.256536483764648 validation Accuracy: 0.05
Epoch: 517 | Train loss: 4.890539169311523 | Train Accuracy: 0.415 | Validati

Epoch: 573 | Train loss: 4.892871379852295 | Train Accuracy: 0.415 | Validation loss : 5.273344039916992 validation Accuracy: 0.033
Epoch: 574 | Train loss: 4.913792133331299 | Train Accuracy: 0.394 | Validation loss : 5.256597518920898 validation Accuracy: 0.05
Epoch: 575 | Train loss: 4.856273174285889 | Train Accuracy: 0.45 | Validation loss : 5.273245334625244 validation Accuracy: 0.033
Epoch: 576 | Train loss: 4.859337329864502 | Train Accuracy: 0.447 | Validation loss : 5.27340841293335 validation Accuracy: 0.033
Epoch: 577 | Train loss: 4.8564066886901855 | Train Accuracy: 0.45 | Validation loss : 5.273610591888428 validation Accuracy: 0.033
Epoch: 578 | Train loss: 4.8531670570373535 | Train Accuracy: 0.454 | Validation loss : 5.273379802703857 validation Accuracy: 0.033
Epoch: 579 | Train loss: 4.870476722717285 | Train Accuracy: 0.436 | Validation loss : 5.273375988006592 validation Accuracy: 0.033
Epoch: 580 | Train loss: 4.863478660583496 | Train Accuracy: 0.443 | Validatio

Epoch: 637 | Train loss: 4.8918585777282715 | Train Accuracy: 0.415 | Validation loss : 5.30659294128418 validation Accuracy: 0.0
Epoch: 638 | Train loss: 4.824531078338623 | Train Accuracy: 0.482 | Validation loss : 5.289961338043213 validation Accuracy: 0.017
Epoch: 639 | Train loss: 4.88828706741333 | Train Accuracy: 0.418 | Validation loss : 5.289953231811523 validation Accuracy: 0.017
Epoch: 640 | Train loss: 4.856411457061768 | Train Accuracy: 0.45 | Validation loss : 5.282410144805908 validation Accuracy: 0.033
Epoch: 641 | Train loss: 4.880568504333496 | Train Accuracy: 0.426 | Validation loss : 5.2886247634887695 validation Accuracy: 0.017
Epoch: 642 | Train loss: 4.859237194061279 | Train Accuracy: 0.447 | Validation loss : 5.290451526641846 validation Accuracy: 0.017
Epoch: 643 | Train loss: 4.884325981140137 | Train Accuracy: 0.422 | Validation loss : 5.28942346572876 validation Accuracy: 0.017
Epoch: 644 | Train loss: 4.852851390838623 | Train Accuracy: 0.454 | Validation 

Epoch: 701 | Train loss: 4.841515064239502 | Train Accuracy: 0.465 | Validation loss : 5.289856910705566 validation Accuracy: 0.017
Epoch: 702 | Train loss: 4.908974647521973 | Train Accuracy: 0.397 | Validation loss : 5.283307075500488 validation Accuracy: 0.017
Epoch: 703 | Train loss: 4.887240409851074 | Train Accuracy: 0.418 | Validation loss : 5.273991107940674 validation Accuracy: 0.033
Epoch: 704 | Train loss: 4.8350443840026855 | Train Accuracy: 0.472 | Validation loss : 5.274531841278076 validation Accuracy: 0.033
Epoch: 705 | Train loss: 4.86688756942749 | Train Accuracy: 0.44 | Validation loss : 5.274865627288818 validation Accuracy: 0.033
Epoch: 706 | Train loss: 4.867047309875488 | Train Accuracy: 0.44 | Validation loss : 5.265855312347412 validation Accuracy: 0.033
Epoch: 707 | Train loss: 4.845064163208008 | Train Accuracy: 0.461 | Validation loss : 5.269984245300293 validation Accuracy: 0.033
Epoch: 708 | Train loss: 4.89389181137085 | Train Accuracy: 0.415 | Validation

Epoch: 765 | Train loss: 4.870654582977295 | Train Accuracy: 0.436 | Validation loss : 5.2734150886535645 validation Accuracy: 0.033
Epoch: 766 | Train loss: 4.886019229888916 | Train Accuracy: 0.422 | Validation loss : 5.256734848022461 validation Accuracy: 0.05
Epoch: 767 | Train loss: 4.888322353363037 | Train Accuracy: 0.418 | Validation loss : 5.273453235626221 validation Accuracy: 0.033
Epoch: 768 | Train loss: 4.875268459320068 | Train Accuracy: 0.433 | Validation loss : 5.273331165313721 validation Accuracy: 0.033
Epoch: 769 | Train loss: 4.870744228363037 | Train Accuracy: 0.436 | Validation loss : 5.256680488586426 validation Accuracy: 0.05
Epoch: 770 | Train loss: 4.888587951660156 | Train Accuracy: 0.418 | Validation loss : 5.281905174255371 validation Accuracy: 0.033
Epoch: 771 | Train loss: 4.860908031463623 | Train Accuracy: 0.447 | Validation loss : 5.273403644561768 validation Accuracy: 0.033
Epoch: 772 | Train loss: 4.849874496459961 | Train Accuracy: 0.457 | Validati

Epoch: 828 | Train loss: 4.86706018447876 | Train Accuracy: 0.44 | Validation loss : 5.2733473777771 validation Accuracy: 0.033
Epoch: 829 | Train loss: 4.853301048278809 | Train Accuracy: 0.454 | Validation loss : 5.256564617156982 validation Accuracy: 0.05
Epoch: 830 | Train loss: 4.891839027404785 | Train Accuracy: 0.415 | Validation loss : 5.273305416107178 validation Accuracy: 0.033
Epoch: 831 | Train loss: 4.909560203552246 | Train Accuracy: 0.397 | Validation loss : 5.3066487312316895 validation Accuracy: 0.0
Epoch: 832 | Train loss: 4.862490177154541 | Train Accuracy: 0.443 | Validation loss : 5.25689172744751 validation Accuracy: 0.05
Epoch: 833 | Train loss: 4.834417819976807 | Train Accuracy: 0.472 | Validation loss : 5.274169445037842 validation Accuracy: 0.033
Epoch: 834 | Train loss: 4.857440948486328 | Train Accuracy: 0.45 | Validation loss : 5.306500434875488 validation Accuracy: 0.0
Epoch: 835 | Train loss: 4.896995544433594 | Train Accuracy: 0.411 | Validation loss : 

Epoch: 892 | Train loss: 4.8421759605407715 | Train Accuracy: 0.465 | Validation loss : 5.27339506149292 validation Accuracy: 0.033
Epoch: 893 | Train loss: 4.8315815925598145 | Train Accuracy: 0.475 | Validation loss : 5.2900872230529785 validation Accuracy: 0.017
Epoch: 894 | Train loss: 4.856380939483643 | Train Accuracy: 0.45 | Validation loss : 5.290082931518555 validation Accuracy: 0.017
Epoch: 895 | Train loss: 4.888321399688721 | Train Accuracy: 0.418 | Validation loss : 5.262619972229004 validation Accuracy: 0.05
Epoch: 896 | Train loss: 4.9057512283325195 | Train Accuracy: 0.401 | Validation loss : 5.262878894805908 validation Accuracy: 0.05
Epoch: 897 | Train loss: 4.868601322174072 | Train Accuracy: 0.44 | Validation loss : 5.2733683586120605 validation Accuracy: 0.033
Epoch: 898 | Train loss: 4.842260360717773 | Train Accuracy: 0.465 | Validation loss : 5.279802322387695 validation Accuracy: 0.033
Epoch: 899 | Train loss: 4.8316264152526855 | Train Accuracy: 0.475 | Valida

Epoch: 955 | Train loss: 4.874131679534912 | Train Accuracy: 0.433 | Validation loss : 5.290071964263916 validation Accuracy: 0.017
Epoch: 956 | Train loss: 4.842243671417236 | Train Accuracy: 0.465 | Validation loss : 5.306685924530029 validation Accuracy: 0.0
Epoch: 957 | Train loss: 4.845736026763916 | Train Accuracy: 0.461 | Validation loss : 5.28984260559082 validation Accuracy: 0.017
Epoch: 958 | Train loss: 4.863513469696045 | Train Accuracy: 0.443 | Validation loss : 5.290010929107666 validation Accuracy: 0.017
Epoch: 959 | Train loss: 4.849313735961914 | Train Accuracy: 0.457 | Validation loss : 5.290289402008057 validation Accuracy: 0.017
Epoch: 960 | Train loss: 4.866992473602295 | Train Accuracy: 0.44 | Validation loss : 5.290895462036133 validation Accuracy: 0.017
Epoch: 961 | Train loss: 4.8634514808654785 | Train Accuracy: 0.443 | Validation loss : 5.2895426750183105 validation Accuracy: 0.017
Epoch: 962 | Train loss: 4.89187479019165 | Train Accuracy: 0.415 | Validation

In [50]:
gru_MODEL = torch.load("model_gru.pth")
gru_MODEL.eval()

GRUTG(
  (embed): Embedding(200, 500)
  (rnn): GRU(500, 100, num_layers=2)
  (fc): Linear(in_features=400, out_features=200, bias=True)
  (prob): Softmax(dim=1)
  (drop): Dropout(p=0.4, inplace=False)
)

In [51]:
def TestAnalysis(xtest,model_saved, ytest, i2w):
    word_pred = []
    word_actual = []
    pred_test = model_saved(xtest)
    pred_test_ix = pred_test.argmax(dim=1)
    test_acc = TestAccuracy(pred_test, ytest)
    for i in pred_test_ix:
        word_pred.append(i2w[int(i.numpy())])

    for j in ytest:
        word_actual.append(i2w[int(j.numpy())])
        
    return (' '.join(word_actual), ' '.join(word_pred), test_acc)

In [52]:
actual_text, predicted_text, test_accuracy = TestAnalysis(X_test, gru_MODEL, Y_test, index_to_word)

In [53]:
print("actual text")
print(actual_text)
print()
print("predicted text")
print(predicted_text)
print()
print("test accuracy")
print(test_accuracy)

actual text
i decided to drop out and trust that it would all work out ok it was pretty scary at the time but looking back it was one of the best decisions i ever made the minute i dropped out i could stop taking the required classes that didn t interest me and begin dropping in on the ones that looked interesting

predicted text
she the at a the three it a of biological that was my closest that parents almost that relented mother do i that want that parents almost was out out is parents from call to finest story the my to that of to had as from never she working college to that after college and the that a almost it she

test accuracy
0.0
