In [None]:
!pip install -q Pillow gdown

### Importing required libraries

In [None]:
import gdown
from PIL import Image
import json
import os
import shutil

import torchvision
import numpy as np
import pandas as pd 
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from tqdm.notebook import tqdm
import nltk

In [None]:
def delete_files(arr):
    for file_path in arr:
        if os.path.exists(file_path):
            os.remove(file_path)

In [None]:
def delete_folders(arr):
    for folder_path in arr:
        if os.path.exists(folder_path):
            shutil.rmtree(folder_path)

## Data downloading, preprocessing and analyzing

Model

In [None]:
gdown.download("https://drive.google.com/u/0/uc?export=download&confirm=X2sC&id=1TlNmpLUBw7jJEXgpliy29Am1HpHl-KNJ", "beheaded_inception3.py", True)

'beheaded_inception3.py'

### Downloading the MS COCO Dataset

**Images download**

In [None]:
%%time
!wget -q http://images.cocodataset.org/zips/train2017.zip
print("Unzipping...")
!unzip -q train2017.zip

Unzipping...
train2017/000000259014.jpg:  write error (disk full?).  Continue? (y/n/^C) n

CPU times: user 1.56 s, sys: 315 ms, total: 1.87 s
Wall time: 17min 3s


In [None]:
delete_files(["/content/train2017.zip"])

**Captions download**

In [None]:

%%time
!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
print("Unzipping...")
!unzip -q annotations_trainval2017.zip

--2021-01-24 22:35:11--  http://images.cocodataset.org/annotations/annotations_trainval2017.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 52.216.140.4
Connecting to images.cocodataset.org (images.cocodataset.org)|52.216.140.4|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 252907541 (241M) [application/zip]
Saving to: ‘annotations_trainval2017.zip’


2021-01-24 22:35:13 (98.0 MB/s) - ‘annotations_trainval2017.zip’ saved [252907541/252907541]

Unzipping...
CPU times: user 35.4 ms, sys: 14.3 ms, total: 49.7 ms
Wall time: 9.86 s


In [None]:
delete_files([
    "/content/annotations_trainval2017.zip",
    "/content/annotations/instances_train2017.json",
    "/content/annotations/instances_val2017.json",
    "/content/annotations/person_keypoints_train2017.json",
    "/content/annotations/person_keypoints_val2017.json",
])

### Data preprocessing

In [None]:
data = json.load(open("/content/annotations/captions_train2017.json"))

In [None]:
def data_table(keys):
    dataframes = []

    for key in keys:

        if key == img_key:
            key_value = "file_name"
            key_id = "id"
        elif key == caption_key:
            key_value = "caption"
            key_id = "image_id"

        data_dict = {}
        data_dict[key_value] = []
        data_dict["id"] = []

        for item in data[key]:
            data_dict[key_value].append(item[key_value])
            data_dict["id"].append(item[key_id])

        dataframes.append(pd.DataFrame(data=data_dict))

    return pd.merge(dataframes[0], dataframes[1], on=("id"))

**Get rid of redundant data columns**

In [None]:
%%time
img_key = "images"
caption_key = "annotations"

df = data_table((img_key, caption_key))

CPU times: user 591 ms, sys: 23.3 ms, total: 614 ms
Wall time: 702 ms


In [None]:
df

Unnamed: 0,file_name,id,caption
0,000000391895.jpg,391895,A man with a red helmet on a small moped on a ...
1,000000391895.jpg,391895,Man riding a motor bike on a dirt road on the ...
2,000000391895.jpg,391895,A man riding on the back of a motorcycle.
3,000000391895.jpg,391895,A dirt path with a young person on a motor bik...
4,000000391895.jpg,391895,A man in a red shirt and a red hat is on a mot...
...,...,...,...
591748,000000475546.jpg,475546,The patrons enjoy their beverages at the bar.
591749,000000475546.jpg,475546,People having a drink in a basement bar.
591750,000000475546.jpg,475546,A group of friends enjoys a drink while sittin...
591751,000000475546.jpg,475546,Group of people drinking wine at a public loca...


In [None]:

df.drop(columns="id", inplace=True)
df.rename(columns={"file_name": "image"}, inplace=True)

In [None]:
df.head(7)

Unnamed: 0,image,caption
0,000000391895.jpg,A man with a red helmet on a small moped on a ...
1,000000391895.jpg,Man riding a motor bike on a dirt road on the ...
2,000000391895.jpg,A man riding on the back of a motorcycle.
3,000000391895.jpg,A dirt path with a young person on a motor bik...
4,000000391895.jpg,A man in a red shirt and a red hat is on a mot...
5,000000522418.jpg,A woman wearing a net on her head cutting a ca...
6,000000522418.jpg,A woman cutting a large white sheet cake.


