In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
from batchify import Corpus
import random
import torch.optim as optim
from torch.autograd import Variable
from model import Encoder, Decoder

In [2]:
corpus=Corpus()

In [3]:
device=torch.device('cuda')
print (device)

cuda


In [4]:
def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion,corpus):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    encoded, hidden=encoder(input_tensor)
    batch_size=target_tensor.size(0)
    decoder_input = np.zeros((batch_size,1))
    decoder_input[:]=corpus.w2i['<sos>']
    decoder_input=torch.from_numpy(decoder_input).cuda().detach()
    dec_len=target_tensor.size(1)
    decoded, hidden, outputs=decoder(encoded,hidden,decoder_input.long(),dec_len,target_tensor)
    s=outputs.size(1)
    loss=0
    for i in range(s):
        loss+=criterion(outputs[:,i,:],target_tensor[:,i])
    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return loss/target_tensor.size(1)

In [5]:
def validation(input_tensor, target_tensor, encoder, decoder, criterion,corpus):
    with torch.no_grad():
        encoded, hidden=encoder(input_tensor.long())
        batch_size=target_tensor.size(0)
        decoder_input = np.zeros((batch_size,1))
        decoder_input[:]=corpus.w2i['<sos>']
        decoder_input=torch.from_numpy(decoder_input).cuda().detach()
        dec_len=target_tensor.size(1)
        decoded, hidden,outputs=decoder(encoded,hidden,decoder_input.long(),dec_len,target_tensor,val=True)
        s=outputs.size(1)
        loss=0
        for i in range(s):
            loss+=criterion(outputs[:,i,:],target_tensor[:,i].long())
        
    return loss/target_tensor.size(1)

In [6]:
encoder=Encoder(vocab_size=50000, embedding_size=250, hidden_size=250)
decoder=Decoder(vocab_size=50000, embedding_dim=250, hidden_dim=250)
encoder=encoder.to(device)
decoder=decoder.to(device)

In [7]:
encoder_optimizer = optim.Adagrad(encoder.parameters(), lr=0.15, initial_accumulator_value=0.1)
decoder_optimizer = optim.Adagrad(decoder.parameters(), lr=0.15, initial_accumulator_value=0.1)
criterion=nn.NLLLoss()

In [8]:
checkpoint = torch.load('model.pth')
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])
encoder_optimizer.load_state_dict(checkpoint['encoder_optimizer_state_dict'])
decoder_optimizer.load_state_dict(checkpoint['decoder_optimizer_state_dict'])
step = checkpoint['step']
tl = checkpoint['training_loss']
vl = checkpoint['validation_loss']

encoder.train()
decoder.train()

Decoder(
  (attn): Linear(in_features=500, out_features=400, bias=True)
  (attn_combine): Linear(in_features=500, out_features=250, bias=True)
  (linear1): Linear(in_features=500, out_features=250, bias=True)
  (linear2): Linear(in_features=500, out_features=250, bias=True)
  (embedding): Embedding(50000, 250)
  (gru): GRU(250, 250, batch_first=True)
  (linear): Linear(in_features=250, out_features=50000, bias=True)
)

In [9]:
corpus.counter=15160
print(corpus.counter,step)

15160 9960


In [None]:
#tl=[]
#vl=[]
num_steps=20000
val_loss_benchmark=vl[-1]
for i in range(step,num_steps):
    input_tensor_train,target_tensor_train=corpus.get_train_minibatch()
    input_tensor_val,target_tensor_val=corpus.get_validation_batch()
    input_tensor_train=Variable(input_tensor_train.cuda())
    target_tensor_train=Variable(target_tensor_train.cuda())
    input_tensor_val=Variable(input_tensor_val.cuda())
    target_tensor_val=Variable(target_tensor_val.cuda())
    train_loss=train(input_tensor_train,target_tensor_train , encoder, decoder, encoder_optimizer, decoder_optimizer, criterion,corpus)
    val_loss=validation(input_tensor_val,target_tensor_val, encoder, decoder, criterion,corpus)
    tl.append(train_loss)
    vl.append(val_loss)
    print ('Step: {}/{} | Training Loss: {} | Validation Loss: {}'.format(i+1,num_steps,train_loss,val_loss))
    
    if (i>10 and val_loss<=val_loss_benchmark):
            print ('%---Saving the model---%')
            torch.save({
                'step':i+1,
                'encoder_state_dict': encoder.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'encoder_optimizer_state_dict': encoder_optimizer.state_dict(),
                'decoder_optimizer_state_dict': decoder_optimizer.state_dict(),
                'training_loss':tl,
                'validation_loss':vl,
                },'model.pth')
            val_loss_benchmark=val_loss

