# Генерация аннотации к изображению

## Предварительная работа

### Библиотеки

In [34]:
%pip install pycocotools
%pip install nltk


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [35]:
import io
import math
import nltk
import os
import requests
import torch

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from collections             import Counter
from copy                    import deepcopy
from matplotlib.image        import imread
from mpl_toolkits            import mplot3d
from matplotlib              import gridspec
from nerus                   import load_nerus
from skimage.segmentation    import mark_boundaries
from sklearn.metrics         import classification_report
from sklearn.model_selection import ParameterGrid
from torch.utils             import data
from torch.utils.tensorboard import SummaryWriter
from torchvision             import datasets, transforms, models
from tqdm.autonotebook       import tqdm
from PIL                     import Image
from pycocotools.coco        import COCO
from urllib.request          import urlopen


In [36]:
import warnings
warnings.filterwarnings("ignore")

nltk.download('punkt')


[nltk_data] Downloading package punkt to /home/panterrich/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

### Установка вычислительного устройства

In [37]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

## Загрузка датасета

In [38]:
class CocoDataset(torch.utils.data.Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self, root, json, vocab, start, end, transform=None):
        """Set the path for images, captions and vocabulary wrapper.

        Args:
            root: image directory.
            json: coco annotation file path.
            vocab: vocabulary wrapper.
            transform: image transformer.
        """
        self.root = root
        self.coco = COCO(json)
        self.ids = list(self.coco.anns.keys())[start:end]
        self.vocab = vocab
        self.transform = transform

    def __getitem__(self, index):
        """Returns one data pair (image and caption)."""
        coco = self.coco
        vocab = self.vocab
        ann_id = self.ids[index]
        caption = coco.anns[ann_id]['caption']
        img_id = coco.anns[ann_id]['image_id']
        path = coco.loadImgs(img_id)[0]['file_name']

        image = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        # Convert caption (string) to word ids.
        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        return image, target

    def __len__(self):
        return len(self.ids)


def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (image, caption).

    We should build custom collate_fn rather than using default collate_fn,
    because merging caption (including padding) is not supported in default.

    Args:
        data: list of tuple (image, caption).
            - image: torch tensor of shape (3, 256, 256).
            - caption: torch tensor of shape (?); variable length.

    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(images, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]
    return images, targets, lengths


## Обучение

In [39]:
def train_on_batch(model, images, captions, lengths, optimizer, loss_function):
    encoder, decoder = model

    encoder.zero_grad()
    decoder.zero_grad()

    images = images.to(device)
    captions = captions.to(device)

    targets = torch.nn.utils.rnn.pack_padded_sequence(captions, lengths, batch_first=True)[0]

    features = encoder(images)
    outputs = decoder(features, captions, lengths)

    loss = loss_function(outputs, targets)

    loss.backward()
    optimizer.step()

    return loss.cpu().item()


In [40]:
def train_epoch(train_generator,
                model,
                loss_function,
                optimizer,
                callback = None):

    epoch_loss = 0
    total = 0
    for it, (images, captions, lengths) in enumerate(train_generator):
        batch_loss = train_on_batch(model,
                                    images,
                                    captions,
                                    lengths,
                                    optimizer,
                                    loss_function)
        if callback is not None:
            with torch.no_grad():
                callback(model, batch_loss)

        epoch_loss += batch_loss*len(images)
        total += len(images)

    return epoch_loss/total


In [41]:
def trainer(count_of_epoch,
            batch_size,
            model,
            dataset,
            collate_fn,
            loss_function,
            optimizer,
            callback = None):
    iterations = tqdm(range(count_of_epoch), desc='epoch')
    iterations.set_postfix({'train epoch loss': np.nan})

    n_samples = len(dataset)
    number_of_batch = n_samples//batch_size + (n_samples%batch_size>0)

    for it in iterations:
        batch_generator = tqdm(
                torch.utils.data.DataLoader(dataset=dataset,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            collate_fn=collate_fn),
                leave=False, total=number_of_batch)

        epoch_loss = train_epoch(
            train_generator = batch_generator,
            model = model,
            loss_function = loss_function,
            optimizer = optimizer,
            callback = callback)

        iterations.set_postfix({'train epoch loss': epoch_loss})


## Отслеживание обучения модели

