In [2]:
import torch.nn as nn
import torchvision
import torch
import torchtext
import matplotlib.pyplot as plt
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


## **Step1 Load the data**

In [3]:
import dataloader
creature_imgs, creature_captions =dataloader.get_torch_creature_data() 

## **Step2 Process the data**
Turn both the image and captions in to embeddings,\
For images we are using **VGG16** to obtain the features
For captions we are using **GloVe** embeddings

* GloVe embedding

In [58]:
creature_tokens = [i.replace(".", " . ").replace(",", " , ").replace(";", " ; ").replace("?", " ? ").lower().split() for i in creature_captions]


        Load glove embedding

In [59]:
glove = torchtext.vocab.GloVe(name="6B", dim=50)

        Get the mean of all embedding vectors, we will use this mean as the embedding of all unseen words

In [60]:
mean = glove.vectors.mean(dim=0)
glove_vector_with_unk = torch.cat((glove.vectors,mean.unsqueeze(0)))

In [61]:
vocab_size = len(glove_vector_with_unk)

        Get the vector representation of all sentences

In [62]:
## obtain the index of all words in the caption
emb_creature_captions =[] 
for caption in creature_tokens:
    emb_creature_captions.append(torch.Tensor([glove.stoi[i] if i in glove.stoi.keys() else 400000 for i in caption]).int())

In [63]:
## pad the captions
from torch.nn.utils.rnn import pad_sequence
padded_emb_creature_captions = pad_sequence(emb_creature_captions,batch_first=True)


In [64]:
## obtain the embedding
glove_emb = nn.Embedding.from_pretrained(glove_vector_with_unk)
target = glove_emb(padded_emb_creature_captions)


* VGG16

        Load VGG16 

In [4]:
vgg = torchvision.models.vgg16(pretrained=True)


        Get the features from VGG16

In [5]:
features = vgg.features(creature_imgs[0:5])

In [6]:
features.shape

torch.Size([5, 512, 7, 7])

In [169]:
transformer_layer = nn.Transformer(128,8,batch_first=True)  

In [182]:
inp = features
print(inp.shape)
print(target.shape)
transformer_layer(features,target)

torch.Size([5, 512, 7, 7])
torch.Size([5, 34, 50])


RuntimeError: the feature number of src and tgt must be equal to d_model

In [67]:
class PositionalEncoding(nn.Module):
    def __init__(self,dropout_rate, input_dimensions):
        super(PositionalEncoding,self).__init__()
        self.dropout = nn.Dropout(p=dropout_rate)
        
        positional_embedding = torch.zeros(input_dimensions,7)
        even = torch.arange(0,input_dimensions,2)
        odd = torch.arange(1,input_dimensions,2)
        position = torch.arange(7)
        denominator = torch.float_power(10000,even/input_dimensions)
        positional_embedding[0::2] = torch.sin(position.unsqueeze(0)/denominator.unsqueeze(1))
        positional_embedding[1::2] = torch.cos(position.unsqueeze(0)/denominator.unsqueeze(1))
        horizontal_positional_embedding = positional_embedding
        vertical_positional_embedding = positional_embedding
        self.positional_embedding = horizontal_positional_embedding.unsqueeze(1) + vertical_positional_embedding.unsqueeze(2)


    def forward(self,x):
        x = x + self.positional_embedding.unsqueeze(0)
        return self.dropout(x)       
        
        
        

In [8]:
features.shape

torch.Size([5, 512, 7, 7])

In [70]:
class caption_transformer(nn.Module):
    def __init__(self,num_heads):
        super(caption_transformer,self).__init__()
        self.cnn_emb = vgg.features
        self.cnn_layer = nn.Conv2d(512,128,1)
        self.word_emb = nn.Embedding.from_pretrained(glove_vector_with_unk)
        self.fc_layer = nn.Linear(50,128)      
        self.transformer_layer = nn.Transformer(128,num_heads,batch_first=True)  
        self.fc_layer2 = nn.Linear(128,50)
        self.positional_embedding = PositionalEncoding(0.1,128)
        
    def forward(self, inp, target):
        # embed the image
        emb_inp = torch.relu(self.cnn_layer(inp))
        # positional embedding
        emb_inp = torch.relu(self.positional_embedding(emb_inp))
        # embed the text
        emb_target = torch.relu(self.fc_layer(target))
        N = emb_inp.shape[0]
        dim = emb_inp.shape[1] 
        emb_inp = emb_inp.view(N,dim,-1)
        emb_inp = torch.transpose(emb_inp, 1,2)
        out = self.transformer_layer(emb_inp,emb_target)
        out = torch.relu(self.fc_layer2(out))
        return out
        

        
        
        

In [75]:
small_target = target[:5]


In [77]:
my_transformer =  caption_transformer(8)
out = my_transformer(features,small_target)


In [223]:
target[0][0]


tensor([ 0.2171,  0.4651, -0.4676,  0.1008,  1.0135,  0.7484, -0.5310, -0.2626,
         0.1681,  0.1318, -0.2491, -0.4419, -0.2174,  0.5100,  0.1345, -0.4314,
        -0.0312,  0.2067, -0.7814, -0.2015, -0.0974,  0.1609, -0.6184, -0.1850,
        -0.1246, -2.2526, -0.2232,  0.5043,  0.3226,  0.1531,  3.9636, -0.7136,
        -0.6701,  0.2839,  0.2174,  0.1443,  0.2593,  0.2343,  0.4274, -0.4445,
         0.1381,  0.3697, -0.6429,  0.0241, -0.0393, -0.2604,  0.1202, -0.0438,
         0.4101,  0.1796])

In [229]:
def print_closest_words(vec, n=5):
    dists = torch.norm(glove.vectors - vec, dim=1)     # compute distances to all words
    lst = sorted(enumerate(dists.numpy()), key=lambda x: x[1]) # sort by distance
    for idx, difference in lst[1:n+1]: 					       # take the top n
        return glove.itos[idx]

print_closest_words(out[0][0].detach().numpy(), n=1)


'http://www.mediabynumbers.com'

In [230]:
def get_word_from_vector(emb_word):
    dists = torch.norm(glove.vectors -emb_word, dim=1)
    return glove.itos[torch.argmax(dists)]

In [231]:
def get_sentences_from_vector(emb_vectors):
    sentences = []
    for sentence_vector in emb_vectors:
        sentences.append([get_word_from_vector(word) for word in sentence_vector])
    return sentences

In [237]:
target[0][0] in glove.vectors

True

In [234]:
get_sentences_from_vector(target[0].unsqueeze(0))

[['non-families',
  'non-families',
  '202-383-7824',
  'non-families',
  '202-383-7824',
  '202-383-7824',
  '202-383-7824',
  'non-families',
  '202-383-7824',
  'non-families',
  'non-families',
  '202-383-7824',
  '202-383-7824',
  '202-383-7824',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families',
  'non-families']]