### Libraries

In [44]:
from more_itertools import consume
from pycocotools.coco import COCO
from collections import Counter
from threading import Thread
from PIL import Image
from tqdm import tqdm
import numpy as np
import nltk
import os

from torch.utils.tensorboard import SummaryWriter
import torchvision
import torch

nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/vokerlee/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [96]:
# !unzip experiment.zip -d .
%reload_ext tensorboard
%load_ext tensorboard
%tensorboard --logdir ./experiment/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 23224), started 11 days, 23:01:55 ago. (Use '!kill 23224' to kill it.)

In [14]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

### Dataset & general parameters

#### Download dataset, resize images

Скачивание dataset'а будет очень долгим, поэтому код ниже комменчен. 

In [15]:
# Uncomment to load datasets
# !wget http://images.cocodataset.org/zips/train2014.zip
# !wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip

# !unzip train2014.zip
# !unzip annotations_trainval2014.zip

In [16]:
# !rm train2014.zip
# !rm annotations_trainval2014.zip

Для единообразия и экономии времени сожмём картинки до более худшего качества.

Создадим на каждую картинку отдельный thread, который сожмёт картинку.

In [17]:
!rm -rf resized_images/*
!mkdir -p resized_images/

/bin/bash: line 1: /bin/rm: Argument list too long


In [45]:
img_path_to_resize = 'train2014'
images_to_resize = os.listdir(img_path_to_resize)

# Function just for future tests & training, not for resizing
def load_image(image_path, transform=None, basewidth=128):
    img = Image.open(image_path).convert('RGB')
    img = img.resize([basewidth, basewidth], Image.LANCZOS)

    if transform is not None:
        img = transform(img).unsqueeze(0)

    return img

def resize_image(img_name, basewidth=128):
    img = Image.open(img_path_to_resize + '/' + img_name)
    img = img.resize((basewidth, basewidth), Image.LANCZOS)
    img.save('resized_images/' + img_name)

img_load_threads = []
for img_name in tqdm(images_to_resize):
    thread = Thread(target=resize_image, kwargs={'img_name': img_name})
    thread.start()
    img_load_threads.append(thread)

consume(map(lambda t: t.join(), tqdm(img_load_threads)))

100%|██████████| 82783/82783 [01:14<00:00, 1106.07it/s]
100%|██████████| 82783/82783 [00:00<00:00, 1114822.58it/s]


#### Filling vocabulary using extracted tokens from annotations

На будущее создадим словарь "первичного" embedding'а максимально простым способом: присваивая новым словая последовательные индексы, начиная с нуля.

Для обозначения старта и конца предложения добавим сразу в словарь `<start>` и `<end>`, а также `<pad>` и `<unk>` для выравниваний и неизвестых слов соответственно.

In [66]:
class Vocabulary(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.n_idx = 0

        self.add_word('<pad>')
        self.add_word('<start>')
        self.add_word('<end>')
        self.add_word('<unk>')

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.n_idx
            self.idx2word[self.n_idx] = word
            self.n_idx += 1

    def restore_caption(self, caption_idxs):
        caption = []

        for word_idx in caption_idxs:
            word = self.idx2word[word_idx]
            if (word != '<pad>') and (word != '<start>') and \
               (word != '<unk>') and (word != '<end>'):
                caption.append(word)

            if word == '<end>':
                break

        return caption

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']

        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

Заполним словарь словами из датасета COCO, отсекая наименее часто встречающиеся слова.

In [67]:
def build_vocab(json, freq_threshold):
    """Build a simple vocabulary wrapper."""
    coco = COCO(json)
    counter = Counter() # map for counting unique words
    ids = coco.anns.keys()

    for id in ids:
        caption = str(coco.anns[id]['caption'])
        tokens = nltk.tokenize.word_tokenize(caption.lower())
        counter.update(tokens)

    # If the word frequency is less than 'freq_threshold', then the word is discarded.
    # So only the most popular words participate in vocabulary -> models.
    words = [word for word, cnt in counter.items() if cnt >= freq_threshold]

    vocab = Vocabulary()

    for word in words:
        vocab.add_word(word)

    return vocab

In [68]:
json_path = 'annotations/captions_train2014.json'
vocab = build_vocab(json=json_path, freq_threshold=16)

loading annotations into memory...
Done (t=0.24s)
creating index...
index created!


#### Creation of torch-dataloader from annotations and images

Создадим обёртку над `torch.utils.data.Dataset`'ом, чтобы абстрагироваться от формата данных, с помощью которых происходит обучение.

Наиболее гланый метод в `CocoDataset` — `__getitem__`, который нужен для функциии `collate_fn` для создания batch'а image-caption.

In [49]:
class CocoDataset(torch.utils.data.Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self, root, json, vocab, transform):
        """Set the path for images, captions and vocabulary wrapper.

        Args:
            root: image directory.
            json: coco annotation file path.
            vocab: vocabulary wrapper.
            transform: image transformer.
        """
        self.root = root
        self.coco = COCO(json)
        self.ids = list(self.coco.anns.keys())
        self.vocab = vocab
        self.transform = transform

    # Generally, for collate_fn() function (see below).
    # Return tensors of image and its caption.
    def __getitem__(self, index):
        """Returns one data pair (image and caption).
        Args:
            index: index in image-caption dataset
        """
        # Get raw image & caption for input index
        annotation_id = self.ids[index]
        caption = self.coco.anns[annotation_id]['caption']
        image_id = self.coco.anns[annotation_id]['image_id']

        path = self.coco.loadImgs(image_id)[0]['file_name']
        image = Image.open(os.path.join(self.root, path)).convert('RGB')

        # Convert caption to vector (using vocabulary)
        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))

        return self.transform(image), torch.Tensor(caption)

    def __len__(self):
        return len(self.ids)

Теперь, имея класс датасета, можем себе позволить организовать генерацию `DataLoader`.

Для этого следует перегрузить дефолтную реализацию `collate_fn`, чтобы обеспечить корректное создание batch'а: caption'ы должны склеиться с учётом padding'а.

In [50]:
def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (image, caption).

    We should build custom collate_fn rather than using default collate_fn,
    because merging caption (including padding) is not supported in default.
    Args:
        data: list of tuple (image, caption).
            - image: torch tensor of shape (3, 128, 128).
            - caption: torch tensor of shape (?); variable length.
    Returns:
        images: torch tensor of shape (batch_size, 3, 128, 128).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list indicating valid length for each caption.
    """
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor) within batch dimension
    images = torch.stack(images, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor) within batch dimension
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()

    # Filling zero-tensor 2D with captions
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]

    return images, targets, lengths

def generate_data_loader(root, json, vocab, transform, batch_size, shuffle, num_workers):
    """Returns torch.utils.data.DataLoader for custom COCO dataset."""
    # COCO caption dataset
    coco = CocoDataset(root=root,
                       json=json,
                       vocab=vocab,
                       transform=transform)

    # This will return (images, captions, lengths) for each iteration.
    # images: a tensor of shape (batch_size, 3, 128, 128).
    # captions: a tensor of shape (batch_size, padded_length).
    # lengths: a list indicating valid length for each caption. len(lengths) == batch_size
    data_loader = torch.utils.data.DataLoader(dataset=coco,
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn)
    return data_loader

### Annotation generation model

В качестве модели encoder будем использовать предобученую модель `resnet152` без последнего слоя.

В качестве модели decoder возьмём `LSTM` (вместе с `Embedding` слоем).

In [84]:
class EncoderCNN(torch.nn.Module):
    def __init__(self, embed_size):
        """Load the pretrained ResNet-152 and replace top fc layer."""
        super(EncoderCNN, self).__init__()

        resnet = torchvision.models.resnet152(weights=torchvision.models.ResNet152_Weights.DEFAULT)
        modules = list(resnet.children())[:-1] # remove the last layer

        self.resnet = torch.nn.Sequential(*modules)
        self.linear = torch.nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = torch.nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        """Extract feature vectors from input images."""
        with torch.no_grad():
            features = self.resnet(images)

        features = features.reshape(features.size(0), -1)
        return self.bn(self.linear(features))

In [85]:
class DecoderRNN(torch.nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_sent_length=20):
        """Set the hyper-parameters and build the layers."""
        super(DecoderRNN, self).__init__()

        self.embed = torch.nn.Embedding(vocab_size, embed_size)
        self.lstm = torch.nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = torch.nn.Linear(hidden_size, vocab_size) # interpretation: probabilities for each word in vocabulary
        self.max_sent_length = max_sent_length

    def forward(self, features, captions, lengths):
        """Decode image feature vectors and generates captions."""
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        packed = torch.nn.utils.rnn.pack_padded_sequence(embeddings, lengths, batch_first=True)
        hiddens, _ = self.lstm(packed)

        return self.linear(hiddens[0])

    def sample(self, features, states=None):
        """Generate captions for given image features using greedy search."""
        sampled_idxs = []
        inputs = features.unsqueeze(1)

        for _ in range(self.max_sent_length):
            # Restore word for batch                             # inputs: (batch_size, 1, embed_size)
            hiddens, states = self.lstm(inputs, states)          # hiddens: (batch_size, 1, hidden_size)
            outputs = self.linear(hiddens.squeeze(1))            # outputs:  (batch_size, vocab_size)
            _, predicted = outputs.max(1)                        # predicted: (batch_size)

            # Save predicted words for batcn
            sampled_idxs.append(predicted)

            # Get new input for next prediction
            inputs = self.embed(predicted)                       # inputs: (batch_size, embed_size)
            inputs = inputs.unsqueeze(1)                         # inputs: (batch_size, 1, embed_size)

        return torch.stack(sampled_idxs, 1)                      # sampled_idxs: (batch_size, max_sent_length)

### Training code

In [91]:
image_dir = 'resized_images/'
caption_path = json_path
batch_size = 128
num_workers = 2

embed_size = 256
hidden_size = 512
num_layers = 1

criterion = torch.nn.CrossEntropyLoss()

num_epochs = 6
log_step = 10

writer = SummaryWriter(log_dir='experiment/')

In [87]:
# For image transformation -> tensor
transform = torchvision.transforms.Compose([
                torchvision.transforms.RandomCrop(64),
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.485, 0.456, 0.406),
                                                 (0.229, 0.224, 0.225))]) # officially used values by pytorch for image normalization

# Build data loader
data_loader = generate_data_loader(image_dir, caption_path, vocab,
                                   transform, batch_size,
                                   shuffle=True, num_workers=num_workers)

# Build the models
encoder = EncoderCNN(embed_size).to(device)
decoder = DecoderRNN(embed_size, hidden_size, len(vocab), num_layers).to(device)

params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=0.001)

total_step = len(data_loader)

loading annotations into memory...
Done (t=0.24s)
creating index...
index created!


Выберем картинку, за которой будем поэтапно следить во время обучения (точнее, за её аннотациями).

In [88]:
simple_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

test_img_name = 'resized_images/' + os.listdir('resized_images/')[42]
test_image = load_image(test_img_name, transform)
test_image_tensor = test_image.to(device)
writer.add_image(f'Image {test_img_name}', load_image(test_img_name, transform=simple_transform), dataformats='NCHW')

Теперь запустим само обучение.

In [92]:
for epoch in tqdm(range(num_epochs), leave=False):
    for i, (images, captions, lengths) in enumerate(tqdm(data_loader, leave=False)):

        encoder.train()
        decoder.train()

        # Set mini-batch dataset
        images = images.to(device)
        captions = captions.to(device)
        targets = torch.nn.utils.rnn.pack_padded_sequence(captions, lengths, batch_first=True)[0]

        # Forward, backward and optimize
        features = encoder(images)
        outputs = decoder(features, captions, lengths)

        loss = criterion(outputs, targets)

        decoder.zero_grad()
        encoder.zero_grad()
        loss.backward()
        optimizer.step()

        # Print log info
        if i % log_step == 0:
            writer.add_scalar('LOSS/train', loss.item(), i + total_step * epoch)
            encoder.eval()
            decoder.eval()

            with torch.no_grad():
                feature = encoder(test_image_tensor)
                sampled_idxs = decoder.sample(feature)
                sampled_idxs = sampled_idxs[0].cpu().numpy() # [0] because of batch with batcn_size=1

                sampled_caption = vocab.restore_caption(sampled_idxs)

                sentence = ' '.join(sampled_caption)
                writer.add_text('Captions', sentence, i + total_step * epoch)

                                              

### Conclusions

По итогу всё получилось примерно так, как и ожидалось. На картинке стоит человек около около прилавок с фруктами (апельсины и яблоки).

Модель на последних эпохах утверждает, что на картинке стоит человек в костюме напротив каких-либо объектов, что максимально похоже на правду, с учётом того, что на сжатой картинке детализация прилавок сильно пострадала.