Step: 9961/20000 | Training Loss: 5.336607933044434 | Validation Loss: 6.302183628082275
Step: 9962/20000 | Training Loss: 5.715422630310059 | Validation Loss: 6.294032573699951
Step: 9963/20000 | Training Loss: 5.420490264892578 | Validation Loss: 6.303720951080322
Step: 9964/20000 | Training Loss: 5.215981960296631 | Validation Loss: 6.309159278869629
Step: 9965/20000 | Training Loss: 5.363572120666504 | Validation Loss: 6.310832500457764
Step: 9966/20000 | Training Loss: 5.6544036865234375 | Validation Loss: 6.300436973571777
Step: 9967/20000 | Training Loss: 5.662583351135254 | Validation Loss: 6.289553642272949
Step: 9968/20000 | Training Loss: 5.521991729736328 | Validation Loss: 6.2934184074401855
Step: 9969/20000 | Training Loss: 6.0064473152160645 | Validation Loss: 6.303558826446533
Step: 9970/20000 | Training Loss: 5.026932239532471 | Validation Loss: 6.308955669403076
Step: 9971/20000 | Training Loss: 5.810401916503906 | Validation Loss: 6.303990364074707
Step: 9972/20000 |

Step: 10053/20000 | Training Loss: 5.185098171234131 | Validation Loss: 6.311402797698975
Step: 10054/20000 | Training Loss: 6.142661094665527 | Validation Loss: 6.2993998527526855
Step: 10055/20000 | Training Loss: 5.228925704956055 | Validation Loss: 6.305155277252197
Step: 10056/20000 | Training Loss: 5.202666759490967 | Validation Loss: 6.3082075119018555
Step: 10057/20000 | Training Loss: 6.134332656860352 | Validation Loss: 6.30678653717041
Step: 10058/20000 | Training Loss: 5.232161998748779 | Validation Loss: 6.302610874176025
Step: 10059/20000 | Training Loss: 5.092964172363281 | Validation Loss: 6.314047813415527
Step: 10060/20000 | Training Loss: 5.749735355377197 | Validation Loss: 6.296676158905029
Step: 10061/20000 | Training Loss: 5.19612455368042 | Validation Loss: 6.3007988929748535
Step: 10062/20000 | Training Loss: 5.6711626052856445 | Validation Loss: 6.3036112785339355
Step: 10063/20000 | Training Loss: 5.430266380310059 | Validation Loss: 6.302990436553955
Step: 1

Step: 10144/20000 | Training Loss: 5.589088439941406 | Validation Loss: 6.287568092346191
Step: 10145/20000 | Training Loss: 4.6944966316223145 | Validation Loss: 6.34027624130249
Step: 10146/20000 | Training Loss: 5.238375663757324 | Validation Loss: 6.329432964324951
Step: 10147/20000 | Training Loss: 5.853811264038086 | Validation Loss: 6.303055286407471
Step: 10148/20000 | Training Loss: 4.667355060577393 | Validation Loss: 6.333863735198975
Step: 10149/20000 | Training Loss: 5.32357120513916 | Validation Loss: 6.31191349029541
Step: 10150/20000 | Training Loss: 5.708406925201416 | Validation Loss: 6.303318023681641
Step: 10151/20000 | Training Loss: 5.6416192054748535 | Validation Loss: 6.302514553070068
Step: 10152/20000 | Training Loss: 5.76275634765625 | Validation Loss: 6.289306163787842
Step: 10153/20000 | Training Loss: 5.2579450607299805 | Validation Loss: 6.303128719329834
Step: 10154/20000 | Training Loss: 6.187253952026367 | Validation Loss: 6.282688140869141
%---Saving 

