In [1]:
from torchvision import utils
from data_loader import *
import torchvision
from torchvision import transforms
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch
import time
import pickle as pkl
import vocabulary_struct
import AnnoNet_GLOVE
import csv
import numpy as np
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [2]:
#vocabulary_struct.handle_glove("glove.6B.50d.txt")

In [3]:
with open('TrainImageIds.csv', 'r') as f:
    reader = csv.reader(f)
    trainIds = list(reader)

trainIds = [int(i) for i in trainIds[0]]
with open('TestImageIds.csv', 'r') as f:
    reader = csv.reader(f)
    testIds = list(reader)

testIds = [int(i) for i in testIds[0]]

In [4]:
valIds = trainIds[:int(0.2*len(trainIds))]
del trainIds[:int(0.2*len(trainIds))]

In [5]:
with open('glove_vocab','rb') as f:
    vocab = pkl.load(f)
with open('glove_weights','rb') as f:
    glove_weights = pkl.load(f)
glove_weights = torch.from_numpy(glove_weights).float()

In [6]:
batch_size = 32
#Implement normalization later
transform = transforms.Compose([
    transforms.Resize(250),
    transforms.CenterCrop(250),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))
])
train_loader = get_loader(root = './data/images/train/',
                          json = './data/annotations/captions_train2014.json',
                          ids = trainIds,
                          vocab = vocab,
                          transform = transform,
                          batch_size = batch_size,
                          shuffle = True,
                          num_workers = 4)
val_loader = get_loader(root = './data/images/train/',
                          json = './data/annotations/captions_train2014.json',
                          ids = valIds,
                          vocab = vocab,
                          transform = transform,
                          batch_size = batch_size,
                          shuffle = True,
                          num_workers = 4)
test_loader = get_loader(root = './data/images/test/',
                          json = './data/annotations/captions_val2014.json',
                          ids = testIds,
                          vocab = vocab,
                          transform = transform,
                          batch_size = batch_size,
                          shuffle = True,
                          num_workers = 4)

loading annotations into memory...
Done (t=0.95s)
creating index...
index created!
loading annotations into memory...
Done (t=1.06s)
creating index...
index created!
loading annotations into memory...
Done (t=0.46s)
creating index...
index created!


In [7]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight.data)
        #torch.nn.init.xavier_uniform_(m.bias.data)
        torch.nn.init.zeros_(m.bias.data)
        
epochs     = 100
#criterion = # Choose an appropriate loss function from https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html
criterion = torch.nn.CrossEntropyLoss()
weights_dict = {'weight': glove_weights}
AnnoNet = AnnoNet_GLOVE.AnnoNet_G(vocab_size = len(vocab), batch_size = batch_size, embedding_dim=50,weights_dict = weights_dict,hidden_dim = 512, hidden_units=1)
AnnoNet.apply(init_weights)
optimizer = optim.Adam(AnnoNet.parameters(), lr=5e-4, weight_decay = .001)

In [None]:
use_gpu = torch.cuda.is_available()
cpu_device = torch.device("cpu")
if use_gpu:
    device = torch.device("cuda:0")
    #fcn_model = fcn_model.cuda()
    #fcn_model = fcn_model.to(device)
    AnnoNet = AnnoNet.to(device)
    
