In [None]:
import os
import pickle
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR

import warnings
warnings.filterwarnings('ignore')

# Import Path,Vocabulary, utility, evaluator and datahandler module
from config import Path
from dictionary import Vocabulary
from utils import Utils
from evaluate import Evaluator
from data import DataHandler


import random
import numpy as np
import copy

print(torch.cuda.is_available())

#set seed for reproducibility
utils = Utils()
utils.set_seed(1)

# SA-LSTM

In [None]:
#Import configuration and model 

from config import ConfigSALSTM
from models.SA_LSTM.model import SALSTM

#create Mean pooling object
cfg = ConfigSALSTM(opt_encoder=True)
# specifying the dataset in configuration object from {'msvd','msrvtt'}
cfg.dataset = 'msrvtt'

#Changing the hyperparameters in configuration object
cfg.batch_size = 100 #training batch size
cfg.n_layers = 1    # number of layers in decoder rnn
cfg.decoder_type = 'lstm'  # from {'lstm','gru'}
cfg.dropout = 0.5
cfg.opt_param_init = False



#creation of path object
path = Path(cfg,os.getcwd())
#Vocabulary object, 
voc = Vocabulary(cfg)
#If vocabulary is already saved or downloaded the saved file
voc.load() #comment this if using vocabulary for the first time or with no saved file

min_count = 5 #remove all words below count min_count
voc.trim(min_count=min_count)
print('Vocabulary Size : ',voc.num_words)

In [None]:
# Datasets and dataloaders
data_handler = DataHandler(cfg,path,voc)
train_dset,val_dset,test_dset = data_handler.getDatasets()
train_loader,val_loader,test_loader = data_handler.getDataloader(train_dset,val_dset,test_dset)

#Model object
model = SALSTM(voc,cfg,path).to('cuda:0')
#Evaluator object on test data
test_evaluator_greedy = Evaluator(model,test_loader,path,cfg,data_handler.test_dict)
test_evaluator_beam = Evaluator(model,test_loader,path,cfg,data_handler.test_dict,decoding_type='beam')

In [None]:
model.load_state_dict(torch.load("epochs_81.pth"))

In [None]:
#Training Loop
from torch.optim.lr_scheduler import ReduceLROnPlateau
cfg.encoder_lr = 1e-4
cfg.decoder_lr = 1e-4
cfg.teacher_forcing_ratio = 1.0
model.update_hyperparameters(cfg)
lr_scheduler = ReduceLROnPlateau(model.dec_optimizer, mode='min', factor=cfg.lr_decay_gamma,
                                     patience=cfg.lr_decay_patience, verbose=True)
for e in range(82, 701):
    loss_train = model.train_epoch(train_loader,utils)
    # loss_val = model.train_epoch(val_loader,utils)
    # lr_scheduler.step(loss_train)
    # if e%50 == 0 :
    print('Epoch -- >',e,'Loss -->',loss_train)
    # print('greedy :',test_evaluator_greedy.evaluate(utils,model,e,loss_train))
    # print('beam :',test_evaluator_beam.evaluate(utils,model,e,loss_train))

In [None]:
torch.save(model.state_dict(), "epochs_700.pth")

In [None]:
dataiter = iter(val_loader)
features, targets, mask, max_length,_,motion_feat,object_feat= next(dataiter)

In [None]:
tsr,txt,_ = model.GreedyDecoding(features.to(cfg.device))
txt

In [None]:
utils.target_tensor_to_caption(voc,targets)

In [None]:
txt = model.BeamDecoding(features.to(cfg.device), 10)
txt