In [1]:
import torch 
import torchvision
import os
import operator
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
import pandas as pd

from PIL import Image
from collections import Counter
from string import punctuation
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset


In [2]:
def prepare_caption(caption):
    caption = ''.join([char for char in caption if not char in punctuation]).lower()
    return caption
prepare_caption('My name is Nancy.')

'my name is nancy'

In [3]:
transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [4]:
class Flickr30k(Dataset):
    
    def __init__(self, root_dir, csv_file, transform=None, topk=5000):
        self.df = pd.read_csv(os.path.join(root_dir, csv_file), delimiter='|')
        self.df.iloc[19999][' comment_number'] = ' 4'
        self.df.iloc[19999][' comment'] = ' A dog runs across the grass .'
        self.captions = {}
        self.vocab = Counter()
        for idx, row in self.df.iterrows():
            img_path = row['image_name']
            caption = prepare_caption(row[' comment'])
            if img_path not in self.captions:
                self.captions.update({img_path: [caption]})
            else:
                self.captions[img_path].append(caption)
                
            for word in caption.split():
                self.vocab[word] += 1
                
        self.transform = transform
        self.root_dir = root_dir
        self.topk = topk
        self.word2index = {word: index for index, (word, count) in enumerate(sorted(self.vocab.items(), key=operator.itemgetter(1), reverse=True)[:topk])}
        self.index2word = {index: word for index, (word, count) in enumerate(sorted(self.vocab.items(), key=operator.itemgetter(1), reverse=True)[:topk])}
    
    def __len__(self):
        return len(self.captions)
    
    def __getitem__(self, x):
        img_name = self.df.iloc[x * 5, 0]
        img = Image.open(os.path.join(self.root_dir, 'flickr30k_images', img_name))
        caption = sorted(self.captions[img_name], key=len)[-1]
        caption_encoded = []
        for word in caption.split():
            if word not in self.word2index:
                caption_encoded.append(self.topk)
            else:
                caption_encoded.append(self.word2index[word])
        
        if self.transform:
            img = self.transform(img)
        return img, caption, caption_encoded

In [5]:
def pad_seq(tensors):
    seq_len = max([tensor.shape[0] for tensor in tensors])
    for i in range(len(tensors)):
        if tensors[i].shape[0] < seq_len:
            tensors[i] = torch.cat([tensors[i], torch.zeros(seq_len - tensors[i].shape[0])], dim=-1)
    return tensors

In [6]:
def collate_fn(batch):
    imgs = [example[0] for example in batch]
    captions = [example[1] for example in batch]
    captions_encoded = [torch.Tensor(example[2]) for example in batch]
    
    imgs = torch.stack(imgs, dim=0)
    captions_encoded = pad_seq(captions_encoded)
    captions_encoded = torch.stack(captions_encoded, dim=0)
    
    return imgs, captions, captions_encoded

In [7]:
dataset = Flickr30k('/mnt/c/Users/MAX/Downloads/flickr30k_images', 'results.csv', transform=transform)

In [8]:
data_loader = DataLoader(dataset, batch_size=2, drop_last=True, collate_fn=collate_fn)

In [9]:
torch.zeros((1, 2, 3), device=torch.device('cpu'))

tensor([[[0., 0., 0.],
         [0., 0., 0.]]])

In [None]:
for idx, data in enumerate(data_loader):
    print(idx, data[0].shape, data[2].shape)

In [2]:
class Identity(nn.Module):
    
    def forward(self, x):
        return x

In [15]:
class Encoder(nn.Module):
    
    def __init__(self):
        super(Encoder, self).__init__()
    
        self.model = torchvision.models.resnet18(pretrained=False)
        
    
    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        
        return x

In [13]:
class Attention(nn.Module):
    
    def __init__(self, features_dim, hidden_dim):
        super(Attention, self).__init__()
        
        self.linear1 = nn.Linear(features_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)
        
    def forward(self, features, hidden):
        features = self.linear1(features)
        hidden = hidden[0]
        
        hidden = self.linear2(hidden)
        
        score = torch.nn.functional.tanh(hidden + features)
        
        attention_weights = torch.nn.functional.softmax(self.linear3(score), dim=1)
        
        attention_vectors = attention_weights * score
        attention_vectors = torch.mean(attention_vectors, dim=1)
        
        return attention_vectors