def train(batch_size, check_num = 5):
    counter = 0 
    losses = []
    accuracies = []
    val_losses = []
    val_accuracies = []
    for epoch in range(epochs):
        ts = time.time()
        rolling_loss = 0
        theCounter = 0
        for iter, (X, tar, Y) in enumerate(train_loader):
            optimizer.zero_grad()
            if use_gpu:
                inputs = X.to(device)# Move your inputs onto the gpu
                labels = tar.to(device,dtype=torch.int64)# Move your labels onto the gpu
            else:
                inputs, labels = (X,tar)# Unpack variables into inputs and labels
            
            #print("lengths: ", Y)
            outputs = AnnoNet(inputs, labels, Y)
            del inputs
            torch.cuda.empty_cache()
            #output_captions, output_labels = output_captioning(outputs)
            '''if iter % 100 == 0:
                print(output_captions)'''
            #print(outputs.shape)
            labels = pack_padded_sequence(labels, Y, batch_first=True)
            #print(labels.data.shape)
            loss = criterion(outputs, labels.data)
            #Acc, _, _, _ = prediction_and_Accuracy(outputs, labels)
            del outputs,labels
            torch.cuda.empty_cache()
            loss.backward()
            optimizer.step()
            
            if iter % 10 == 0:
                print("epoch{}, iter{}, loss: {}".format(epoch, iter, loss.item()))
            
            rolling_loss += loss.item()
            del loss
            torch.cuda.empty_cache()
            theCounter += 1
            
        print("Finish epoch {}, time elapsed {}".format(epoch, time.time() - ts))
        '''Normalizing_Factor = len(train_loader) * batch_size
        losses.append(rolling_loss / Normalizing_Factor)
        accuracies.append(rolling_acc / Normalizing_Factor)
        loss_val, acc_val, IOU_val = val(epoch, batch_size)
        val_losses.append(loss_val)
        val_accuracies.append(acc_val)
        #val_IOUs.append(IOU_val)
        fcn_model.train()'''
        Normalizing_Factor = theCounter * batch_size
        losses.append(rolling_loss / Normalizing_Factor)
        loss_val = val(epoch, batch_size)
        val_losses.append(loss_val)
        AnnoNet.train()
        #Early Stopping for validation Loss
        if epoch == 0:
            torch.save(AnnoNet.state_dict(), 'best_model.pt')
        else:
            if torch.argmin(torch.Tensor(val_losses)) == epoch:
                torch.save(AnnoNet.state_dict(), 'best_model.pt')
                counter = 0
            else:
                counter += 1
        torch.save(val_losses,"val_losses")
        #torch.save(val_accuracies,"val_accuracies")
        #torch.save(val_IOUs,"val_IOUs")
        torch.save(losses,"train_loss")
        #torch.save(accuracies,'train_accs')
        #torch.save(IOUs,'train_IOU')
        if counter == check_num:
            print("early stop achieved")
            break
    
def output_captioning(anOutput, temp):
    softmaxed_output = F.softmax(anOutput/temp, dim = 1)
    output_labels = torch.argmax(softmaxed_output, dim = 1)
    daShape = output_labels.shape
    output_captions = []
    for i in range(daShape[0]):
        tempList = []
        for j in range(daShape[1]):
            tempList.append(vocab(int(output_labels[i, j])))
            
        output_captions.append(tempList)
        
    return output_captions, output_labels 
    
    
    
def val(epoch, batch_size):
    AnnoNet.eval()
    ts = time.time()
    rolling_loss = 0
    rolling_acc = 0
    counter = 0
    for iter, (X, tar, Y) in enumerate(val_loader):
        if use_gpu:
            inputs = X.to(device)# Move your inputs onto the gpu
            labels = tar.to(device,dtype=torch.int64)# Move your labels onto the gpu
        else:
            inputs, labels = (X,tar)# Unpack variables into inputs and labels

        #print("lengths: ", Y)
        outputs = AnnoNet(inputs, labels, Y)
        del inputs
        torch.cuda.empty_cache()
        #output_captions, output_labels = output_captioning(outputs)
        '''if iter % 100 == 0:
            print(output_captions)'''
        #print(outputs.shape)
        labels = pack_padded_sequence(labels, Y, batch_first=True)
        #print(labels.data.shape)
        loss = criterion(outputs, labels.data)
        rolling_loss += loss.item()
        #comp = outputs.data.to(cpu_device)
        #Acc, TP_sum, FP_sum, FN_sum = prediction_and_Accuracy(outputs, labels)
        del outputs,labels
        torch.cuda.empty_cache()
        #rolling_acc += Acc
        
        if iter% 10 == 0:
            print("epoch{}, iter{}, loss: {}".format(epoch, iter, loss.item()))
        del loss
        torch.cuda.empty_cache()
        counter += 1
    
    print("Finish epoch {}, time elapsed {}".format(epoch, time.time() - ts))
    Normalizing_Factor = counter * batch_size
    #rolling_acc /= Normalizing_Factor
    #rolling_iou /= Normalizing_Factor
    rolling_loss /= Normalizing_Factor
    print("Average loss: ",rolling_loss)
    #print("Accuracy: ", rolling_acc)
    #print("IOU: ", rolling_iou)
    #print("TP's: ", rolling_TP)
    #print("FP's: ", rolling_FP)
    #print("FN's: ", rolling_FN)
    return rolling_loss
    #Complete this function - Calculate loss, accuracy and IoU for every epoch
    # Make sure to include a softmax after the output from your model
    
