In [1]:
from IPython.display import clear_output, display

In [2]:
# %pip install torch torchvision pillow spacy numpy
# %pip install torchtext
# %pip install pycocotools

In [3]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CocoCaptions
from torchtext.data.utils import get_tokenizer

from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm

from PIL import Image
import spacy



In [4]:
dataset_variant = 'val2017'

## Downloading the data

In [5]:
# Define paths for dataset and annotations
data_dir = './data'
images_dir = os.path.join(data_dir, dataset_variant)
annotations_dir = os.path.join(data_dir, 'annotations')

# Create directories if they don't exist
if not os.path.exists(data_dir):
    os.makedirs(data_dir)
if not os.path.exists(images_dir):
    os.makedirs(images_dir)
if not os.path.exists(annotations_dir):
    os.makedirs(annotations_dir)

# Download dataset
#!wget http://images.cocodataset.org/zips/{dataset_variant}.zip -P {data_dir}

# Unzip dataset
# !unzip {data_dir}/{dataset_variant}.zip -d {data_dir}

clear_output()


In [6]:
# Download annotations
# !wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip -P {annotations_dir}

In [7]:
# # Unzip annotations
# !unzip {annotations_dir}/annotations_trainval2017.zip -d {annotations_dir}

## Loading the Dataset

In [8]:
transform = transforms.Compose(
        [
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ]
    )

# Load MS-COCO dataset
train_dataset = CocoCaptions(root=f'./data/{dataset_variant}', annFile=f'./data/annotations/annotations/captions_{dataset_variant}.json', transform=transform)

loading annotations into memory...
Done (t=0.06s)
creating index...
index created!


## Building the tokenizer and vocabulary

In [None]:
# !python -m spacy download en_core_web_sm

In [9]:
spacy_eng = spacy.load("en_core_web_sm")

In [10]:
def word_tokenize(text):
    return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

In [11]:
# Define the vocabulary and tokenizer
word_to_index = {'<PAD>':0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
index_to_word = {it: k for k, it in word_to_index.items()}
word_freq = {}
caption_lengths = []


# Tokenize captions and build vocabulary
for _, captions in tqdm(train_dataset):
    for caption in captions:
        caption = f'{caption}'
        caption_lengths.append(len(caption))
        tokens = word_tokenize(caption.lower())
        for token in tokens:
            if token not in word_to_index:
                idx = len(word_to_index)
                word_to_index[token] = idx
                index_to_word[idx] = token
                word_freq[token] = 1
            else:
                word_freq[token] += 1

100%|██████████| 5000/5000 [00:27<00:00, 181.59it/s]


In [12]:
word_tokenize('<SOS> hi, my friend <EOS>')  # We will manually add tokens for <EOS> and <SOS> etc after tokenization to avoid them breaking up.

['<', 'sos', '>', 'hi', ',', 'my', 'friend', '<', 'eos', '>']

## Defining the Model

In [13]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()

        self.resnet = models.resnet50(pretrained=True).requires_grad_(False)  # resnet embedding backbone
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size)
        self.relu = nn.ReLU()
        self.times = []
        self.dropout = nn.Dropout(0.5)

    def forward(self, images):

        features = self.resnet(images)
        # features = features[0] if not isinstance(features, torch.Tensor) else features
        return self.dropout(self.relu(features))


class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Sequential(
            nn.Linear(hidden_size, 1024),
            nn.Linear(1024, vocab_size)
        )
        self.dropout = nn.Dropout(0.5)

    def forward(self, features, captions):
        embeddings = self.dropout(self.embed(captions))
        embeddings = torch.cat((features.unsqueeze(1), embeddings), dim=-2)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs


class CNNtoRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CNNtoRNN, self).__init__()
        self.encoderCNN = EncoderCNN(embed_size)
        self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

    def forward(self, images, captions):
        features = self.encoderCNN(images)
        outputs = self.decoderRNN(features, captions)
        return outputs

    def caption_image(self, image, max_length=50):
        result_caption = []

        with torch.no_grad():
            x = self.encoderCNN(image)
            states = None

            for _ in range(max_length):
                hiddens, states = self.decoderRNN.lstm(x, states)
                output = self.decoderRNN.linear(hiddens.squeeze(0))
                predicted = output.argmax(0)
                result_caption.append(predicted.item())
                x = self.decoderRNN.embed(predicted).unsqueeze(0)

                if index_to_word[predicted.item()] == "<EOS>":
                    break

        return [index_to_word[idx] for idx in result_caption]

