In [15]:
import torch 
import torchvision
import os
import operator
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
import pandas as pd

from PIL import Image
from collections import Counter
from string import punctuation
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset


In [16]:
def prepare_caption(caption):
    caption = ''.join([char for char in caption if not char in punctuation]).lower()
    return caption
prepare_caption('My name is Nancy.')

'my name is nancy'

In [17]:
transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [18]:
class Flickr30k(Dataset):
    
    def __init__(self, root_dir, csv_file, transform=None, topk=5000):
        self.df = pd.read_csv(os.path.join(root_dir, csv_file), delimiter='|')
        self.df.iloc[19999][' comment_number'] = ' 4'
        self.df.iloc[19999][' comment'] = ' A dog runs across the grass .'
        self.captions = {}
        self.vocab = Counter()
        for idx, row in self.df.iterrows():
            img_path = row['image_name']
            caption = prepare_caption(row[' comment'])
            if img_path not in self.captions:
                self.captions.update({img_path: [caption]})
            else:
                self.captions[img_path].append(caption)
                
            for word in caption.split():
                self.vocab[word] += 1
                
        self.transform = transform
        self.root_dir = root_dir
        self.topk = topk
        self.word2index = {word: index for index, (word, count) in enumerate(sorted(self.vocab.items(), key=operator.itemgetter(1), reverse=True)[:topk])}
        self.index2word = {index: word for index, (word, count) in enumerate(sorted(self.vocab.items(), key=operator.itemgetter(1), reverse=True)[:topk])}
    
    def __len__(self):
        return len(self.captions)
    
    def __getitem__(self, x):
        img_name = self.df.iloc[x * 5, 0]
        img = Image.open(os.path.join(self.root_dir, 'flickr30k_images', img_name))
        caption = sorted(self.captions[img_name], key=len)[-1]
        caption_encoded = []
        for word in caption.split():
            if word not in self.word2index:
                caption_encoded.append(self.topk)
            else:
                caption_encoded.append(self.word2index[word])
        
        if self.transform:
            img = self.transform(img)
        return img, caption, caption_encoded

In [19]:
def pad_seq(tensors):
    seq_len = max([tensor.shape[0] for tensor in tensors])
    for i in range(len(tensors)):
        if tensors[i].shape[0] < seq_len:
            tensors[i] = torch.cat([tensors[i], torch.zeros(seq_len - tensors[i].shape[0])], dim=-1)
    return tensors

In [20]:
def collate_fn(batch):
    imgs = [example[0] for example in batch]
    captions = [example[1] for example in batch]
    captions_encoded = [torch.Tensor(example[2]) for example in batch]
    
    imgs = torch.stack(imgs, dim=0)
    captions_encoded = pad_seq(captions_encoded)
    captions_encoded = torch.stack(captions_encoded, dim=0)
    
    return imgs, captions, captions_encoded

In [21]:
dataset = Flickr30k('/mnt/c/Users/MAX/Downloads/flickr30k_images', 'results.csv', transform=transform)

In [22]:
data_loader = DataLoader(dataset, batch_size=2, drop_last=True, collate_fn=collate_fn)

In [23]:
for idx, data in enumerate(data_loader):
    print(idx, data[0].shape, data[2].shape)

1000092795.jpg
 two young guys with shaggy hair look at their hands while hanging out in the yard  [10, 17, 320, 8, 2103, 106, 182, 14, 59, 154, 21, 319, 69, 1, 2, 485]
10002456.jpg
 several men in hard hats are operating a giant pulley system  [115, 27, 1, 326, 270, 11, 1306, 0, 806, 3842, 2628]
0 torch.Size([2, 3, 224, 224]) torch.Size([2, 16])
1000268201.jpg
 a child in a pink dress is climbing up a set of stairs in an entry way  [0, 46, 1, 0, 83, 112, 6, 240, 45, 0, 355, 7, 400, 1, 15, 5000, 649]
1000344755.jpg
 someone in a blue shirt and hat is standing on stair and leaning against a window  [273, 1, 0, 23, 19, 4, 60, 6, 29, 3, 2831, 4, 360, 221, 0, 228]