def test(batch_size):
    #Complete this function - Calculate accuracy and IoU 
    # Make sure to include a softmax after the output from your model
    #Net = AnnoNet(vocab_size = len(vocab), batch_size = batch_size, embedding_dim=60,hidden_dim = 5, hidden_units=1)
    #use_gpu = torch.cuda.is_available()
    #cpu_device = torch.device("cpu")
    #if use_gpu:
    #    device = torch.device("cuda:0")
    #    #fcn_model = fcn_model.cuda()
    #    Net = Net.to(device)
    #Net.load_state_dict(torch.load(path))
    AnnoNet.eval()
    ts = time.time()
    rolling_loss = 0
    rolling_acc = 0
    for iter, (X, tar, Y) in enumerate(test_loader):
        if use_gpu:
            inputs = X.to(device)# Move your inputs onto the gpu
            labels = tar.to(device,dtype=torch.int64)# Move your labels onto the gpu
        else:
            inputs, labels = (X,tar)# Unpack variables into inputs and labels
        outputs = AnnoNet(inputs, labels, Y)
        del inputs
        torch.cuda.empty_cache()
        loss = criterion(outputs,labels)
        rolling_loss += loss.item()
        #comp = outputs.data.to(cpu_device)
        #Acc, TP_sum, FP_sum, FN_sum = prediction_and_Accuracy(outputs, labels)
        del outputs,labels
        torch.cuda.empty_cache()
        #rolling_acc += Acc              
        if iter% 10 == 0:
            print("test, iter{}, loss: {}".format(iter, loss.item()))
        del loss
        torch.cuda.empty_cache()

    print("Finish test, time elapsed {}".format(time.time() - ts))
    Normalizing_Factor = len(test_loader)# * batch_size
    rolling_acc /= Normalizing_Factor
    #rolling_iou /= Normalizing_Factor
    rolling_loss /= Normalizing_Factor
    print("Accuracy: ", rolling_acc)
    print("IOU: ", rolling_iou)
    print("TP's: ", rolling_TP)
    print("FP's: ", rolling_FP)
    print("FN's: ", rolling_FN)
    return rolling_loss, rolling_acc, rolling_iou
    
if __name__ == "__main__":
    #val(0,batch_size)  # show the accuracy before training
    train(batch_size)
    print("yay")
    #test("best_model_save.pt", batch_size)

epoch0, iter0, loss: 12.899178504943848
epoch0, iter10, loss: 12.727303504943848
epoch0, iter20, loss: 10.783590316772461
epoch0, iter30, loss: 8.283262252807617
epoch0, iter40, loss: 7.569025993347168
epoch0, iter50, loss: 7.508580684661865
epoch0, iter60, loss: 6.761693954467773
epoch0, iter70, loss: 6.838247299194336
epoch0, iter80, loss: 6.948451995849609
epoch0, iter90, loss: 6.557712078094482
epoch0, iter100, loss: 6.1033854484558105
epoch0, iter110, loss: 6.1755242347717285
epoch0, iter120, loss: 6.005166053771973
epoch0, iter130, loss: 6.300960063934326
epoch0, iter140, loss: 5.842947006225586
epoch0, iter150, loss: 5.944695472717285
epoch0, iter160, loss: 5.632076263427734
epoch0, iter170, loss: 5.443544864654541
epoch0, iter180, loss: 5.979888439178467
epoch0, iter190, loss: 5.425782203674316
epoch0, iter200, loss: 5.534978866577148
epoch0, iter210, loss: 5.454797267913818
epoch0, iter220, loss: 5.545133590698242
epoch0, iter230, loss: 5.591846466064453
epoch0, iter240, loss:

