In [1]:
import torch 
import torchvision
import os
import operator
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
import pandas as pd

from PIL import Image
from collections import Counter
from string import punctuation
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset


In [2]:
def prepare_caption(caption):
    caption = ''.join([char for char in caption if not char in punctuation]).lower()
    return caption
prepare_caption('My name is Nancy.')

'my name is nancy'

In [3]:
transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [4]:
class Flickr30k(Dataset):
    
    def __init__(self, root_dir, csv_file, transform=None, topk=5000):
        self.df = pd.read_csv(os.path.join(root_dir, csv_file), delimiter='|')
        self.df.iloc[19999][' comment_number'] = ' 4'
        self.df.iloc[19999][' comment'] = ' A dog runs across the grass .'
        self.captions = {}
        self.vocab = Counter()
        for idx, row in self.df.iterrows():
            img_path = row['image_name']
            caption = prepare_caption(row[' comment'])
            if img_path not in self.captions:
                self.captions.update({img_path: [caption]})
            else:
                self.captions[img_path].append(caption)
                
            for word in caption.split():
                self.vocab[word] += 1
                
        self.transform = transform
        self.root_dir = root_dir
        self.topk = topk
        self.word2index = {word: index for index, (word, count) in enumerate(sorted(self.vocab.items(), key=operator.itemgetter(1), reverse=True)[:topk])}
        self.index2word = {index: word for index, (word, count) in enumerate(sorted(self.vocab.items(), key=operator.itemgetter(1), reverse=True)[:topk])}
    
    def __len__(self):
        return len(self.captions)
    
    def __getitem__(self, x):
        img_name = self.df.iloc[x * 5, 0]
        img = Image.open(os.path.join(self.root_dir, 'flickr30k_images', img_name))
        caption = sorted(self.captions[img_name], key=len)[-1]
        caption_encoded = []
        for word in caption.split():
            if word not in self.word2index:
                caption_encoded.append(self.topk)
            else:
                caption_encoded.append(self.word2index[word])
        
        if self.transform:
            img = self.transform(img)
        return img, caption, caption_encoded

In [5]:
def pad_seq(tensors):
    seq_len = max([tensor.shape[0] for tensor in tensors])
    for i in range(len(tensors)):
        if tensors[i].shape[0] < seq_len:
            tensors[i] = torch.cat([tensors[i], torch.zeros(seq_len - tensors[i].shape[0])], dim=-1)
    return tensors

In [6]:
pad_seq([torch.Tensor([1, 2]), torch.Tensor([1, 3, 5])])

[tensor([1., 2., 0.]), tensor([1., 3., 5.])]

In [7]:
def collate_fn(batch):
    imgs = [example[0] for example in batch]
    captions = [example[1] for example in batch]
    captions_encoded = [torch.Tensor(example[2]) for example in batch]
    
    imgs = torch.stack(imgs, dim=0)
    captions_encoded = pad_seq(captions_encoded)
    captions_encoded = torch.stack(captions_encoded, dim=0)
    
    return imgs, captions, captions_encoded

In [8]:
dataset = Flickr30k('/mnt/c/Users/MAX/Downloads/flickr30k_images', 'results.csv', transform=transform)

In [9]:
data_loader = DataLoader(dataset, batch_size=16, drop_last=True, collate_fn=collate_fn)

In [10]:
class Identity(nn.Module):
    
    def forward(self, x):
        return x

In [11]:
class Encoder(nn.Module):
    
    def __init__(self):
        super(Encoder, self).__init__()
    
        self.model = torchvision.models.resnet18(pretrained=True)
        
    
    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        
        return x