In [37]:
class Decoder(nn.Module):
    
    def __init__(self, hidden_dim, embedding_dim, num_layers, vocab_len):
        super(Decoder, self).__init__()
        
        self.embeddings = nn.Embedding(vocab_len, embedding_dim)
        self.lstm = nn.LSTM(hidden_dim + embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_dim, vocab_len)
        self.attention = Attention(embedding_dim, hidden_dim)
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        
    def init_hidden(self, batch_size):
        if torch.cuda.is_available():
            hidden = (torch.zeros((batch_size, self.num_layers, self.hidden_dim), device=torch.device('cuda')),
                     torch.zeros((batch_size, self.num_layers, self.hidden_dim), device=torch.device('cuda'))
                     )
        else:
            hidden = (torch.zeros((self.num_layers, batch_size, self.hidden_dim), device=torch.device('cpu')),
                     torch.zeros((self.num_layers, batch_size, self.hidden_dim), device=torch.device('cpu'))
                     )
            
        return hidden
    
    def forward(self, x, enc_embed, hidden):
        embeds = self.embeddings(x)
        embeds = torch.unsqueeze(embeds, 1)
        attention_vectors = self.attention(enc_embed, hidden)
        attention_vectors = torch.unsqueeze(attention_vectors, 1)
        #print(attention_vectors.shape)
        x = torch.cat([embeds, attention_vectors], dim=-1)
        x, hidden = self.lstm(x, hidden)
        x = self.linear(x)
        return x, hidden

In [38]:
def train(encoder, decoder, loader):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(decoder.parameters())
    for idx, data in enumerate(loader):
        optimizer.zero_grad()
        
        inputs = data[0].float()
        batch_size = inputs.shape[0]
        
        enc_embeds = encoder(inputs)
        enc_embeds = enc_embeds.view(batch_size, encoder.model.inplanes, -1)
        labels = data[2].long()
        hidden = decoder.init_hidden(batch_size)
        print(hidden[0].shape)
        outputs = []
        for i in range(labels.shape[1]):
            output, hidden = decoder(labels[:, i], enc_embeds, hidden)
            outputs.append(output)
        outputs = torch.stack(outputs, dim=0)
        loss = criterion(outputs, labels)
        print(loss)
        loss.backward()
        optimizer.step()
        

In [39]:
encoder = Encoder()
decoder = Decoder(20, 49, 1, dataset.topk + 1)
train(encoder, decoder, data_loader)

torch.Size([1, 2, 20])


RuntimeError: The size of tensor a (2) must match the size of tensor b (512) at non-singleton dimension 1

In [56]:
batch_size = 1
seq_len = 10
embedding_dim = 49
hidden_dim = 10
num_layers = 1
vocab_len = 29
model = Decoder(hidden_dim, embedding_dim, num_layers, vocab_len)
attention = Attention(embedding_dim, hidden_dim)
encoder_output = torch.rand(batch_size, 512, 49)
#inputs from embeddings [batch_size, seq_len, embedding_dim + hidden_dim]
input_str = torch.rand(batch_size, seq_len, embedding_dim)
hidden = torch.rand(2, num_layers, batch_size, hidden_dim)

for i in range(input_str.shape[1]):
    
    attention_vector = attention(encoder_output, hidden)
    lstm_input = torch.unsqueeze(torch.cat([input_str[:, i, :], attention_vector], dim=-1), 1)
    output, hidden = model(lstm_input, hidden)
    print(hidden)
    

(tensor([[[-0.1206,  0.0362, -0.2240,  0.2395, -0.0155, -0.0090,  0.0532,
           0.1678,  0.1748,  0.0265]]], grad_fn=<ViewBackward>), tensor([[[-0.4015,  0.0666, -0.2668,  0.4081, -0.0563, -0.0465,  0.1942,
           0.5321,  1.1125,  0.0434]]], grad_fn=<ViewBackward>))
(tensor([[[-0.1229, -0.2185, -0.4701,  0.4513, -0.1097, -0.0544,  0.0641,
           0.0982,  0.2417, -0.0161]]], grad_fn=<ViewBackward>), tensor([[[-0.3627, -0.7442, -0.5807,  0.9857, -0.2122, -0.3104,  0.3643,
           0.7520,  0.9708, -0.0301]]], grad_fn=<ViewBackward>))
(tensor([[[-0.1978, -0.4159, -0.3750,  0.3909, -0.0475, -0.1323,  0.1368,
           0.1838,  0.1406,  0.0098]]], grad_fn=<ViewBackward>), tensor([[[-0.4194, -1.0675, -0.6118,  1.0504, -0.1487, -0.6193,  0.3856,
           0.7525,  0.7458,  0.0186]]], grad_fn=<ViewBackward>))
(tensor([[[-0.1297, -0.4165, -0.4629,  0.4339, -0.0909, -0.0684,  0.1993,
           0.1639,  0.0761, -0.0210]]], grad_fn=<ViewBackward>), tensor([[[-0.3372, -1.0228, -0

In [27]:
print(x.shape)
print(hidden)

torch.Size([1, 10, 29])
(tensor([[[ 0.1567, -0.3082,  0.2041, -0.2354,  0.0321, -0.0569,  0.0033,
          -0.0245,  0.1133,  0.1696]]], grad_fn=<ViewBackward>), tensor([[[ 0.4709, -0.5611,  0.3368, -0.3590,  0.0493, -0.1002,  0.0086,
          -0.0744,  0.4667,  0.2902]]], grad_fn=<ViewBackward>))