epoch0, iter1980, loss: 3.7380948066711426
epoch0, iter1990, loss: 3.777279853820801
epoch0, iter2000, loss: 3.823946952819824
epoch0, iter2010, loss: 3.754255771636963
epoch0, iter2020, loss: 3.726382255554199
epoch0, iter2030, loss: 3.864579677581787
epoch0, iter2040, loss: 3.6906445026397705
epoch0, iter2050, loss: 3.8011250495910645
epoch0, iter2060, loss: 4.120181083679199
epoch0, iter2070, loss: 3.775834560394287
Finish epoch 0, time elapsed 632.4531817436218
epoch0, iter0, loss: 3.756016254425049
epoch0, iter10, loss: 3.9889397621154785
epoch0, iter20, loss: 3.9806737899780273
epoch0, iter30, loss: 3.854557991027832
epoch0, iter40, loss: 4.365982532501221
epoch0, iter50, loss: 3.731600046157837
epoch0, iter60, loss: 3.6225838661193848
epoch0, iter70, loss: 3.9070823192596436
epoch0, iter80, loss: 3.8480429649353027
epoch0, iter90, loss: 4.120165824890137
epoch0, iter100, loss: 3.852818489074707
epoch0, iter110, loss: 3.9345366954803467
epoch0, iter120, loss: 4.329192638397217
ep

epoch1, iter1330, loss: 3.615802049636841
epoch1, iter1340, loss: 3.5875396728515625
epoch1, iter1350, loss: 3.878119945526123
epoch1, iter1360, loss: 3.7888545989990234
epoch1, iter1370, loss: 3.4260175228118896
epoch1, iter1380, loss: 3.7252607345581055
epoch1, iter1390, loss: 3.3073320388793945
epoch1, iter1400, loss: 3.7629520893096924
epoch1, iter1410, loss: 3.5317909717559814
epoch1, iter1420, loss: 3.292107105255127
epoch1, iter1430, loss: 3.731891393661499
epoch1, iter1440, loss: 3.6067678928375244
epoch1, iter1450, loss: 3.7174384593963623
epoch1, iter1460, loss: 3.8105292320251465
epoch1, iter1470, loss: 3.6220943927764893
epoch1, iter1480, loss: 3.67153000831604
epoch1, iter1490, loss: 3.7821760177612305
epoch1, iter1500, loss: 3.4439597129821777
epoch1, iter1510, loss: 3.462078094482422
epoch1, iter1520, loss: 3.5134193897247314
epoch1, iter1530, loss: 3.466668128967285
epoch1, iter1540, loss: 3.5164225101470947
epoch1, iter1550, loss: 3.4730522632598877
epoch1, iter1560, l

epoch2, iter660, loss: 3.5674898624420166
epoch2, iter670, loss: 3.6899824142456055
epoch2, iter680, loss: 3.5422263145446777
epoch2, iter690, loss: 3.891930103302002
epoch2, iter700, loss: 4.059230804443359
epoch2, iter710, loss: 3.611468553543091
epoch2, iter720, loss: 3.7151243686676025
epoch2, iter730, loss: 3.5383219718933105
epoch2, iter740, loss: 3.5082452297210693
epoch2, iter750, loss: 3.5645904541015625
epoch2, iter760, loss: 3.6469967365264893
epoch2, iter770, loss: 3.631530284881592
epoch2, iter780, loss: 3.4902281761169434
epoch2, iter790, loss: 3.470036268234253
epoch2, iter800, loss: 3.6616437435150146
epoch2, iter810, loss: 3.6730711460113525
epoch2, iter820, loss: 3.4810705184936523
epoch2, iter830, loss: 3.915956497192383
epoch2, iter840, loss: 3.671893358230591
epoch2, iter850, loss: 3.746896982192993
epoch2, iter860, loss: 3.5568478107452393
epoch2, iter870, loss: 3.5321738719940186
epoch2, iter880, loss: 3.6111114025115967
epoch2, iter890, loss: 3.7418622970581055