In [14]:
def convert_sentence_to_idxs(sentence):

    words = word_tokenize(sentence)
    idxs = [word_to_index[word] for word in words]

    return idxs


def convert_idxs_to_sentence(idxs):

    words = [index_to_word[idx] for idx in idxs]
    return ' '.join(words)


# Need to define a collate function which pads the sentence tokens in the batch to be of the same length so they can be stacked.
# We also take care of converting string tokens to idxs here.
def collate_fn(data):

    images, captions = zip(*data)
    images = torch.stack(images, 0)

    captions = [f'{caption[0]}' for caption in captions]  # each image has multiple captions. we use just the first one here.

    # manually adding <SOS> and <EOS> tokens after tokenization and conversion because our tokenizers break <SOS> and <EOS>
    captions = [torch.Tensor([word_to_index['<SOS>']]+convert_sentence_to_idxs(caption.lower().strip())+[word_to_index['<EOS>']]) for caption in captions]

    # 0 is the idx for <PAD> (see index_to_word)
    captions = pad_sequence(captions, batch_first=True, padding_value=0)
    return images, captions

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=4)

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [16]:
embed_size = 2048
hidden_size = 256
vocab_size = len(word_to_index)
num_lstm_layers = 1
learning_rate = 5e-4
num_epochs = 100

In [17]:
# initialize model, loss etc
model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_lstm_layers).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=word_to_index['<PAD>'])  # ignore pad token loss calculations
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Only finetune the CNN
for name, param in model.encoderCNN.resnet.named_parameters():
    if "fc.weight" in name or "fc.bias" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False



In [18]:
# Loading the model (example)
model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_lstm_layers).to(device)  # Initialize the model
model.load_state_dict(torch.load('lstm-model.pth'))  # Load the saved state dictionary
model.to(device)  # Move the model to the appropriate device

CNNtoRNN(
  (encoderCNN): EncoderCNN(
    (resnet): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(in

## Pre-training Testing

In [19]:
test_img_paths = ['data/val2017/000000000139.jpg', 'data/val2017/000000000632.jpg', 'data/val2017/000000000724.jpg']
imgs_pil = [Image.open(path).convert('RGB') for path in test_img_paths]
imgs_test = [transform(im_pil).to(device) for im_pil in imgs_pil]

In [20]:

imgs = torch.stack(imgs_test, 0)
model.eval()
captions = []

for img in imgs_test:
    with torch.no_grad():
        caption = model.caption_image(img.unsqueeze(0))
        caption = ' '.join(caption)
        captions.append(caption)

print(('\n'+'-'*20+'\n').join(captions))

  return F.conv2d(input, weight, bias, self.stride,


<SOS> a living room filled with furniture and a flat screen tv . <EOS>
--------------------
<SOS> a living room filled with furniture and a flat screen tv . <EOS>
--------------------
<SOS> a red stop sign sitting on the side of a road . <EOS>


In [24]:
imgs_pil[2]

IndexError: list index out of range

## Training the model

In [None]:
for epoch in range(num_epochs):

    model.train()

    for idx, (imgs, captions) in tqdm(
        enumerate(train_loader), total=len(train_loader), leave=False
    ):
        imgs = imgs.to(device)
        captions = captions.to(device).type(torch.long)

        outputs = model(imgs, captions[:, :-1])
        loss = criterion(
            outputs.reshape(-1, outputs.shape[2]), captions[:, :].reshape(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # For debugging purposes
    # model.eval()
    # captions = []

    # for img in imgs_test:
    #     with torch.no_grad():
    #         caption = model.caption_image(img.unsqueeze(0))
    #         caption = ' '.join(caption)
    #         captions.append(caption)
    # print('\n'.join(captions))

    print(f'\nEpoch: {epoch+1}/{num_epochs}', "Training loss: ", loss.item())

In [None]:
!ls data/val2017/ | head -20

In [None]:
!ls data/

In [None]:
img = Image.open('data/val2017/000000001490.jpg').convert('RGB')
img

In [None]:
img_t = transform(img).to(device).unsqueeze(0)
model.eval()
with torch.no_grad():
    caption = model.caption_image(img_t)

caption

In [None]:
for img, img_pil in zip(imgs_test, imgs_pil):
    with torch.no_grad():
        caption = model.caption_image(img.unsqueeze(0))
        caption = ' '.join(caption)

        display(img_pil)
        print(caption)
        print('-'*20)

In [None]:
torch.save(model.state_dict(),'lstm-model.pth')