In [19]:
class Attention(nn.Module):
    
    def __init__(self, features_dim, hidden_dim):
        super(Attention, self).__init__()
        
        self.linear1 = nn.Linear(features_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)
        
    def forward(self, features, hidden):
        features = self.linear1(features)
        hidden = hidden[0]
        
        hidden = self.linear2(hidden)
        #rint(features.shape)
        score = torch.nn.functional.tanh(hidden.permute(1, 0, 2) + features)
        
        attention_weights = torch.nn.functional.softmax(self.linear3(score), dim=1)
        
        attention_vectors = attention_weights * score
        attention_vectors = torch.mean(attention_vectors, dim=1)
        
        return attention_vectors

In [20]:
class Decoder(nn.Module):
    
    def __init__(self, hidden_dim, embedding_dim, num_layers, vocab_len):
        super(Decoder, self).__init__()
        
        self.embeddings = nn.Embedding(vocab_len, embedding_dim)
        self.lstm = nn.LSTM(hidden_dim + embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_dim, vocab_len)
        self.attention = Attention(embedding_dim, hidden_dim)
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        
    def init_hidden(self, batch_size):
        if torch.cuda.is_available():
            hidden = (torch.zeros((batch_size, self.num_layers, self.hidden_dim), device=torch.device('cuda')),
                     torch.zeros((batch_size, self.num_layers, self.hidden_dim), device=torch.device('cuda'))
                     )
        else:
            hidden = (torch.zeros((self.num_layers, batch_size, self.hidden_dim), device=torch.device('cpu')),
                     torch.zeros((self.num_layers, batch_size, self.hidden_dim), device=torch.device('cpu'))
                     )
            
        return hidden
    
    def forward(self, x, enc_embed, hidden):
        embeds = self.embeddings(x)
        embeds = torch.unsqueeze(embeds, 1)
        #print(embeds.shape)
        attention_vectors = self.attention(enc_embed, hidden)
        #print(attention_vectors.shape)
        attention_vectors = torch.unsqueeze(attention_vectors, 1)
        x = torch.cat([embeds, attention_vectors], dim=-1)
        x, hidden = self.lstm(x, hidden)
        x = self.linear(x)
        return x, hidden

In [21]:
def loss_with_mask(pred_vec, target_vec, loss_func):
    indices = torch.nonzero(target_vec)
    pred_vec = torch.squeeze(pred_vec[indices])
    target_vec = torch.squeeze(target_vec[indices])
    return loss_func(pred_vec, target_vec)

In [26]:
def train(encoder, decoder, loader, num_epochs):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(decoder.parameters())
    for epoch in range(num_epochs):
        for idx, data in enumerate(loader):
            optimizer.zero_grad()
        
            inputs = data[0].float()
            batch_size = inputs.shape[0]
        
            enc_embeds = encoder(inputs)
            enc_embeds = enc_embeds.view(batch_size, encoder.model.inplanes, -1)
            labels = data[2].long()
            hidden = decoder.init_hidden(batch_size)
        
            outputs = []
            for i in range(labels.shape[1]):
                output, hidden = decoder(labels[:, i], enc_embeds, hidden)
                outputs.append(output)
            outputs = torch.cat(outputs, dim=1)
            outputs = outputs.view(-1, dataset.topk + 1)
            labels = labels.view(-1)
            loss = loss_with_mask(outputs, labels, criterion)
            if idx % 5 == 4:
                print("Num steps: ", idx + 1, " Loss: ", loss)
            loss.backward()
            optimizer.step()
        

In [27]:
encoder = Encoder()
decoder = Decoder(50, 49, 1, dataset.topk + 1)
train(encoder, decoder, data_loader, 5)