Finish epoch 2, time elapsed 70.39189124107361
Average loss:  0.11088440547118315
epoch3, iter0, loss: 3.6068289279937744
epoch3, iter10, loss: 3.6182196140289307
epoch3, iter20, loss: 3.785083532333374
epoch3, iter30, loss: 3.45871639251709
epoch3, iter40, loss: 3.358154535293579
epoch3, iter50, loss: 3.373025417327881
epoch3, iter60, loss: 3.5328822135925293
epoch3, iter70, loss: 3.534076690673828
epoch3, iter80, loss: 3.460949659347534
epoch3, iter90, loss: 3.34708309173584
epoch3, iter100, loss: 3.2780869007110596
epoch3, iter110, loss: 3.508272886276245
epoch3, iter120, loss: 3.4575371742248535
epoch3, iter130, loss: 3.7379043102264404
epoch3, iter140, loss: 3.6669809818267822
epoch3, iter150, loss: 4.002566814422607
epoch3, iter160, loss: 3.6422834396362305
epoch3, iter170, loss: 3.4250776767730713
epoch3, iter180, loss: 3.4229347705841064
epoch3, iter190, loss: 3.597707748413086
epoch3, iter200, loss: 3.5042128562927246
epoch3, iter210, loss: 3.350057363510132
epoch3, iter220, l

epoch3, iter1940, loss: 3.3655221462249756
epoch3, iter1950, loss: 3.598308563232422
epoch3, iter1960, loss: 3.1986937522888184
epoch3, iter1970, loss: 3.192089557647705
epoch3, iter1980, loss: 3.4944682121276855
epoch3, iter1990, loss: 3.5591812133789062
epoch3, iter2000, loss: 3.2855122089385986
epoch3, iter2010, loss: 3.1564676761627197
epoch3, iter2020, loss: 3.292203903198242
epoch3, iter2030, loss: 3.5663414001464844
epoch3, iter2040, loss: 3.256190061569214
epoch3, iter2050, loss: 3.728153944015503
epoch3, iter2060, loss: 3.5738725662231445
epoch3, iter2070, loss: 3.809415102005005
Finish epoch 3, time elapsed 635.1166062355042
epoch3, iter0, loss: 2.943115472793579
epoch3, iter10, loss: 3.7065577507019043
epoch3, iter20, loss: 3.3817551136016846
epoch3, iter30, loss: 3.4895176887512207
epoch3, iter40, loss: 3.4571692943573
epoch3, iter50, loss: 3.3206474781036377
epoch3, iter60, loss: 3.2282328605651855
epoch3, iter70, loss: 3.7003395557403564
epoch3, iter80, loss: 3.7262117862

epoch4, iter1290, loss: 3.3938329219818115
epoch4, iter1300, loss: 3.498778820037842
epoch4, iter1310, loss: 3.7189512252807617
epoch4, iter1320, loss: 3.5090208053588867
epoch4, iter1330, loss: 3.522216796875
epoch4, iter1340, loss: 3.622100353240967
epoch4, iter1350, loss: 3.646538019180298
epoch4, iter1360, loss: 3.6036181449890137
epoch4, iter1370, loss: 3.6962108612060547
epoch4, iter1380, loss: 3.3919360637664795
epoch4, iter1390, loss: 3.576357364654541
epoch4, iter1400, loss: 3.4119386672973633
epoch4, iter1410, loss: 3.505222797393799
epoch4, iter1420, loss: 3.6531903743743896
epoch4, iter1430, loss: 3.383723020553589
epoch4, iter1440, loss: 3.904603958129883
epoch4, iter1450, loss: 3.4568281173706055
epoch4, iter1460, loss: 3.6213202476501465
epoch4, iter1470, loss: 3.5080974102020264
epoch4, iter1480, loss: 3.451042652130127
epoch4, iter1490, loss: 3.4867870807647705
epoch4, iter1500, loss: 3.399686336517334
epoch4, iter1510, loss: 3.723374128341675
epoch4, iter1520, loss: 3

