In [None]:
import os  
import pandas as pd
import spacy  
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence  
from torch.utils.data import DataLoader, Dataset
from PIL import Image  
import torchvision.transforms as transforms
import torchvision.models as models
import torch.optim as optim
from tqdm.auto import tqdm

In [None]:
!python -m spacy download en

Collecting en_core_web_sm==2.2.5
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.5/en_core_web_sm-2.2.5.tar.gz (12.0 MB)
[K     |████████████████████████████████| 12.0 MB 5.7 MB/s 
[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('en_core_web_sm')
[38;5;2m✔ Linking successful[0m
/usr/local/lib/python3.7/dist-packages/en_core_web_sm -->
/usr/local/lib/python3.7/dist-packages/spacy/data/en
You can now load the model via spacy.load('en')


In [None]:
spacy_eng = spacy.load("en")

class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1

                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]


In [None]:
class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        # Get img, caption columns
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

        # Initialize vocabulary and build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        return img, torch.tensor(numericalized_caption)


In [None]:
root_dir = './Dataset/flickr8k/images'
captions_file = './Dataset/flickr8k/captions.txt'

transform = transforms.Compose(
    [
        transforms.Resize((356, 356)),
        transforms.RandomCrop((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

dataset = FlickrDataset(root_dir=root_dir, captions_file=captions_file, transform=transform)

class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [img.unsqueeze(0) for img, _ in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [caption for _, caption in batch]
        targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)

        return imgs, targets

pad_idx = dataset.vocab.stoi["<PAD>"]
train_loader = DataLoader(
    dataset=dataset,
    batch_size=32,
    num_workers=0,
    shuffle=True,
    pin_memory=False,
    collate_fn=MyCollate(pad_idx=pad_idx),
)
 

torch.Size([32, 3, 299, 299])
torch.Size([32, 23])


In [None]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        self.train_CNN = train_CNN
        self.inception = models.inception_v3(pretrained=True, aux_logits=False)
        self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, images):
        features = self.inception(images)
        return self.dropout(self.relu(features)) 
 
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, features, captions):  
        embeddings = self.dropout(self.embed(captions))  
        embeddings = torch.cat((features.unsqueeze(1), embeddings), dim=1)  
        hiddens, _ = self.lstm(embeddings)  
        outputs = self.linear(hiddens)  
        return outputs


class CNNtoRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CNNtoRNN, self).__init__()
        self.encoderCNN = EncoderCNN(embed_size)
        self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

    def forward(self, images, captions):
        features = self.encoderCNN(images)  
        outputs = self.decoderRNN(features, captions)  
        return outputs

    def caption_image(self, image, vocabulary, max_length=50):  
        result_caption = []

        with torch.no_grad():
            x = self.encoderCNN(image).unsqueeze(0) 
            states = None

            for _ in range(max_length):
                hiddens, states = self.decoderRNN.lstm(x, states)  
                output = self.decoderRNN.linear(hiddens)
                predicted = output.argmax(2)
                result_caption.append(predicted.item())
                x = self.decoderRNN.embed(predicted)  

                if vocabulary.itos[predicted.item()] == "<EOS>":
                    break

        return [vocabulary.itos[idx] for idx in result_caption]

In [None]:
def print_examples(model, device, dataset):
    transform = transforms.Compose(
        [
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    model.eval()
    test_img1 = transform(Image.open("./Dataset/flickr8k/images/1000268201_693b08cb0e.jpg").convert("RGB")).unsqueeze(0)  # shape is (1, 3, 299, 299)
    print("Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .")
    print(
        "Example 1 OUTPUT: "
        + " ".join(model.caption_image(test_img1.to(device), dataset.vocab))
    )
    model.train()

In [None]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
load_model = False
save_model = False
train_CNN = False

# Hyperparameters
embed_size = 256
hidden_size = 256
vocab_size = len(dataset.vocab)
num_layers = 1
learning_rate = 3e-4
num_epochs = 10

In [None]:
# initialize model, loss etc
model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Only finetune the CNN
for name, param in model.encoderCNN.inception.named_parameters():
    if "fc.weight" in name or "fc.bias" in name:
        param.requires_grad = True
    else:
        param.requires_grad = train_CNN

model.train()
for epoch in range(num_epochs):

    for idx, (imgs, captions) in tqdm(enumerate(train_loader), total=len(train_loader), leave=False):
        if idx%100 == 0:
            print_examples(model, device, dataset)
        imgs = imgs.to(device)
        captions = captions.to(device)

        outputs = model(imgs, captions[:, :-1])  
        loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))

        optimizer.zero_grad()
        loss.backward(loss)
        optimizer.step()

Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

  0%|          | 0/1265 [00:00<?, ?it/s]

Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Example 1 OUTPUT: hikers digs kayaker her change hit either box fence skips ear course ear holds whilst pets goalie galloping innertube practices leafy explosion beaded mom monster old helps buckets motorized next shore shore seen many numbered tackling chopsticks hawk offering behind snowboarding pro houses practices an background practicing flipped man sling
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man a in a . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a man in a <UNK> . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a black and a dog is on a <UNK> . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red and white dog is running . <EOS>
Example 1 CORRECT: A 

  0%|          | 0/1265 [00:00<?, ?it/s]

Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt is standing on a <UNK> . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt is standing on a <UNK> . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt is sitting on a <UNK> . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt is standing on a beach . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt is standing on a <UNK> . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt is standing on a bench .

  0%|          | 0/1265 [00:00<?, ?it/s]

Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt and a black shirt and a black shirt and a black and white shirt and a black and white shirt and a black and white shirt and a black and white shirt and a black and white shirt and a black and white
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt and a black shirt and a black shirt and a black shirt and a black shirt and a black shirt and a black shirt and a black shirt and a black shirt and a black shirt and a black shirt and a black
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt is standing on a bench . <EOS>


  0%|          | 0/1265 [00:00<?, ?it/s]

Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a blue shirt is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a blue shirt and a black shirt and a black shirt and a black shirt and a black shirt and a black shirt and a black shirt and a black shirt and a black shirt and a black shirt and a black shirt and a black
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt and a black shirt is standin

  0%|          | 0/1265 [00:00<?, ?it/s]

Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt and a black shirt is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a blue shirt is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt and a black hat is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt and a black shirt is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPU

  0%|          | 0/1265 [00:00<?, ?it/s]

Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a blue shirt is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a blue shirt and a hat is standing in front of a building . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a black shirt and a hat is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a blue shirt and a hat is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a black shirt is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <

  0%|          | 0/1265 [00:00<?, ?it/s]

Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a black shirt and a black hat is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a red shirt is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a black shirt and a hat is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a black shirt and a black hat is standing in front of a building . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a black shirt and a black hat is standing in front of a building . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of sta

  0%|          | 0/1265 [00:00<?, ?it/s]

Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a black shirt and a black hat is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a black shirt and a black hat is standing in front of a building . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a black jacket is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man and a woman are sitting on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man and a woman are sitting on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS>

  0%|          | 0/1265 [00:00<?, ?it/s]

Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a black shirt and a hat is standing in front of a building . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man and a woman are sitting on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a black shirt and a hat is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a blue shirt and a hat is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a black shirt is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS

  0%|          | 0/1265 [00:00<?, ?it/s]

Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man and a woman are sitting on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man and a woman are sitting on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man and a woman are sitting on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man and a woman are sitting on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a black shirt and jeans is standing on a bench . <EOS>
Example 1 CORRECT: A child in a pink dress is climbing up a set of stairs in an entry way .
Example 1 OUTPUT: <SOS> a man in a black shirt and a hat is standing