Step: 10235/20000 | Training Loss: 5.428486347198486 | Validation Loss: 6.287701606750488
Step: 10236/20000 | Training Loss: 6.367542266845703 | Validation Loss: 6.304799556732178
Step: 10237/20000 | Training Loss: 5.354739665985107 | Validation Loss: 6.300777435302734
Step: 10238/20000 | Training Loss: 5.828014373779297 | Validation Loss: 6.299764633178711
Step: 10239/20000 | Training Loss: 5.468513488769531 | Validation Loss: 6.300422668457031
Step: 10240/20000 | Training Loss: 5.623467922210693 | Validation Loss: 6.310715675354004
Step: 10241/20000 | Training Loss: 4.968943119049072 | Validation Loss: 6.311724662780762
Step: 10242/20000 | Training Loss: 6.022042274475098 | Validation Loss: 6.289191246032715
Step: 10243/20000 | Training Loss: 5.6338701248168945 | Validation Loss: 6.291934967041016
Step: 10244/20000 | Training Loss: 5.807989120483398 | Validation Loss: 6.304866790771484
Step: 10245/20000 | Training Loss: 5.232584476470947 | Validation Loss: 6.316832065582275
Step: 102

Step: 10326/20000 | Training Loss: 5.19545841217041 | Validation Loss: 6.295201778411865
Step: 10327/20000 | Training Loss: 5.8711419105529785 | Validation Loss: 6.3362016677856445
Step: 10328/20000 | Training Loss: 5.810329914093018 | Validation Loss: 6.281959056854248
Step: 10329/20000 | Training Loss: 6.29758358001709 | Validation Loss: 6.276095390319824
Step: 10330/20000 | Training Loss: 5.833550930023193 | Validation Loss: 6.277763366699219
Step: 10331/20000 | Training Loss: 5.660676956176758 | Validation Loss: 6.30195426940918
Step: 10332/20000 | Training Loss: 6.303869247436523 | Validation Loss: 6.2795915603637695
Step: 10333/20000 | Training Loss: 5.467462062835693 | Validation Loss: 6.2984724044799805
Step: 10334/20000 | Training Loss: 6.099415302276611 | Validation Loss: 6.280705451965332
Step: 10335/20000 | Training Loss: 5.561647891998291 | Validation Loss: 6.2809929847717285
Step: 10336/20000 | Training Loss: 5.969916820526123 | Validation Loss: 6.282216548919678
Step: 10

Step: 10418/20000 | Training Loss: 4.908864974975586 | Validation Loss: 6.2901458740234375
Step: 10419/20000 | Training Loss: 5.803831100463867 | Validation Loss: 6.278071403503418
Step: 10420/20000 | Training Loss: 5.628860950469971 | Validation Loss: 6.269072532653809
Step: 10421/20000 | Training Loss: 5.759814262390137 | Validation Loss: 6.263304710388184
Step: 10422/20000 | Training Loss: 5.609277248382568 | Validation Loss: 6.267226696014404
Step: 10423/20000 | Training Loss: 5.180486679077148 | Validation Loss: 6.280616760253906
Step: 10424/20000 | Training Loss: 5.630953311920166 | Validation Loss: 6.310789108276367
Step: 10425/20000 | Training Loss: 6.270550727844238 | Validation Loss: 6.258435249328613
Step: 10426/20000 | Training Loss: 5.676966190338135 | Validation Loss: 6.260519981384277
Step: 10427/20000 | Training Loss: 5.748951435089111 | Validation Loss: 6.259109973907471
Step: 10428/20000 | Training Loss: 5.542107582092285 | Validation Loss: 6.26997184753418
Step: 1042

