In [1]:
import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models

import nltk

In [2]:
class Embeddings():
    def __init__(self, vocab, emd_dims):
        self.vocab = vocab
        self.embeds = nn.Embedding(len(vocab), emd_dims)
    
    def get_embedding(self, word):
        lookup_tensor = torch.tensor(self.vocab[word], dtype = torch.long)
        return self.embeds(lookup_tensor)
    
    def vocab_size():
        return len(self.vocab.keys())

In [3]:
test_df = pd.read_csv('./data/coco/coco_test_all.csv')

In [4]:
data_folder = ['coco','bing','flickr']
dataframes = []
for f in data_folder:
    files = os.listdir('./data/'+f)   
    for path in files:
        if('.csv' in path):
            csv_path = './data/'+f+'/'+path
            print(csv_path)
            df = pd.read_csv(csv_path)
            dataframes.append(df)
len(dataframes)

./data/coco/coco_test_all.csv
./data/coco/coco_train_all.csv
./data/coco/coco_val_all.csv
./data/bing/bing_train_all.csv
./data/bing/bing_test_all.csv
./data/bing/bing_val_all.csv
./data/flickr/flickr_val_all.csv
./data/flickr/flickr_train_all.csv
./data/flickr/flickr_test_all.csv


9

In [5]:
df = pd.concat(dataframes, axis = 0)
len(df)

14815

In [6]:
questions = list(df['questions'])
freq = {}
for q in questions:
    for question in q.split('---'):
        wordlist = nltk.word_tokenize(question)
        for word in wordlist:
            if(word not in freq):
                freq[word] = 1
            else:
                freq[word] += 1
len(freq)

12098

In [8]:
vocab = {}
counter = 0
for key in freq.keys():
    if freq[key]>=3:
        vocab[key] = counter
        counter += 1


vocab['<eoq>'] = counter
len(vocab)

5062

In [12]:
class Vqgnet(nn.Module):
    def __init__(self, num_lstm_layers, embedding, max_len):
        super(Vqgnet, self).__init__()
        self.embedding = embedding
        self.model_vgg = models.vgg19(pretrained=True)
        for p in self.model_vgg.parameters():
            p.requires_grad = False
        self.features = self.model_vgg.classifier[:-1]
        self.transform_layer = nn.Linear(4096, 512)
        self.feature_to_word = nn.Linear(512, self.embedding.vocab_size())
        self.n_lstm_layers = n_lstm_layers
        self.lstm = nn.LSTM(512, 512)
        self.max_len = max_len
    
    def forward(image, question):
        # teacher forcing using gt question
        
        # getting image features
        x = self.features(image)
        x = F.relu(x)
        x = F.relu(self.transform_layer(x))
        
        cell_state = torch.randn(1, 1, 512)
        predicted_question = []
        # embedding phase
        for i in range(len(question)):
            if(i == 0):
                embed = x
            else:
                word = question[i]
                embed = self.embedding.get_embedding(word)
            output, cell_state = self.lstm(embed, cell_state)
            output = F.softmax(self.feature_to_word(output))
            predicted_question.append(output)
        
        return predicted_question
    
    def test(image):
        # get image features
        x = self.features(image)
        x = F.relu(x)
        x = F.relu(self.tranform_layer(x))
        
        cell_state = torch.randn(1, 1, 512)
        output = ''
        predicted_question = []
        # generate question
        while output != self.embedding.vocab_size():
            x, cell_state = self.lstm(x, cell_state)
            output = F.softmax(self.feature_to_word(x))
            predicted_question.append(output)
            output = torch.argmax(output)
        
        return predicted_question
        

In [13]:
def train(model, dataloader, criterion, optimizer, scheduler, device, embedding, vocab, num_epochs=25):
    train_loss = []
    val_loss = []
    
    # looping over number of epochs
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        # looping over train validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
        
            # looping over phase data 
            for image, question in dataloader[phase]:

                image = image.to(device)
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    if(phase == 'train'):
                        output = model(image, question)
                    else:
                        output = model.test(image)
                        if len(output) < len(question):
                            for i in range(len(question) - len(output)):
                                output.append(torch.zeros(output[0].shape))

                    # getting one_hot encoding
                    one_hot = torch.zeros([len(output), embedding.vocab_size()])
                    for i in range(len(question)):
                        one_hot[i, vocab[question[i]]] = 1.0
                        
                    # finding the loss
                    loss = criterion(output, one_hot)
                    
                    # back propogating the weights
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        
                # adding to loss
                running_loss += loss.item()
            
            if phase == 'train':
                scheduler.step()
            
            # finding and printing epoch loss
            epoch_loss = running_loss / len(dataloader[phase])
            print('{} Loss: {:.4f} '.format(
                phase, epoch_loss))
            
            # appending loss to list 
            if(phase == 'train'):
                train_loss.append(epoch_loss)
            else:
                val_loss.append(epoch_loss)
                    
    return train_loss, val_loss