In [20]:
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import variable
import torchvision as tv
import nntools as nt
import modifiednntools as mnt
import torch

In [21]:
class NNClassifier(mnt.NeuralNetwork):
    def __init__(self):
        super(NNClassifier, self).__init__()
        self.cross_entropy = nn.CrossEntropyLoss()
    
    def criterion(self, y, d):
        print(y.shape)
        print(d.shape)
        return self.cross_entropy(y, d)

In [22]:
class VGGNet(nn.Module):
    def __init__(self, output_features, fine_tuning=False):
        super(VGGNet, self).__init__()
        vgg = tv.models.vgg16_bn(pretrained=True)
        
        #freezing the feature extraction layers
        for param in vgg.parameters():
            param.requires_grad = fine_tuning
            
        self.features = vgg.features
        
        self.num_fts = 512
        self.output_features = output_features
        
        # Linear layer goes from 512 to 1024
        self.classifier = nn.Linear(self.num_fts, self.output_features)
        self.tanh = nn.Tanh()
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):        
        h = self.features(x)
        
        h = self.classifier(h.view(-1, self.num_fts)).view(-1, 196, self.output_features)
        y = self.dropout(self.tanh(h))
        
        return y

In [23]:
class LSTM(nn.Module): 
    def __init__(self, vocab_size, embedding_dim, num_layers=1, batch_size=100, hidden_dim=1024):
        super(LSTM,self).__init__()
        self.vocab_size = vocab_size
        self.batch_size = batch_size
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim
        
        self.linear = nn.Linear(vocab_size, embedding_dim, bias=False)
        
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, 
                            num_layers=num_layers)
    
    def forward(self, question_vec):        
        q = question_vec.view(self.batch_size, -1, self.vocab_size)
        
        h0 = torch.randn(1, self.batch_size, self.hidden_dim)
        c0 = torch.randn(1, self.batch_size, self.hidden_dim)
        
        h0 = h0.cuda()
        c0 = c0.cuda()
                
        h = self.linear(q)
        
        # h is of shape sequence length x batch size x embedding dimension (1000)
        h = h.view(q.shape[1], self.batch_size, self.embedding_dim)
                
        _, states = self.lstm(h, (h0, c0))
                
        hidden_state,_ = states
        
        return hidden_state[0].view(self.batch_size, 1, self.hidden_dim)

In [24]:
class AttentionNet(nn.Module):
    def __init__(self, num_classes, batch_size, input_features, output_features):
        #v_i in dxm => 1024x196 vec
        #v_q in d => 1024x1 vec
        #Wia v_i in kxm => kx196
        #will choose k => 512
        super(AttentionNet,self).__init__()
        self.input_features = input_features 
        self.output_features = output_features #k 
        self.num_classes = num_classes
        self.batch_size = batch_size
        
        self.q_transform1 = nn.Linear(input_features, output_features)
        self.image_transform1 = nn.Linear(input_features, output_features, bias=False)
        self.fc31 = nn.Linear(output_features, 1)
        
        self.q_transform2 = nn.Linear(input_features, output_features)
        self.image_transform2 = nn.Linear(input_features, output_features, bias=False)
        self.fc32 = nn.Linear(output_features, 1)
        
        self.answerDist = nn.Linear(input_features, self.num_classes)
        
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=2)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, image_vec, question_vec):
        #do linear on fc1
        u_0 = question_vec
        
        q_transformation = self.q_transform1(u_0)
        
        #do linear on fc2 
        image_transformation = self.image_transform1(image_vec)
        
        #perform addition of a matrix and a vector 
        hA = self.tanh(image_transformation + q_transformation)
        
        #perform softmax on fc3 with result of tanh 
        x = self.fc31(hA)
        pI = self.softmax(x) #196x1

        v_0 = torch.matmul(image_vec.view(self.batch_size, self.input_features, -1), pI)

        u_1 = v_0.view(self.batch_size, 1, self.input_features) + u_0 
        
        q_transformation2 = self.q_transform2(u_0)
        
        #do linear on fc2 
        image_transformation2 = self.image_transform2(image_vec)
        
        #perform addition of a matrix and a vector 
        hA = self.tanh(image_transformation2 + q_transformation2)
        
        #perform softmax on fc3 with result of tanh 
        x = self.fc32(hA)
        pI = self.softmax(x) #196x1
        v_1 = torch.matmul(image_vec.view(self.batch_size, self.input_features, -1), pI)
        u_2 = v_1.view(self.batch_size, 1, self.input_features) + u_1 

        #perform softmax to get a final answer distribution
        linear = self.answerDist(u_2)
        linear = linear - linear.max()
        
        #pI = self.softmax(linear)

        return linear                

In [25]:
class SAN(NNClassifier):
    def __init__(self, output_vgg, vocab_size, batch_size, embedding_dim, num_classes, 
                 input_attention, output_attention, fine_tuning=False):
        super(SAN, self).__init__()
        #output_featured -> 1024
        self.vgg = VGGNet(output_vgg)
        
        #vocab_size,embedding_dim =1000
        self.lstm = LSTM(vocab_size=vocab_size, embedding_dim=embedding_dim, 
                         batch_size=batch_size)
        
        #num_classes = 10000
        self.attention = AttentionNet(num_classes=num_classes, batch_size=batch_size, 
                                      input_features=input_attention, output_features=output_attention)
        
    def forward(self, image, question):
        #image_embedding -> 1024x196
        image_embedding = self.vgg(image)
        
        #question_embedding -> 1024x1 
        question_embedding = self.lstm(question)
        
        #should return answer distribution 1000x1 
        return self.attention(image_embedding, question_embedding)