In [4]:
from decoder import *
from encoder import *
from data_loader import *
import pickle
import random
import torch.optim as optim
import csv
import time
from tqdm import tqdm
import gc
import os

In [5]:
def trainEncoderDecoder(encoder, decoder, criterion, epochs,
                        train_loader,val_loader, test_loader, name):
    
    #Create non-existing logfiles
    logname = './logs/' + name + '.log'
    i = 0
    if os.path.exists(logname) == True:
        
        logname = './logs/' + name + str(i) + '.log'
        while os.path.exists(logname):
            i+=1
            logname = './logs/' + name + str(i) + '.log'

    print('Loading results to logfile: ' + logname)
    with open(logname, "a") as file:
        file.write("Log file DATA: Validation Loss and Accuracy\n") 
    
    logname_summary = './logs/' + name + '_summary' + str(i) + '.log'    
    print('Loading Summary to : ' + logname_summary) 
    
    parameters = list(encoder.parameters())
    parameters.extend(list(decoder.parameters()))
    optimizer = optim.Adam(parameters, lr=5e-3)
    use_gpu = torch.cuda.is_available()
    if use_gpu:
        device = torch.device("cuda:0")
        encoder = torch.nn.DataParallel(encoder)
        decoder = torch.nn.DataParallel(decoder)
        
        encoder.to(device)
        decoder.to(device)
        
        
    
    val_loss_set = []
    val_acc_set = []
    val_iou_set = []
    
    
    training_loss = []
    
    # Early Stop criteria
    minLoss = 1e6
    minLossIdx = 0
    earliestStopEpoch = 10
    earlyStopDelta = 5
    for epoch in range(epochs):
        ts = time.time()

        #import pdb; pdb.set_trace()                     
        for iter, (inputs, tar, labels) in tqdm(enumerate(train_loader)):
            print("Inputs:")
            print(inputs)
            print("Tar")
            print(tar)
            print("Labels")
            print(labels)
            optimizer.zero_grad()
            del tar
            
            if use_gpu:
                inputs = inputs.to(device)# Move your inputs onto the gpu
                labels = labels.to(device) # Move your labels onto the gpu
            
                
            #outputs = model(inputs)
            #del inputs
            loss = criterion(outputs, Variable(labels.long()))
            del labels
            del outputs

            loss.backward()
            loss = loss#.item()
            optimizer.step()

            if iter % 10 == 0:
                print("epoch{}, iter{}, loss: {}".format(epoch, iter, loss))

        
        # calculate val loss each epoch
#         val_loss, val_acc, val_iou = val(model, val_loader, criterion, use_gpu)
#         val_loss_set.append(val_loss)
#         val_acc_set.append(val_acc)
#         val_iou_set.append(val_iou)
        
#         print("epoch {}, time {}, train loss {}, val loss {}, val acc {}, val iou {}".format(epoch, time.time() - ts,
#                                                                                                loss, val_loss,
#                                                                                                val_acc,
#                                                                                                val_iou))        
        training_loss.append(loss)
        
        with open(logname, "a") as file:
            file.write("writing!\n")
            file.write("Finish epoch {}, time elapsed {}".format(epoch, time.time() - ts))
            file.write("\n training Loss:   " + str(loss.item()))
#             file.write("\n Validation Loss: " + str(val_loss_set[-1]))
#             file.write("\n Validation acc:  " + str(val_acc_set[-1]))
#             file.write("\n Validation iou:  " + str(val_iou_set[-1]) + "\n ")                                             
                                                                                                
                                                                                                
        
        # Early stopping
#         if val_loss < minLoss:
#             # Store new best
#             torch.save(model, name)
#             minLoss = val_loss#.item()
#             minLossIdx = epoch
            
        # If passed min threshold, and no new min has been reached for delta epochs
#         elif epoch > earliestStopEpoch and (epoch - minLossIdx) > earlyStopDelta:
#             print("Stopping early at {}".format(minLossIdx))
#             break
        

        
        
    with open(logname_summary, "a") as file:
            file.write("Summary!\n")
            file.write("Stopped early at {}".format(minLossIdx))
            file.write("\n training Loss:   " + str(training_loss))        
            file.write("\n Validation Loss: " + str(val_loss_set))
            file.write("\n Validation acc:  " + str(val_acc_set))
            file.write("\n Validation iou:  " + str(val_iou_set) + "\n ")
            
        
    #return val_loss_set, val_acc_set, val_iou_set

In [6]:
if __name__=='__main__':
    with open('TrainImageIds.csv', 'r') as f:
        reader = csv.reader(f)
        trainIds = list(reader)
        
    with open('TestImageIds.csv', 'r') as f:
        reader = csv.reader(f)
        testIds = list(reader)
    
    trainIds = [int(i) for i in trainIds[0]]
    testIds = [int(i) for i in testIds[0]]
    
    # Will shuffle the trainIds incase of ordering in csv
    random.shuffle(trainIds)
    splitIdx = int(len(trainIds)/5)
    
    # Selecting 1/5 of training set as validation
    valIds = trainIds[:splitIdx]
    trainIds = trainIds[splitIdx:]
    #print(trainIds)
    

    trainValRoot = "./data/images/train/"
    testRoot = "./data/images/train/"
    
    trainValJson = "./data/annotations/captions_train2014.json"
    testJson = "./data/annotations/captions_val2014.json"
    
    
    with open('./data/vocab.pkl', 'rb') as f:
        vocab = pickle.load(f)
    
    transform = None
    batch_size = 1
    shuffle = True
    num_workers = 1
    
    
    trainDl = get_loader(trainValRoot, trainValJson, trainIds, vocab, 
                         transform=None, batch_size=batch_size, 
                         shuffle=shuffle, num_workers=1)
    valDl = get_loader(trainValRoot, trainValJson, valIds, vocab, 
                         transform=None, batch_size=batch_size, 
                         shuffle=shuffle, num_workers=1)
    testDl = get_loader(testRoot, testJson, testIds, vocab, 
                         transform=None, batch_size=batch_size, 
                         shuffle=shuffle, num_workers=1)
    
    encoded_feature_dim = 10
    hidden_dim = 50
    
    encoder = Encoder(encoded_feature_dim)
    decoder = Decoder(encoded_feature_dim, hidden_dim, vocab.idx)
    
    criterion = nn.NLLLoss()
    
    epochs = 100
    trainEncoderDecoder(encoder, decoder, criterion, epochs, 
                        trainDl, valDl, testDl, "LSTM")
    
    
    
    
    
    
    

loading annotations into memory...
Done (t=0.98s)
creating index...
index created!
loading annotations into memory...
Done (t=0.64s)
creating index...
index created!
loading annotations into memory...
Done (t=0.35s)
creating index...
index created!
Loading results to logfile: ./logs/LSTM7.log
Loading Summary to : ./logs/LSTM_summary7.log



0it [00:00, ?it/s][A

dict_items([('image_id', 99402), ('id', 276009), ('caption', 'an image of a tour bus driving down the road')])


KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/datasets/home/home-02/64/364/rhadden/Image-Captioning-using-LSTM-network/data_loader.py", line 36, in __getitem__
    print(coco.anns[ann_id].items())
KeyError: 31462


In [None]:
%debug