-- Empty table data check --

In [None]:
df.isna().any()

image      False
caption    False
dtype: bool

## Creating datasets, iterators

In [None]:
IMAGE_COL_LABEL = df.columns[0]
CAPTION_COL_LABEL = df.columns[1]

In [None]:
import string

class TextDataset(Dataset):
    
    BOS = "<BOS>"
    EOS = "<EOS>"
    PAD = "<PAD>"

    def __init__(self, text_data: pd.Series, min_word_freq=None, max_word_freq=None, max_sentence_len=20, lower=True):

        if min_word_freq is None:
            self.min_word_freq = -np.inf
        else:
            self.min_word_freq = min_word_freq

        if max_word_freq is None:
            self.max_word_freq = np.inf
        else:
            self.max_word_freq = max_word_freq


        self.max_sentence_len = max_sentence_len
        self.lower = lower

        self.tokenizer = nltk.WordPunctTokenizer()
        self.additional_tokens = (self.BOS, self.EOS, self.PAD)
        self.token2idx = {self.BOS: 1, self.EOS: 2, self.PAD: 3}
        self.vocab = set([self.BOS, self.EOS, self.PAD])

        for row in text_data:
            self.vocab.update([token for token in self.__tokenize(row) if token not in self.additional_tokens])
  
        self.__remove_rare_words(text_data)
        self.token2idx.update({token:num + len(self.additional_tokens) for num, token in enumerate(self.vocab) if token not in self.additional_tokens})
        self.idx2token = {num: token for token, num in self.token2idx.items()}

    def __len__(self):
        return len(self.token2idx.keys())

    def __getitem__(self, text):            
        tokens = self.__tokenize(text)
        text_ids = [self.token2idx[token] for token in tokens]        
        return text_ids

    def __tokenize(self, row):
        cut_edge = self.max_sentence_len - 2

        row = row.lower() if self.lower else row
        tokens = self.tokenizer.tokenize(row)
        tokens = [token for token in tokens if token not in string.punctuation]
        tokens = [token for token in tokens if token in self.vocab]
        tokens = tokens[:cut_edge]
        tokens = [self.BOS] + tokens + [self.EOS]
        tokens = tokens + (self.max_sentence_len - len(tokens)) * [self.PAD]
        return tokens

    def __remove_rare_words(self, text: pd.Series):
        counted_words = text.str.lower().str.split().explode().value_counts()

        remove_list =  [key for key, val in dict(counted_words).items() 
                        if (val < self.min_word_freq or val > self.max_word_freq) and (key not in self.additional_tokens)]

        for key in remove_list:
            self.vocab.discard(key)

In [None]:
tokenizer = nltk.WordPunctTokenizer()
max_sentence_len = 20

def tokenize(row):
    cut_edge = max_sentence_len - 2

    row = row.lower()
    tokens = tokenizer.tokenize(row)
    tokens = [token for token in tokens if token not in string.punctuation]
    tokens = tokens[:cut_edge]
    tokens = [1] + tokens + [2]
    tokens = tokens + (max_sentence_len - len(tokens)) * [3]
    return tokens

### Creating text dataset

In [None]:
%%time
text_data = df.iloc[:, 1]
text_dataset = TextDataset(text_data=text_data, min_word_freq=5, max_sentence_len=20, lower=True)

CPU times: user 9.37 s, sys: 186 ms, total: 9.56 s
Wall time: 9.67 s


### Applying tokenization for data captions

In [None]:
df[CAPTION_COL_LABEL] = df[CAPTION_COL_LABEL].apply(lambda x: text_dataset[x])

### BatchIterator class for creating data iterators

