In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (20, 20)

import os
import json
import torch 
from PIL import Image
import numpy as np
import pandas as pd
from collections import Counter
from string import punctuation
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim import Adam
from tqdm import tqdm_notebook as tqdm
import torchvision.models as models
from torch.nn.utils.rnn import pack_padded_sequence
from torch import nn

In [None]:
def remove_punctuation(input_string):
    return input_string.translate(str.maketrans('', '', punctuation))

# coco

In [None]:
coco_image_path = '/home/jupyter/datasets/coco/images/val2014/'
image_paths = [coco_image_path + image_id for image_id in os.listdir(coco_image_path)]
image_ids = [path.split('/')[-1].split('.')[0] for path in image_paths]

In [None]:
with open('datasets/coco/captions_val2014.json') as f:
    meta = json.load(f)

In [None]:
df = (pd.merge(pd.DataFrame(meta['images']).set_index('id'),
               pd.DataFrame(meta['annotations']).set_index('image_id'), 
               left_index=True, right_index=True)
      .reset_index()
      [['caption', 'file_name']]
     )

df['file_name'] = coco_image_path + df['file_name']

In [None]:
len(df)

# build vocabulary

In [None]:
df['caption'] = (df['caption']
                 .apply(lambda x: ''.join([c for c in x if c.isalpha() or c.isspace()]))
                 .apply(str.lower)
                 .apply(lambda x: ' '.join(x.split()))
                )

In [None]:
all_text = ' '.join(df['caption'].values).split()

In [None]:
class Vocabulary(object):
    def __init__(self):
        self.word_to_index = {}
        self.index_to_word = {}
        self.index = 0

    def add_word(self, word):
        if not word in self.word_to_index:
            self.word_to_index[word] = self.index
            self.index_to_word[self.index] = word
            self.index += 1

    def __call__(self, word):
        if not word in self.word_to_index:
            return self.word_to_index['<unk>']
        return self.word_to_index[word]

    def __len__(self):
        return len(self.word_to_index)

In [None]:
vocabulary = Vocabulary()
vocabulary.add_word('<pad>')
vocabulary.add_word('<start>')
vocabulary.add_word('<unk>')
vocabulary.add_word('<end>')

for word, count in Counter(all_text).items():
    vocabulary.add_word(word)

# train test split

In [None]:
mask = np.random.rand(len(df)) < 0.8
train_df, test_df = df[mask], df[~mask]

len(train_df), len(test_df)

# dataset

In [None]:
class CaptionsDataset(Dataset):
    def __init__(self, path_df, vocab, transform=transforms.ToTensor()):
        self.ids = path_df.index.values
        self.image_paths = path_df['file_name'].values
        self.titles = path_df['caption'].values
        self.vocab = vocab
        self.transform = transform

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index]).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        caption = '<start> ' + self.titles[index] + ' <end>'
        tokens = [self.vocab(token) for token in caption.split()]
        target = torch.Tensor(tokens)
        return image, target

    def __len__(self):
        return len(self.ids)

In [None]:
transform = transforms.Compose([transforms.RandomResizedCrop(224, scale=[0.5, 0.9]),
                                transforms.RandomHorizontalFlip(),
                                transforms.RandomGrayscale(0.8),
                                transforms.ToTensor()])

In [None]:
train_dataset = CaptionsDataset(train_df, vocabulary, transform=transform)
test_dataset = CaptionsDataset(test_df, vocabulary, transform=transform)

In [None]:
train_dataset.__getitem__(0)

# dataloader
with custom `collate_fn` (allowing for variable-length padding on captions)

In [None]:
def collate_fn(data):
    data.sort(key=lambda x: len(x[1]), reverse=True)  # descending order
    images, captions = zip(*data)
    lengths = [len(caption) for caption in captions]
    max_len = max(lengths)

    images = torch.stack(images, 0)
    
    targets = torch.zeros(len(captions), max_len).long()
    for i, caption in enumerate(captions):
        targets[i, :len(caption)] = caption

    return images, targets, lengths

In [None]:
batch_size = 128

train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=5,
                          collate_fn=collate_fn)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size,
                         num_workers=5,
                         collate_fn=collate_fn)

# fasttext vectors

In [None]:
fasttext = {}

with open('datasets/wiki.en.vec', encoding='utf-8') as f:
    for line in tqdm(f.readlines()[1:100000]):
        line = line.split()
        word, vector = ' '.join(line[:-300]), np.array(line[-300:]).astype(np.float32)
        fasttext[word.lower()] = vector

In [None]:
wv_max = np.abs(list(fasttext.values())).max()

for word in tqdm(fasttext.keys()):
    fasttext[word] = fasttext[word] / wv_max

wv_mean = np.array(list(fasttext.values())).mean(axis=0)

In [None]:
fasttext['<pad>'] = np.zeros((300,))
fasttext['<start>'] = np.zeros((300,))
fasttext['<unk>'] = np.zeros((300,))
fasttext['<end>'] = np.zeros((300,))

fasttext['<pad>'][0] = 1
fasttext['<start>'][1] = 1
fasttext['<unk>'][2] = 1
fasttext['<end>'][3] = 1

In [None]:
all_fasttext_vectors = [fasttext[word] 
                        if word in fasttext else wv_mean
                        for word in vocabulary.index_to_word.values()]

all_fasttext_vectors = torch.Tensor(np.stack(all_fasttext_vectors))
all_fasttext_vectors.requires_grad = False