epoch5, iter620, loss: 3.4445526599884033
epoch5, iter630, loss: 3.3485107421875
epoch5, iter640, loss: 3.181551218032837
epoch5, iter650, loss: 3.424758195877075
epoch5, iter660, loss: 3.523463249206543
epoch5, iter670, loss: 3.631316900253296
epoch5, iter680, loss: 3.2144277095794678
epoch5, iter690, loss: 3.62845778465271
epoch5, iter700, loss: 3.3189034461975098
epoch5, iter710, loss: 3.3099100589752197
epoch5, iter720, loss: 3.6732869148254395
epoch5, iter730, loss: 3.618372678756714
epoch5, iter740, loss: 3.265782594680786
epoch5, iter750, loss: 3.5560479164123535
epoch5, iter760, loss: 3.2410483360290527
epoch5, iter770, loss: 3.1558353900909424
epoch5, iter780, loss: 3.323070764541626
epoch5, iter790, loss: 3.5282883644104004
epoch5, iter800, loss: 3.232008695602417
epoch5, iter810, loss: 3.730205535888672
epoch5, iter820, loss: 3.7732272148132324
epoch5, iter830, loss: 3.492213249206543
epoch5, iter840, loss: 3.359849452972412
epoch5, iter850, loss: 3.088974714279175
epoch5, i

epoch5, iter480, loss: 3.536928176879883
epoch5, iter490, loss: 3.33455491065979
epoch5, iter500, loss: 2.8867437839508057
epoch5, iter510, loss: 3.136993885040283
Finish epoch 5, time elapsed 69.42300462722778
Average loss:  0.10756579947632712
epoch6, iter0, loss: 3.6461105346679688
epoch6, iter10, loss: 3.641261100769043
epoch6, iter20, loss: 3.5170750617980957
epoch6, iter30, loss: 3.3426673412323
epoch6, iter40, loss: 3.3848063945770264
epoch6, iter50, loss: 3.25150203704834
epoch6, iter60, loss: 3.4229772090911865
epoch6, iter70, loss: 3.1919333934783936
epoch6, iter80, loss: 3.2234246730804443
epoch6, iter90, loss: 3.1929519176483154
epoch6, iter100, loss: 3.168135166168213
epoch6, iter110, loss: 3.447538375854492
epoch6, iter120, loss: 3.376899242401123
epoch6, iter130, loss: 3.2897067070007324
epoch6, iter140, loss: 3.1681430339813232
epoch6, iter150, loss: 3.4738800525665283
epoch6, iter160, loss: 3.3261759281158447
epoch6, iter170, loss: 3.1447601318359375
epoch6, iter180, l

epoch6, iter1900, loss: 3.382688283920288
epoch6, iter1910, loss: 3.3084683418273926
epoch6, iter1920, loss: 3.8418033123016357
epoch6, iter1930, loss: 3.172325372695923
epoch6, iter1940, loss: 3.693877696990967
epoch6, iter1950, loss: 3.755488157272339
epoch6, iter1960, loss: 3.2519590854644775
epoch6, iter1970, loss: 3.631856679916382
epoch6, iter1980, loss: 3.3802506923675537
epoch6, iter1990, loss: 3.4179325103759766
epoch6, iter2000, loss: 3.426042079925537
epoch6, iter2010, loss: 3.4321789741516113
epoch6, iter2020, loss: 3.817146062850952
epoch6, iter2030, loss: 3.3339481353759766
epoch6, iter2040, loss: 3.2636797428131104
epoch6, iter2050, loss: 3.647645950317383
epoch6, iter2060, loss: 3.1691300868988037
epoch6, iter2070, loss: 3.311338186264038
Finish epoch 6, time elapsed 633.5119941234589
epoch6, iter0, loss: 3.397704839706421
epoch6, iter10, loss: 3.4886510372161865
epoch6, iter20, loss: 3.7867109775543213
epoch6, iter30, loss: 3.2933175563812256
epoch6, iter40, loss: 3.87