In [None]:
class BatchIterator:
    def __init__(self, dataframe, unique_images, image_col_label, caption_col_label, batch_size, image_transformer, main_img_path, shuffle=False):
        self.dataframe = dataframe
        self.unique_images = unique_images
        self.image_col = image_col_label
        self.caption_col = caption_col_label
        self.image_transformer = image_transformer
        self.main_img_path = main_img_path

        self.num_samples = len(unique_images)
        self.batch_size = batch_size
        self.batches_count =  self.num_samples // self.batch_size
        self.shuffle = shuffle

    def __len__(self):
        return self.batches_count

    def __iter__(self):
        indices = np.arange(self.num_samples)

        if self.shuffle:
            np.random.shuffle(indices)

        for start in range(0, self.num_samples, self.batch_size):
            end = min(start + self.batch_size, self.num_samples)

            batch_indices = indices[start:end]
            batch_images = []
            batch_captions = []

            for idx in batch_indices:
                image_name = self.dataframe[self.image_col][idx]

                image = self.__get_image_matrix(image_name)                
                caption = self.__get_caption(image_name)

                batch_images.append(image)
                batch_captions.append(caption)

            yield {
                "images": batch_images,
                "captions": torch.tensor(batch_captions)
            }

    def __get_image_matrix(self, image_name):
        image = Image.open(os.path.join(self.main_img_path, image_name))
        return self.image_transformer(image)

    def __get_caption(self, image_name):
        all_captions = np.array(
                        self.dataframe[self.dataframe[self.image_col] == image_name][self.caption_col]
                      )
        return random.choice(all_captions)

In [None]:
def split_data(dataframe, ratios):
    data_len = dataframe.shape[0]

    lengths = [int(data_len * ratio) for ratio in ratios]
    if np.sum(lengths) != data_len:
        lengths[-1] = data_len - np.sum(lengths[:-1])
    
    split_indices = [np.sum(lengths[:i+1]) for i in range(len(lengths))]
    return np.split(df, split_indices)[:len(lengths)]

### Spliting the data

In [None]:
train_data, valid_data, test_data = split_data(df, [0.8, 0.1, 0.1])

### Creating data iterators

In [None]:
image_transformer = torchvision.transforms.Compose([
                          torchvision.transforms.ToTensor(),
                          torchvision.transforms.Resize((299, 299)),
                          torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])
                    ])

In [None]:
BATCH_SIZE = 32
unique_images = np.unique(df[IMAGE_COL_LABEL])
main_img_path = "/content/train2017"

In [None]:
train_iterator = BatchIterator(train_data, unique_images, IMAGE_COL_LABEL, CAPTION_COL_LABEL, BATCH_SIZE, image_transformer, main_img_path)
valid_iterator = BatchIterator(valid_data, unique_images, IMAGE_COL_LABEL, CAPTION_COL_LABEL, BATCH_SIZE, image_transformer, main_img_path)
test_iterator = BatchIterator(test_data, unique_images, IMAGE_COL_LABEL, CAPTION_COL_LABEL, BATCH_SIZE, image_transformer, main_img_path)

# Main Part

### Models

**ENCODER**

In [None]:
class Encoder(nn.Module):
    def __init__(self, cnn_model, embedding_size, cnn_feature_size=2048):
        super().__init__()
        self.cnn_model = cnn_model
        self.fc = nn.Linear(cnn_feature_size, embedding_size) 

    def forward(self, images):
        _, features, logits = self.cnn_model(images)
        features = self.fc(features)
        return features 

**DECODER**

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, dropout=0.3, num_layers=2, bidirectional=True, cnn_feature_size=2048):
        super().__init__()

        num_directions = 2 if bidirectional else 1
        assert hidden_dim % num_directions == 0
        rnn_hidden_dim = hidden_dim // 2

        self.vocab_size = vocab_size
        self.embedding_layer = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, rnn_hidden_dim, num_layers=num_layers,
                           dropout=dropout, bidirectional=bidirectional, batch_first=True)

        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, image_vectors, captions):
        embedded = self.embedding_layer(captions)
        
        print("Embedded shape: ", embedded.shape)
        print("image_vectors shape: ", image_vectors.shape)

        concated_data = torch.cat((image_vectors.unsqueeze(1), embedded), dim=1)
        rnn_output, hidden = self.rnn(concated_data)
        logits = self.fc(rnn_output)
        return logits, hidden 

In [None]:
        # captions = captions[:, :-1]
        # embed = self.embedding_layer(captions)
        # embed = torch.cat((features.unsqueeze(1), embed), dim = 1)
        # lstm_outputs, _ = self.lstm(embed)
        # out = self.linear(lstm_outputs)
        
        # return out

**SEQ-2-SEQ**

In [None]:
class CaptionNet(nn.Module):
    def __init__(self, encoder, decoder, device, teacher_forcing_ratio = 0.5):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.teacher_forcing_ratio = teacher_forcing_ratio

    def forward(self, images, captions):
        features = self.encoder(images)

        batch_size = captions.shape[0]
        max_len = captions.shape[1]
        
        #tensor to store decoder outputs
        outputs = torch.zeros(max_len, batch_size, self.decoder.vocab_size).to(self.device)
        
        for idx in range(1, max_len):
            output = self.decoder(features, captions)
            
            outputs[idx] = output
            # teacher_force = random.random() < teacher_forcing_ratio
            # top word or ground truth
            # input = (captions[idx] if teacher_force else output.max(1)[1])
        
        return outputs