Step: 10509/20000 | Training Loss: 5.884615421295166 | Validation Loss: 6.274389743804932
Step: 10510/20000 | Training Loss: 5.4999566078186035 | Validation Loss: 6.27646017074585
Step: 10511/20000 | Training Loss: 5.315322399139404 | Validation Loss: 6.284076690673828
Step: 10512/20000 | Training Loss: 5.70214319229126 | Validation Loss: 6.279187202453613
Step: 10513/20000 | Training Loss: 6.0215654373168945 | Validation Loss: 6.271651268005371
Step: 10514/20000 | Training Loss: 5.984991073608398 | Validation Loss: 6.268319606781006
Step: 10515/20000 | Training Loss: 5.767414569854736 | Validation Loss: 6.275357723236084
Step: 10516/20000 | Training Loss: 5.269726276397705 | Validation Loss: 6.286497592926025
Step: 10517/20000 | Training Loss: 5.724367618560791 | Validation Loss: 6.287006378173828
Step: 10518/20000 | Training Loss: 5.799556255340576 | Validation Loss: 6.28331184387207
Step: 10519/20000 | Training Loss: 4.527338981628418 | Validation Loss: 6.344598293304443
Step: 10520

Step: 10601/20000 | Training Loss: 5.601198673248291 | Validation Loss: 6.292696952819824
Step: 10602/20000 | Training Loss: 5.994409084320068 | Validation Loss: 6.28103494644165
Step: 10603/20000 | Training Loss: 5.562344551086426 | Validation Loss: 6.274105548858643
Step: 10604/20000 | Training Loss: 5.869498252868652 | Validation Loss: 6.27333402633667
Step: 10605/20000 | Training Loss: 6.074211120605469 | Validation Loss: 6.272628307342529
Step: 10606/20000 | Training Loss: 5.937117099761963 | Validation Loss: 6.276615619659424
Step: 10607/20000 | Training Loss: 6.149474143981934 | Validation Loss: 6.272722244262695
Step: 10608/20000 | Training Loss: 5.202827453613281 | Validation Loss: 6.289615154266357
Step: 10609/20000 | Training Loss: 5.3923516273498535 | Validation Loss: 6.300896167755127
Step: 10610/20000 | Training Loss: 5.176161766052246 | Validation Loss: 6.312606334686279
Step: 10611/20000 | Training Loss: 5.890388488769531 | Validation Loss: 6.296887397766113
Step: 10612

Step: 10692/20000 | Training Loss: 5.355832576751709 | Validation Loss: 6.289157390594482
Step: 10693/20000 | Training Loss: 5.726739406585693 | Validation Loss: 6.2887725830078125
Step: 10694/20000 | Training Loss: 5.956144332885742 | Validation Loss: 6.285572052001953
Step: 10695/20000 | Training Loss: 5.816746711730957 | Validation Loss: 6.285782337188721
Step: 10696/20000 | Training Loss: 5.076285362243652 | Validation Loss: 6.299558639526367
Step: 10697/20000 | Training Loss: 5.7727742195129395 | Validation Loss: 6.286783218383789
Step: 10698/20000 | Training Loss: 5.458215236663818 | Validation Loss: 6.301659107208252
Step: 10699/20000 | Training Loss: 5.738052845001221 | Validation Loss: 6.290841102600098
Step: 10700/20000 | Training Loss: 5.776301860809326 | Validation Loss: 6.286419868469238
Step: 10701/20000 | Training Loss: 5.297580718994141 | Validation Loss: 6.303589344024658
Step: 10702/20000 | Training Loss: 5.991734981536865 | Validation Loss: 6.291856288909912
Step: 10