epoch7, iter1240, loss: 3.1007444858551025
epoch7, iter1250, loss: 3.814894437789917
epoch7, iter1260, loss: 3.402437210083008
epoch7, iter1270, loss: 3.2003886699676514
epoch7, iter1280, loss: 2.965158462524414
epoch7, iter1290, loss: 3.587110757827759
epoch7, iter1300, loss: 3.532130241394043
epoch7, iter1310, loss: 3.3641867637634277
epoch7, iter1320, loss: 3.0874414443969727
epoch7, iter1330, loss: 3.347928285598755
epoch7, iter1340, loss: 3.1378304958343506
epoch7, iter1350, loss: 3.2594144344329834
epoch7, iter1360, loss: 3.249619960784912
epoch7, iter1370, loss: 3.1204440593719482
epoch7, iter1380, loss: 3.3203463554382324
epoch7, iter1390, loss: 3.618009328842163
epoch7, iter1400, loss: 3.2041704654693604
epoch7, iter1410, loss: 3.513514757156372
epoch7, iter1420, loss: 3.404068946838379
epoch7, iter1430, loss: 3.306751251220703
epoch7, iter1440, loss: 3.324153184890747
epoch7, iter1450, loss: 3.2817227840423584
epoch7, iter1460, loss: 3.549757242202759
epoch7, iter1470, loss: 

epoch8, iter570, loss: 3.539581298828125
epoch8, iter580, loss: 3.161011219024658
epoch8, iter590, loss: 3.4718949794769287
epoch8, iter600, loss: 3.291203498840332
epoch8, iter610, loss: 3.6351444721221924
epoch8, iter620, loss: 3.6235735416412354
epoch8, iter630, loss: 3.451643943786621
epoch8, iter640, loss: 3.4115517139434814
epoch8, iter650, loss: 2.965175151824951
epoch8, iter660, loss: 3.3647122383117676
epoch8, iter670, loss: 3.578000783920288
epoch8, iter680, loss: 3.374776840209961
epoch8, iter690, loss: 3.3411097526550293
epoch8, iter700, loss: 3.2728431224823
epoch8, iter710, loss: 3.037637710571289
epoch8, iter720, loss: 3.4408605098724365
epoch8, iter730, loss: 3.2366280555725098
epoch8, iter740, loss: 3.0147762298583984
epoch8, iter750, loss: 3.4678094387054443
epoch8, iter760, loss: 3.5893490314483643
epoch8, iter770, loss: 3.5798282623291016
epoch8, iter780, loss: 3.323486328125
epoch8, iter790, loss: 3.4398069381713867
epoch8, iter800, loss: 3.112659454345703
epoch8, 

epoch8, iter430, loss: 3.415677070617676
epoch8, iter440, loss: 3.405123233795166
epoch8, iter450, loss: 3.4072511196136475
epoch8, iter460, loss: 3.417980909347534
epoch8, iter470, loss: 3.3364033699035645
epoch8, iter480, loss: 3.3546853065490723
epoch8, iter490, loss: 3.2390992641448975
epoch8, iter500, loss: 3.2244601249694824
epoch8, iter510, loss: 3.274512529373169
Finish epoch 8, time elapsed 68.38964700698853
Average loss:  0.10620019252638559
epoch9, iter0, loss: 3.4493188858032227
epoch9, iter10, loss: 3.234549045562744
epoch9, iter20, loss: 3.094661235809326
epoch9, iter30, loss: 3.056356191635132
epoch9, iter40, loss: 3.125501871109009
epoch9, iter50, loss: 3.3093857765197754
epoch9, iter60, loss: 2.839493989944458
epoch9, iter70, loss: 3.2856736183166504
epoch9, iter80, loss: 3.4575626850128174
epoch9, iter90, loss: 3.415083169937134
epoch9, iter100, loss: 3.450549364089966
epoch9, iter110, loss: 3.172800064086914
epoch9, iter120, loss: 3.353508472442627
epoch9, iter130, l