In [None]:
    # def sample(self, inputs, states=None, max_len=20):
    #     " accepts pre-processed  image tensor (inputs) and returns predicted sentence (list of tensor ids of length max_len) "
    #     output_sentence = []
    #     for i in range(max_len):
    #         lstm_outputs, states = self.lstm(inputs, states)
    #         lstm_outputs = lstm_outputs.squeeze(1)
    #         out = self.linear(lstm_outputs)
    #         last_pick = out.max(1)[1]
    #         output_sentence.append(last_pick.item())
    #         inputs = self.embedding_layer(last_pick).unsqueeze(1)
        
    #     return output_sentence

In [None]:
#@title Hyperparameters { run: "auto" }

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

VOCAB_SIZE = len(text_dataset.vocab)
EMBEDDING_DIM = 250 #@param {type:"slider", min:100, max:1000, step:50}
HIDDEN_DIM = 128 #@param ["128", "256", "512", "1024", "2048"] {type:"raw"}
DROPOUT = 0.2 #@param {type:"slider", min:0, max:1, step:0.1}
NUM_LAYERS = 2 #@param {type:"slider", min:1, max:10, step:1}
BIDIRECTIONAL = True #@param {type:"boolean"}
TEACHER_FORCE_RATIO = 0.3 #@param {type:"slider", min:0, max:1, step:0.1}

Device:  cuda


Inception model loading

In [None]:
%%time
from beheaded_inception3 import beheaded_inception_v3
inception = beheaded_inception_v3().train(False)

Downloading: "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-1a9a5a14.pth


HBox(children=(FloatProgress(value=0.0, max=108857766.0), HTML(value='')))


CPU times: user 3min 32s, sys: 999 ms, total: 3min 33s
Wall time: 3min 34s


In [None]:
def deny_param_train(model):
    for param in model.parameters():
        param.requires_grad = False

In [None]:
def count_parameters(model):
    return np.sum([param.numel() for param in model.parameters() if param.requires_grad])
# print(f'The model has {count_parameters(model):,} trainable parameters')

In [None]:
encoder = Encoder(inception, EMBEDDING_DIM).to(device)
decoder = Decoder(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM, DROPOUT, NUM_LAYERS, BIDIRECTIONAL).to(device)
# decoder = Decoder(HIDDEN_DIM, VOCAB_SIZE, DROPOUT, NUM_LAYERS, BIDIRECTIONAL).to(device)
captionNet = CaptionNet(encoder, decoder, device, TEACHER_FORCE_RATIO).to(device)

In [None]:
deny_param_train(encoder)

In [None]:
optimizer = torch.optim.Adam(captionNet.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=3).to(device)

In [None]:
#@title {run: "auto" }

EPOCHS = 6 #@param {type:"slider", min:1, max:20, step:1}
CLIP = 2 #@param {type:"slider", min:1, max:10, step:1}

### Training

In [None]:
def iterate_model(mode, model, iterator, optimizer, criterion, device, clip):
    if mode == "train":
        model.train()
    elif mode == "valid":
        model.eval()
    else:
      raise ValueError("Invalid mode, must be 'train' or 'valid'")

    epoch_loss = 0
    history = []
    for iteration, batch in tqdm(enumerate(iterator), total=len(iterator)):

        images = torch.stack(batch["images"]).to(device)
        captions = batch["captions"].to(device)

        output = model(images, captions)
        loss = criterion(output, captions)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()

    epoch_loss /= len(iterator)

    return epoch_loss

In [None]:
MIN_LOSS = np.inf
CUR_PATIENCE = 0
PATIENCE = 2
SAVE_EPOCH = 2

In [None]:
for epoch in range(EPOCHS):
        
    train_loss = iterate_model("train", captionNet, train_iterator, optimizer, criterion, device, CLIP)
    valid_loss = iterate_model("valid", captionNet, valid_iterator, criterion)
    
    if valid_loss < MIN_LOSS:
        MIN_LOSS = valid_loss
        best_model = bert_clf.state_dict()
    else:
        CUR_PATIENCE += 1
        if CUR_PATIENCE == PATIENCE:
            CUR_PATIENCE = 0
            break

    if (epoch + 1) % SAVE_EPOCH == 0:
        torch.save(best_model, 'best-model.pt')