In [42]:
%load_ext tensorboard
%tensorboard --logdir experiment/


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 7772), started 0:08:45 ago. (Use '!kill 7772' to kill it.)

In [43]:
def add_images(writer, dataset):
    num = 10

    images = []
    for i in range(-num, 0):
        image, _ = dataset.__getitem__(i)
        images.append(image)

    digit_size = (3, 256, 256)

    fig = plt.figure(figsize=(36, 36 / 10 * (num // 10 + 1)))
    gs = gridspec.GridSpec(1, num)
    for i in range(num):
        ax = fig.add_subplot(gs[i])
        ax.imshow(images[i], interpolation='lanczos')
        ax.axis('off')

    writer.add_figure('VISUAL/images', fig, 0)


In [44]:
class callback():
    def __init__(self, writer, dataset, collate_fn, loss_function, vocab, delimeter = 100, batch_size=64):
        self.step = 0
        self.writer = writer
        self.delimeter = delimeter
        self.loss_function = loss_function
        self.vocab = vocab
        self.batch_size = batch_size

        self.dataset = dataset
        self.collate_fn = collate_fn

    def forward(self, model, loss):
        self.step += 1
        self.writer.add_scalar('LOSS/train', loss, self.step)

        if self.step % self.delimeter == 0:

            batch_generator = torch.utils.data.DataLoader(dataset = self.dataset,
                                                          batch_size = self.batch_size,
                                                          collate_fn = self.collate_fn)
            encoder, decoder = model

            test_loss = 0
            encoder.eval()
            for it, (images, captions, lengths) in enumerate(batch_generator):
                images = images.to(device)
                captions = captions.to(device)

                targets = torch.nn.utils.rnn.pack_padded_sequence(captions, lengths, batch_first=True)[0]

                features = encoder(images)
                outputs = decoder(features, captions, lengths)

                test_loss += self.loss_function(outputs, targets).cpu().item()*len(images)

            test_loss /= self.dataset.__len__()

            self.writer.add_scalar('LOSS/test', test_loss, self.step)

            num = 10

            # Convert word_ids to words
            for i in range(-num, 0):
                sentence = ''
                image = images[i].unsqueeze(0)

                feature = encoder(image.to(device))
                sampled_ids = decoder.sample(feature)
                sampled_ids = sampled_ids[0].cpu().numpy() # (1, max_seq_length) -> (max_seq_length)

                sampled_caption = []
                for word_id in sampled_ids:
                    word = self.vocab.idx2word[word_id]
                    if word == '<end>':
                        break
                    if word == '<start>':
                        continue

                    sampled_caption.append(word)

                sentence = ' '.join(sampled_caption)

                self.writer.add_text('TEXT/test' + '[' + str(-i) + ']', sentence, self.step)

    def __call__(self, model, loss):
        return self.forward(model, loss)


## Модель

In [45]:
class EncoderCNN(torch.nn.Module):
    def __init__(self, embed_size):
        """Load the pretrained ResNet-152 and replace top fc layer."""
        super(EncoderCNN, self).__init__()
        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-1]      # delete the last fc layer.
        self.resnet = torch.nn.Sequential(*modules)

        self.last_layer = torch.nn.Sequential(
            torch.nn.Linear(resnet.fc.in_features, embed_size),
            torch.nn.BatchNorm1d(embed_size)
        )

    def forward(self, images):
        """Extract feature vectors from input images."""
        with torch.no_grad():
            features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        return self.last_layer(features)


In [46]:
class DecoderRNN(torch.nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
        """Set the hyper-parameters and build the layers."""
        super(DecoderRNN, self).__init__()
        self.max_seg_length = max_seq_length

        self.embed = torch.nn.Embedding(vocab_size, embed_size)
        self.lstm = torch.nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = torch.nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions, lengths):
        """Decode image feature vectors and generates captions."""
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)

        packed = torch.nn.utils.rnn.pack_padded_sequence(embeddings, lengths, batch_first=True)

        hiddens, _ = self.lstm(packed)

        return self.linear(hiddens[0])

    def sample(self, features, states=None):
        """Generate captions for given image features using greedy search."""
        sampled_ids = []
        inputs = features.unsqueeze(1)

        for i in range(self.max_seg_length):
            hiddens, states = self.lstm(inputs, states) # hiddens: (batch_size, 1, hidden_size)
            outputs = self.linear(hiddens.squeeze(1))   # outputs:  (batch_size, vocab_size)
            _, predicted = outputs.max(1)               # predicted: (batch_size)
            sampled_ids.append(predicted)
            inputs = self.embed(predicted)              # inputs: (batch_size, embed_size)
            inputs = inputs.unsqueeze(1)                # inputs: (batch_size, 1, embed_size)

        sampled_ids = torch.stack(sampled_ids, 1)       # sampled_ids: (batch_size, max_seq_length)
        return sampled_ids


## Выборка

In [47]:
class Vocabulary(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

def build_vocab(json, threshold):
    """Build a simple vocabulary wrapper."""
    coco = COCO(json)
    counter = Counter()
    ids = coco.anns.keys()
    for i, id in enumerate(ids):
        caption = str(coco.anns[id]['caption'])
        tokens = nltk.tokenize.word_tokenize(caption.lower())
        counter.update(tokens)

    # If the word frequency is less than 'threshold', then the word is discarded.
    words = [word for word, cnt in counter.items() if cnt >= threshold]

    # Create a vocab wrapper and add some special tokens.
    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    # Add the words to the vocabulary.
    for i, word in enumerate(words):
        vocab.add_word(word)
    return vocab


In [48]:
root = 'data/resized2014'
json = 'data/annotations_trainval2014/annotations/captions_train2014.json'

train_size = 15000
test_size = 3000
threshold = 4

vocab = build_vocab(json, threshold)

transform = transforms.Compose([
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))])