Step: 10784/20000 | Training Loss: 5.952221393585205 | Validation Loss: 6.257375717163086
Step: 10785/20000 | Training Loss: 5.679624557495117 | Validation Loss: 6.254715442657471
%---Saving the model---%
Step: 10786/20000 | Training Loss: 5.4871649742126465 | Validation Loss: 6.259748458862305
Step: 10787/20000 | Training Loss: 6.111208438873291 | Validation Loss: 6.252424240112305
%---Saving the model---%
Step: 10788/20000 | Training Loss: 5.8610405921936035 | Validation Loss: 6.257777214050293
Step: 10789/20000 | Training Loss: 5.955638408660889 | Validation Loss: 6.262369155883789
Step: 10790/20000 | Training Loss: 5.885756969451904 | Validation Loss: 6.276546478271484
Step: 10791/20000 | Training Loss: 5.842705726623535 | Validation Loss: 6.269923210144043
Step: 10792/20000 | Training Loss: 6.137452125549316 | Validation Loss: 6.257787704467773
Step: 10793/20000 | Training Loss: 5.718854904174805 | Validation Loss: 6.254543781280518
Step: 10794/20000 | Training Loss: 5.66992044448

Step: 10874/20000 | Training Loss: 6.150381565093994 | Validation Loss: 6.257813930511475
Step: 10875/20000 | Training Loss: 5.971368789672852 | Validation Loss: 6.2706990242004395
Step: 10876/20000 | Training Loss: 6.28732442855835 | Validation Loss: 6.263794898986816
Step: 10877/20000 | Training Loss: 5.835357666015625 | Validation Loss: 6.263377666473389
Step: 10878/20000 | Training Loss: 6.055538654327393 | Validation Loss: 6.262484073638916
Step: 10879/20000 | Training Loss: 5.6258225440979 | Validation Loss: 6.26190710067749
Step: 10880/20000 | Training Loss: 5.533100605010986 | Validation Loss: 6.255860805511475
Step: 10881/20000 | Training Loss: 5.90877628326416 | Validation Loss: 6.265361309051514
Step: 10882/20000 | Training Loss: 6.205242156982422 | Validation Loss: 6.261037349700928
Step: 10883/20000 | Training Loss: 5.827335357666016 | Validation Loss: 6.252494812011719
Step: 10884/20000 | Training Loss: 5.896160125732422 | Validation Loss: 6.26103401184082
Step: 10885/200

Step: 10965/20000 | Training Loss: 5.79731559753418 | Validation Loss: 6.290469169616699
Step: 10966/20000 | Training Loss: 5.972685813903809 | Validation Loss: 6.2710041999816895
Step: 10967/20000 | Training Loss: 5.97210693359375 | Validation Loss: 6.253302574157715
Step: 10968/20000 | Training Loss: 6.147520065307617 | Validation Loss: 6.2506608963012695
Step: 10969/20000 | Training Loss: 5.772219181060791 | Validation Loss: 6.265109539031982
Step: 10970/20000 | Training Loss: 5.535101413726807 | Validation Loss: 6.264286041259766
Step: 10971/20000 | Training Loss: 5.6436872482299805 | Validation Loss: 6.261476039886475
Step: 10972/20000 | Training Loss: 5.219757556915283 | Validation Loss: 6.271806716918945
Step: 10973/20000 | Training Loss: 6.495218753814697 | Validation Loss: 6.257348537445068
Step: 10974/20000 | Training Loss: 5.884106636047363 | Validation Loss: 6.271995544433594
Step: 10975/20000 | Training Loss: 5.737949848175049 | Validation Loss: 6.267213344573975
Step: 109

Step: 11056/20000 | Training Loss: 6.448904991149902 | Validation Loss: 6.259054183959961
Step: 11057/20000 | Training Loss: 5.133313179016113 | Validation Loss: 6.278114318847656
Step: 11058/20000 | Training Loss: 5.514863967895508 | Validation Loss: 6.266664028167725
Step: 11059/20000 | Training Loss: 6.3541741371154785 | Validation Loss: 6.2562713623046875
Step: 11060/20000 | Training Loss: 5.469985008239746 | Validation Loss: 6.275471210479736
Step: 11061/20000 | Training Loss: 6.201667308807373 | Validation Loss: 6.258304119110107
Step: 11062/20000 | Training Loss: 6.136003017425537 | Validation Loss: 6.260190010070801
Step: 11063/20000 | Training Loss: 6.498969078063965 | Validation Loss: 6.307629585266113
Step: 11064/20000 | Training Loss: 6.26291036605835 | Validation Loss: 6.263864517211914
Step: 11065/20000 | Training Loss: 5.679205894470215 | Validation Loss: 6.269524097442627
Step: 11066/20000 | Training Loss: 4.995804309844971 | Validation Loss: 6.293070316314697
Step: 110