1 torch.Size([2, 3, 224, 224]) torch.Size([2, 17])
1000366164.jpg
 two men  one in a gray shirt  one in a black shirt  standing near a stove  [10, 27, 39, 1, 0, 116, 19, 39, 1, 0, 20, 19, 29, 77, 0, 1392]
1000523639.jpg
 two people in the photo are playing the guitar and the other is poking at him  [10, 13, 1, 2, 356, 11, 31, 2, 

 a mottled black and gray dog in a blue collar jumping over a fallen tree  [0, 5000, 20, 4, 116, 30, 1, 0, 23, 735, 86, 66, 0, 883, 158]
28 torch.Size([2, 3, 224, 224]) torch.Size([2, 19])
101559400.jpg
 men in the business suits are crossing the street  and there are people with placards are gathering on the street  [27, 1, 2, 642, 525, 11, 458, 2, 32, 4, 146, 11, 13, 8, 5000, 11, 663, 3, 2, 32]
1015712668.jpg
 a man in a red longsleeved shirt bikes over a body of water on a bridge  [0, 5, 1, 0, 24, 1088, 19, 450, 66, 0, 295, 7, 40, 3, 0, 346]
29 torch.Size([2, 3, 224, 224]) torch.Size([2, 20])
10160966.jpg
 a barefooted man wearing olive green shorts grilling hotdogs on a small propane grill while holding a blue plastic cup  [0, 2473, 5, 16, 2633, 44, 118, 1291, 2368, 3, 0, 63, 5000, 592, 21, 38, 0, 23, 407, 524]
101654506.jpg
 the white and brown dog is running over the surface of the snow  [2, 18, 4, 56, 30, 6, 71, 66, 2, 990, 7, 2, 91]
30 torch.Size([2, 3, 224, 224]) torch.Size([2

52 torch.Size([2, 3, 224, 224]) torch.Size([2, 19])
103106960.jpg
 a young male kneeling in front of a hockey goal with a hockey stick in his right hand  [0, 17, 164, 632, 1, 35, 7, 0, 361, 748, 8, 0, 361, 293, 1, 22, 337, 128]
103195344.jpg
 the man with the backpack is sitting in a buildings courtyard in front of an art sculpture reading  [2, 5, 8, 2, 340, 6, 26, 1, 0, 452, 1316, 1, 35, 7, 15, 495, 798, 206]
53 torch.Size([2, 3, 224, 224]) torch.Size([2, 18])
1031973097.jpg
 a toddler in a blue shirt with red shorts and hat looks out from behind a fenced in area of a brick patio  [0, 328, 1, 0, 23, 19, 8, 24, 118, 4, 60, 94, 69, 58, 84, 0, 1383, 1, 167, 7, 0, 268, 1292]
103205630.jpg
 two men  standing on an ice  looking into something covered with a blue tarp  [10, 27, 29, 3, 15, 266, 47, 65, 113, 255, 8, 0, 23, 2305]
54 torch.Size([2, 3, 224, 224]) torch.Size([2, 23])
1032122270.jpg
 three dogs are standing in the grass and a person is sitting next to them [42, 99, 11, 29, 1, 2, 90

 a girl paddling down a large river  as seen from behind her  [0, 25, 1219, 34, 0, 51, 301, 48, 759, 58, 84, 36]
1053804096.jpg
 a girl with pigtails is playing in the ocean by the beach  [0, 25, 8, 1822, 6, 31, 1, 2, 225, 41, 2, 81]
80 torch.Size([2, 3, 224, 224]) torch.Size([2, 12])
1054620089.jpg


KeyboardInterrupt: 

In [2]:
class Identity(nn.Module):
    
    def forward(self, x):
        return x

In [3]:
class Encoder(nn.Module):
    
    def __init__(self):
        super(Encoder, self).__init__()
    
        self.model = torchvision.models.resnet18(pretrained=True)
        
    
    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        
        return x

In [46]:
class Attention(nn.Module):
    
    def __init__(self, features_dim, hidden_dim):
        super(Attention, self).__init__()
        
        self.linear1 = nn.Linear(features_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)
        
    def forward(self, features, hidden):
        features = self.linear1(features)
        hidden = hidden[0]
        
        hidden = self.linear2(hidden)
        
        score = torch.nn.functional.tanh(hidden + features)
        
        attention_weights = torch.nn.functional.softmax(self.linear3(score), dim=1)
        
        attention_vectors = attention_weights * score
        attention_vectors = torch.mean(attention_vectors, dim=1)
        
        return attention_vectors

In [39]:
class Decoder(nn.Module):
    
    def __init__(self, hidden_dim, embedding_dim, num_layers, vocab_len):
        super(Decoder, self).__init__()
        
        
        self.lstm = nn.LSTM(hidden_dim + embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_dim, vocab_len)
        
    def forward(self, x, hidden):
        x, hidden = self.lstm(x, hidden)
        x = self.linear(x)
        return x, hidden

In [None]:
def train(encoder, decoder, loader):
    for idx, data in enumerate(loader):
        inputs = data[0].cuda()
        enc_embeds = encoder(inputs)
        
        

In [56]:
batch_size = 1
seq_len = 10
embedding_dim = 49
hidden_dim = 10
num_layers = 1
vocab_len = 29
model = Decoder(hidden_dim, embedding_dim, num_layers, vocab_len)
attention = Attention(embedding_dim, hidden_dim)
encoder_output = torch.rand(batch_size, 512, 49)
#inputs from embeddings [batch_size, seq_len, embedding_dim + hidden_dim]
input_str = torch.rand(batch_size, seq_len, embedding_dim)
hidden = torch.rand(2, num_layers, batch_size, hidden_dim)

for i in range(input_str.shape[1]):
    
    attention_vector = attention(encoder_output, hidden)
    lstm_input = torch.unsqueeze(torch.cat([input_str[:, i, :], attention_vector], dim=-1), 1)
    output, hidden = model(lstm_input, hidden)
    print(hidden)
    

(tensor([[[-0.1206,  0.0362, -0.2240,  0.2395, -0.0155, -0.0090,  0.0532,
           0.1678,  0.1748,  0.0265]]], grad_fn=<ViewBackward>), tensor([[[-0.4015,  0.0666, -0.2668,  0.4081, -0.0563, -0.0465,  0.1942,
           0.5321,  1.1125,  0.0434]]], grad_fn=<ViewBackward>))
(tensor([[[-0.1229, -0.2185, -0.4701,  0.4513, -0.1097, -0.0544,  0.0641,
           0.0982,  0.2417, -0.0161]]], grad_fn=<ViewBackward>), tensor([[[-0.3627, -0.7442, -0.5807,  0.9857, -0.2122, -0.3104,  0.3643,
           0.7520,  0.9708, -0.0301]]], grad_fn=<ViewBackward>))
(tensor([[[-0.1978, -0.4159, -0.3750,  0.3909, -0.0475, -0.1323,  0.1368,
           0.1838,  0.1406,  0.0098]]], grad_fn=<ViewBackward>), tensor([[[-0.4194, -1.0675, -0.6118,  1.0504, -0.1487, -0.6193,  0.3856,
           0.7525,  0.7458,  0.0186]]], grad_fn=<ViewBackward>))
(tensor([[[-0.1297, -0.4165, -0.4629,  0.4339, -0.0909, -0.0684,  0.1993,
           0.1639,  0.0761, -0.0210]]], grad_fn=<ViewBackward>), tensor([[[-0.3372, -1.0228, -0

In [27]:
print(x.shape)
print(hidden)

torch.Size([1, 10, 29])
(tensor([[[ 0.1567, -0.3082,  0.2041, -0.2354,  0.0321, -0.0569,  0.0033,
          -0.0245,  0.1133,  0.1696]]], grad_fn=<ViewBackward>), tensor([[[ 0.4709, -0.5611,  0.3368, -0.3590,  0.0493, -0.1002,  0.0086,
          -0.0744,  0.4667,  0.2902]]], grad_fn=<ViewBackward>))