dataset_train = CocoDataset(root, json, vocab, 0,          train_size,             transform)
dataset_test  = CocoDataset(root, json, vocab, train_size, train_size + test_size, transform)

images_test = CocoDataset(root, json, vocab, train_size, train_size + test_size, None)


loading annotations into memory...
Done (t=0.44s)
creating index...
index created!
loading annotations into memory...
Done (t=0.37s)
creating index...
index created!
loading annotations into memory...
Done (t=0.39s)
creating index...
index created!
loading annotations into memory...
Done (t=0.37s)
creating index...
index created!


## Обучение

In [49]:
grid = ParameterGrid({'emb_dim' : [300, 500],
                      'hidden_dim': [600, 1000],
                      'num_layers': [1, 3],
                      'vocab_dim': [len(vocab)],
                      'max_seq_length': [20]})

for item in tqdm(grid):
    encoder = EncoderCNN(item['emb_dim'])
    decoder = DecoderRNN(item['emb_dim'],
                         item['hidden_dim'],
                         item['vocab_dim'],
                         item['num_layers'],
                         item['max_seq_length'])

    encoder.to(device)
    decoder.to(device)

    name = 'experiment/emb{}_hiden{}_layers{}_vocab{}'.format(
            item['emb_dim'], item['hidden_dim'], item['num_layers'], item['vocab_dim'])

    writer = SummaryWriter(log_dir = name)
    add_images(writer, images_test)

    params = list(decoder.parameters()) + list(encoder.last_layer.parameters())
    optimizer = torch.optim.Adam(params, lr=1e-3)

    loss_function = torch.nn.CrossEntropyLoss(ignore_index=vocab('<pad>'))

    call = callback(writer, dataset_test, collate_fn, loss_function, vocab, delimeter = 10)

    trainer(count_of_epoch = 3,
            batch_size = 64,
            model = (encoder, decoder),
            dataset = dataset_train,
            collate_fn = collate_fn,
            loss_function = loss_function,
            optimizer = optimizer,
            callback = call)


  0%|          | 0/8 [00:00<?, ?it/s]

epoch:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

epoch:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

epoch:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

epoch:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

epoch:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

epoch:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

epoch:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

epoch:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

  0%|          | 0/235 [00:00<?, ?it/s]

## Результаты

Полученные результаты генераций получились достаточно посредственные, некоторые описания совпадают для картинок. Это может быть связано быть как и малостью выбранной подвыборки, так и некачественным подбором наборов параметров.

Наихудшие модели получились с большим числом слоёв, модель получается сложной. Соотвественно лучше всего справились модели с 1 слоем. Остальные параметры (emb_dim, hidden_dim) почти не влияют на результаты.

Размер словаря очень сильно вляет на качество генерации описаний. Если взять словарь только по описаниям тестовой выборки, то генерация заканчивалась с описанием составленных практически из одних артиклей.