epoch9, iter1850, loss: 3.310908794403076
epoch9, iter1860, loss: 3.417971611022949
epoch9, iter1870, loss: 3.1903538703918457
epoch9, iter1880, loss: 3.493523359298706
epoch9, iter1890, loss: 3.901566982269287
epoch9, iter1900, loss: 3.1891026496887207
epoch9, iter1910, loss: 3.6699869632720947
epoch9, iter1920, loss: 3.489750623703003
epoch9, iter1930, loss: 2.9910125732421875
epoch9, iter1940, loss: 3.067253351211548
epoch9, iter1950, loss: 3.572356939315796
epoch9, iter1960, loss: 3.385000228881836
epoch9, iter1970, loss: 3.5176775455474854
epoch9, iter1980, loss: 3.3593978881835938
epoch9, iter1990, loss: 3.4549710750579834
epoch9, iter2000, loss: 3.6361634731292725
epoch9, iter2010, loss: 3.349573850631714
epoch9, iter2020, loss: 3.5515263080596924
epoch9, iter2030, loss: 3.7863779067993164
epoch9, iter2040, loss: 3.3395280838012695
epoch9, iter2050, loss: 3.684882879257202
epoch9, iter2060, loss: 3.629758834838867
epoch9, iter2070, loss: 3.3247640132904053
Finish epoch 9, time e

epoch10, iter1170, loss: 3.3962173461914062
epoch10, iter1180, loss: 3.2569706439971924
epoch10, iter1190, loss: 3.078261613845825
epoch10, iter1200, loss: 3.5210378170013428
epoch10, iter1210, loss: 3.3967649936676025
epoch10, iter1220, loss: 3.0998318195343018
epoch10, iter1230, loss: 3.691671848297119
epoch10, iter1240, loss: 3.400357484817505
epoch10, iter1250, loss: 3.3977363109588623
epoch10, iter1260, loss: 3.3402702808380127
epoch10, iter1270, loss: 3.493067741394043
epoch10, iter1280, loss: 3.2324302196502686
epoch10, iter1290, loss: 3.1999056339263916
epoch10, iter1300, loss: 3.4352471828460693
epoch10, iter1310, loss: 3.2653419971466064
epoch10, iter1320, loss: 3.296992540359497
epoch10, iter1330, loss: 3.14339542388916
epoch10, iter1340, loss: 3.4651713371276855
epoch10, iter1350, loss: 3.6672091484069824
epoch10, iter1360, loss: 3.171729564666748
epoch10, iter1370, loss: 3.3934707641601562
epoch10, iter1380, loss: 3.1630725860595703
epoch10, iter1390, loss: 3.0012931823730

epoch11, iter460, loss: 3.477633237838745
epoch11, iter470, loss: 3.1129310131073
epoch11, iter480, loss: 3.4915647506713867
epoch11, iter490, loss: 3.1527140140533447
epoch11, iter500, loss: 3.81381893157959
epoch11, iter510, loss: 3.523545503616333
epoch11, iter520, loss: 3.5587353706359863
epoch11, iter530, loss: 2.9326188564300537
epoch11, iter540, loss: 3.4719438552856445
epoch11, iter550, loss: 3.1237740516662598
epoch11, iter560, loss: 3.0976786613464355
epoch11, iter570, loss: 3.50243878364563
epoch11, iter580, loss: 3.6117446422576904
epoch11, iter590, loss: 3.5072975158691406
epoch11, iter600, loss: 3.4927780628204346
epoch11, iter610, loss: 3.291506290435791
epoch11, iter620, loss: 3.5454115867614746
epoch11, iter630, loss: 2.8368008136749268
epoch11, iter640, loss: 3.4830820560455322
epoch11, iter650, loss: 3.575608253479004
epoch11, iter660, loss: 3.324195384979248
epoch11, iter670, loss: 3.555264949798584
epoch11, iter680, loss: 3.4665417671203613
epoch11, iter690, loss: 

In [None]:
tensor1 = torch.tensor([[1,2,3,4,25120,6,7],[8,9,10,25120,12,13,14]])
tensor2 = torch.reshape(tensor1, (-1,))
end_list = (tensor2 == 25120).nonzero().tolist()
random = list(range(0,2))
random = [x * 7 for x in random]
new_list = []
for i in range(len(random)):
    new_list.append(tensor2[random])
    

In [None]:
print(vocab(0))