Num steps:  5  Loss:  tensor(8.4885, grad_fn=<NllLossBackward>)
Num steps:  10  Loss:  tensor(8.4329, grad_fn=<NllLossBackward>)
Num steps:  15  Loss:  tensor(8.3629, grad_fn=<NllLossBackward>)
Num steps:  20  Loss:  tensor(8.2771, grad_fn=<NllLossBackward>)
Num steps:  25  Loss:  tensor(8.1355, grad_fn=<NllLossBackward>)
Num steps:  30  Loss:  tensor(8.0506, grad_fn=<NllLossBackward>)
Num steps:  35  Loss:  tensor(7.7341, grad_fn=<NllLossBackward>)
Num steps:  40  Loss:  tensor(7.4051, grad_fn=<NllLossBackward>)
Num steps:  45  Loss:  tensor(7.0200, grad_fn=<NllLossBackward>)
Num steps:  50  Loss:  tensor(6.6714, grad_fn=<NllLossBackward>)
Num steps:  55  Loss:  tensor(6.4411, grad_fn=<NllLossBackward>)
Num steps:  60  Loss:  tensor(6.3063, grad_fn=<NllLossBackward>)
Num steps:  65  Loss:  tensor(6.0539, grad_fn=<NllLossBackward>)
Num steps:  70  Loss:  tensor(6.3910, grad_fn=<NllLossBackward>)
Num steps:  75  Loss:  tensor(6.2321, grad_fn=<NllLossBackward>)
Num steps:  80  Loss:  ten

KeyboardInterrupt: 

In [16]:
def evaluate(inputs, labels):
    batch_size = inputs.shape[0]
    enc_embeds = encoder(inputs)
    enc_embeds = enc_embeds.view(batch_size, encoder.model.inplanes, -1)
    hidden = decoder.init_hidden(batch_size)
    
    #print(hidden[0].shape)
    outputs = []
    for i in range(labels.shape[1]):
        output, hidden = decoder(labels[:, i], enc_embeds, hidden)
        outputs.append(output)
    outputs = torch.cat(outputs, dim=1)
    outputs = outputs.view(-1, dataset.topk + 1)
    outputs = nn.functional.softmax(outputs, dim=-1)
    words_idx = torch.argmax(outputs, dim=-1).numpy()
    print(words_idx)
    words = ' '.join([dataset.index2word[word_idx] if word_idx < 5000 else "Unknown" for word_idx in words_idx])
    return outputs, words

In [20]:
val_data_loader = DataLoader(dataset, batch_size=1, drop_last=True, collate_fn=collate_fn)

In [21]:
iterator = iter(val_data_loader)
inputs, _, labels = iterator.next()
#print(inputs.shape[0])
print(evaluate(inputs, labels.long())[1])
print(labels)

torch.Size([1, 1, 50])
Attn:  torch.Size([1, 1, 50])
Embeds:  torch.Size([1, 1, 49])
torch.Size([1, 1, 50])
Attn:  torch.Size([1, 1, 50])
Embeds:  torch.Size([1, 1, 49])
torch.Size([1, 1, 50])
Attn:  torch.Size([1, 1, 50])
Embeds:  torch.Size([1, 1, 49])
torch.Size([1, 1, 50])
Attn:  torch.Size([1, 1, 50])
Embeds:  torch.Size([1, 1, 49])
torch.Size([1, 1, 50])
Attn:  torch.Size([1, 1, 50])
Embeds:  torch.Size([1, 1, 49])
torch.Size([1, 1, 50])
Attn:  torch.Size([1, 1, 50])
Embeds:  torch.Size([1, 1, 49])
torch.Size([1, 1, 50])
Attn:  torch.Size([1, 1, 50])
Embeds:  torch.Size([1, 1, 49])
torch.Size([1, 1, 50])
Attn:  torch.Size([1, 1, 50])
Embeds:  torch.Size([1, 1, 49])
torch.Size([1, 1, 50])
Attn:  torch.Size([1, 1, 50])
Embeds:  torch.Size([1, 1, 49])
torch.Size([1, 1, 50])
Attn:  torch.Size([1, 1, 50])
Embeds:  torch.Size([1, 1, 49])
torch.Size([1, 1, 50])
Attn:  torch.Size([1, 1, 50])
Embeds:  torch.Size([1, 1, 49])
torch.Size([1, 1, 50])
Attn:  torch.Size([1, 1, 50])
Embeds:  tor