Step: 11147/20000 | Training Loss: 6.450538158416748 | Validation Loss: 6.326948642730713
Step: 11148/20000 | Training Loss: 6.111885070800781 | Validation Loss: 6.268406391143799
Step: 11149/20000 | Training Loss: 5.998831272125244 | Validation Loss: 6.253241539001465
Step: 11150/20000 | Training Loss: 6.045167922973633 | Validation Loss: 6.251236438751221
Step: 11151/20000 | Training Loss: 5.767307758331299 | Validation Loss: 6.248231410980225
Step: 11152/20000 | Training Loss: 5.965418815612793 | Validation Loss: 6.258783340454102
Step: 11153/20000 | Training Loss: 6.048044681549072 | Validation Loss: 6.267022132873535
Step: 11154/20000 | Training Loss: 5.366024971008301 | Validation Loss: 6.271442890167236
Step: 11155/20000 | Training Loss: 6.57291316986084 | Validation Loss: 6.260072708129883
Step: 11156/20000 | Training Loss: 5.856303691864014 | Validation Loss: 6.2587761878967285
Step: 11157/20000 | Training Loss: 6.207441806793213 | Validation Loss: 6.258338928222656
Step: 1115

Step: 11238/20000 | Training Loss: 5.937983989715576 | Validation Loss: 6.244182586669922
Step: 11239/20000 | Training Loss: 6.17411470413208 | Validation Loss: 6.2434587478637695
Step: 11240/20000 | Training Loss: 5.50005578994751 | Validation Loss: 6.246013641357422
Step: 11241/20000 | Training Loss: 5.651598930358887 | Validation Loss: 6.247166633605957
Step: 11242/20000 | Training Loss: 5.468685150146484 | Validation Loss: 6.257784843444824
Step: 11243/20000 | Training Loss: 5.774059772491455 | Validation Loss: 6.254687786102295
Step: 11244/20000 | Training Loss: 5.666408538818359 | Validation Loss: 6.252106666564941
Step: 11245/20000 | Training Loss: 5.74297571182251 | Validation Loss: 6.249843597412109
Step: 11246/20000 | Training Loss: 6.385883808135986 | Validation Loss: 6.242934226989746
Step: 11247/20000 | Training Loss: 6.058100700378418 | Validation Loss: 6.249375343322754
Step: 11248/20000 | Training Loss: 5.42416524887085 | Validation Loss: 6.252216815948486
Step: 11249/2

Step: 11329/20000 | Training Loss: 5.919845104217529 | Validation Loss: 6.238550662994385
Step: 11330/20000 | Training Loss: 5.90155029296875 | Validation Loss: 6.23322868347168
Step: 11331/20000 | Training Loss: 6.186261177062988 | Validation Loss: 6.231151103973389
%---Saving the model---%
Step: 11332/20000 | Training Loss: 6.0237507820129395 | Validation Loss: 6.230689525604248
%---Saving the model---%
Step: 11333/20000 | Training Loss: 5.844259738922119 | Validation Loss: 6.237310409545898
Step: 11334/20000 | Training Loss: 5.683506011962891 | Validation Loss: 6.23783540725708
Step: 11335/20000 | Training Loss: 5.411160469055176 | Validation Loss: 6.254985332489014
Step: 11336/20000 | Training Loss: 6.014742374420166 | Validation Loss: 6.32317590713501
Step: 11337/20000 | Training Loss: 6.0147857666015625 | Validation Loss: 6.280334949493408
Step: 11338/20000 | Training Loss: 5.778575420379639 | Validation Loss: 6.2415690422058105
Step: 11339/20000 | Training Loss: 5.16601610183715

In [17]:
corpus.counter

15160