In [None]:
all_fasttext_vectors.shape

# model

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self, embedding_size):
        '''Load the pretrained ResNet-152 and replace top fc layer.'''
        super(EncoderCNN, self).__init__()
        resnet = models.resnet152(pretrained=True)
        for p in resnet.parameters():
            p.requires_grad = False
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embedding_size)
        self.bn = nn.BatchNorm1d(embedding_size, momentum=0.01)
        
    def forward(self, images):
        '''Extract feature vectors from input images.'''
        with torch.no_grad():
            features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features


class DecoderRNN(nn.Module):
    def __init__(self, embedding_size, hidden_size, vocab_size, n_layers, embedding_matrix, max_seq_length=20):
        '''Set the hyper-parameters and build the layers.'''
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding.from_pretrained(embedding_matrix)
        self.lstm = nn.LSTM(embedding_size, hidden_size, n_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_seg_length = max_seq_length
        
    def forward(self, features, captions, lengths):
        '''Decode image feature vectors and generates captions.'''
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True) 
        hiddens, _ = self.lstm(packed)
        outputs = self.linear(hiddens[0])
        return outputs
    
    def sample(self, features, states=None):
        '''Generate captions for given image features using greedy search.'''
        sampled_ids = []
        inputs = features.unsqueeze(1)
        for i in range(self.max_seg_length):
            hiddens, states = self.lstm(inputs, states)
            outputs = self.linear(hiddens.squeeze(1))
            _, predicted = outputs.max(1)
            sampled_ids.append(predicted)
            inputs = self.embed(predicted)
            inputs = inputs.unsqueeze(1)
        sampled_ids = torch.stack(sampled_ids, 1)
        return sampled_ids

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

embedding_size = 300
hidden_size = 500
n_layers = 1
learning_rate = 0.01

In [None]:
encoder = EncoderCNN(embedding_size).to(device)
decoder = DecoderRNN(embedding_size, hidden_size, len(vocabulary), n_layers, all_fasttext_vectors).to(device)

# train

In [None]:
def train_epoch(encoder, decoder, train_loader, epoch, 
                loss_function, optimiser, device=device):
    
    loop = tqdm(train_loader)
    
    for images, captions, lengths in loop:
        images = images.cuda(non_blocking=True)
        captions = captions.cuda(non_blocking=True)
        targets = pack_padded_sequence(captions, lengths, batch_first=True)[0].to(device)
        
        features = encoder(images)
        outputs = decoder(features, captions, lengths)
        loss = loss_function(outputs, targets)
        losses.append(loss.item())

        decoder.zero_grad()
        encoder.zero_grad()
        loss.backward()
        optimiser.step()

        loop.set_description('Epoch {}/{}'.format(epoch + 1, n_epochs))
        loop.set_postfix(loss=loss.item())

In [None]:
params = (list(filter(lambda p: p.requires_grad, decoder.parameters())) +
          list(filter(lambda p: p.requires_grad, encoder.parameters())))


n_epochs = 4
losses = []

In [None]:
for epoch in range(n_epochs):
    train_epoch(encoder=encoder, 
                decoder=decoder,
                train_loader=train_loader,
                loss_function=nn.CrossEntropyLoss(),
                optimiser=Adam(params, lr=learning_rate),
                epoch=epoch)

In [None]:
loss_data = pd.Series(losses).rolling(window=15).mean()
ax = loss_data.plot(subplots=True);

# sample result from coco

In [None]:
def load_image(image_path, transform=None):
    image = Image.open(image_path).convert('RGB')
    
    if transform is not None:
        image = transform(image).unsqueeze(0)
    
    return image

caption, path = test_df.sample().values[0]
print(caption)

img = load_image(path, transform=transform).to(device)

Image.fromarray((img.to('cpu').data.numpy() * 255)
                .astype(np.uint8)
                .reshape(3, 224, 224)
                .transpose(1, 2, 0))

In [None]:
features = encoder.eval()(img)
sampled_ids = decoder.sample(features)[0].cpu().numpy()

output_sentence_list = []
for index in sampled_ids:
    if index == 3: break
    if index == 1: pass
    else: output_sentence_list.append(vocabulary.index_to_word[index])

' '.join(output_sentence_list)


# sample result from wellcome

In [None]:
base_path = '/home/jupyter/datasets/small_images/'

wellcome_paths = [base_path + subdir + '/' + image_id
                  for subdir in os.listdir(base_path)
                  for image_id in os.listdir(base_path + subdir)]

wellcome_ids = [path.split('/')[-1].split('.')[0] 
                for path in wellcome_paths]

In [None]:
metadata = pd.read_json('/home/jupyter/datasets/works.json', lines=True)

In [None]:
metadata.index = metadata['identifiers'].apply(lambda x: x[0]['value']).rename('miro id')

In [None]:
path = np.random.choice(wellcome_paths)
img = load_image(path, transform=transform).to(device)

features = encoder.eval()(img)
sampled_ids = decoder.sample(features)
sampled_ids = sampled_ids[0].cpu().numpy()

output_sentence_list = []

for index in sampled_ids:
    if index == 3: break
    if index == 1: pass
    else: output_sentence_list.append(vocabulary.index_to_word[index])

print(' '.join(output_sentence_list))

Image.fromarray((img.to('cpu').data.numpy() * 255)
                .astype(np.uint8)
                .reshape(3, 224, 224)
                .transpose(1, 2, 0))