In [16]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))


# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [17]:
import torch
import torch.nn as nn
import torchvision.models as models

class CNN_Encoder(nn.Module) : 
    def __init__(self, embedding_len, trainConvNet = False) -> None:
        super(CNN_Encoder, self).__init__()
        self.trainConvNet = False # just using a pre trained CNN to save on some time.
        self.inception = models.inception_v3(pretrained = True)
        self.inception.aux_logits = False
        self.inception.fc = nn.Linear(self.inception.fc.in_features, embedding_len) #remove the last linear layer
        #and replace it with the mapping to the embed size. last layer of cnn to size of input to rnn.
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, images) :
        features = self.inception(images) #get the features from the pretrained network
        
        #FINE TUNING OF THE MODEL TO ALLOW FC WEIGHTS AND BIAS TO TRAIN AND ELSE TO FREEZE.
        for name, param in self.inception.named_parameters():
            if "fc.weight" in name or "fc.bias" in name : #only fully connected layer is not frozen over here.
                param.requires_grad = True
            else :
                param.requires_grad = self.trainConvNet # unless you want to retrain the entire cnn keep false.
        return self.dropout(self.relu(features))

class RNN(nn.Module) : 
    def __init__(self, embedding_len, hidden_len, vocab_len, num_layers) :
        super(RNN,self).__init__()
        self.embed = nn.Embedding(vocab_len,embedding_len)
        self.lstm = nn.LSTM(embedding_len,hidden_len, num_layers)
        self.linear = nn.Linear(hidden_len,vocab_len) #hidden len to vocab size which serves as indexing essentially
        self.dropout = nn.Dropout(0.5)
    def forward(self, features, captions):
        embeddings = self.dropout(self.embed(captions))
        #adding dropout post embedding layer is going to let go of certain indices and help prevent overfitting.
        embeddings = torch.cat((features.unsqueeze(0),embeddings),dim = 0)
        #Here the teacher forces all the correct captions inside as embeddings. 
        #We feed in all the correct ground truth to the rnn and later in eval we will just feed prev cell's output
        #but for now this will do well since it won't accidentally end up leaerning incorrect captions.
        #here we want to pass the feature vector from the CNN as the first embedding to our RNN.
        #So we take that feature and make if of the size (1,embeding_len) unsqueeze justs makes it from 
        #(embedding_len,) to (1,em) by adding a dimension of 1 at the 0th pos. Now we concat along row direction
        hiddens, _ = self.lstm(embeddings) #Pass the encoding through the linear layer to get outputs.
        outputs = self.linear(hiddens)
        return outputs

class Bridge(nn.Module):
    def __init__(self, embedding_len, hidden_len,vocab_len, num_layers) :
        super(Bridge,self).__init__()
        self.cnn_encoder = CNN_Encoder(embedding_len)
        self.rnn = RNN(embedding_len,hidden_len,vocab_len,num_layers)
    
    def forward(self, images,captions):
        features = self.cnn_encoder(images) #get image feature and pass to rnn.
        outputs = self.rnn(features,captions)
        return outputs
    def caption_image(self,image,vocabulary,max_len =50):
        result_caption = []
        #to return the captions.
        with torch.no_grad(): #unsqueeze for batch dimension
            x = self.cnn_encoder(image).unsqueeze(0) #start with input embedding as x -> image feature.
            states = None 
            for _ in range(max_len):
                #take the hidden layer and hidden state of lstm output and hiddens will have the encoding.
                #pass encoding through linear layer and get the prediction and hence word.
                hiddens, states = self.rnn.lstm(x,states)
                output = self.rnn.linear(hiddens.squeeze(0))
                predicted = output.argmax(1)
                #Previous prediction x is passed in the next time step to the lstm.
                result_caption.append(predicted.item())
                x = self.rnn.embed(predicted).unsqueeze(0)

                if vocabulary.itos[predicted.item()] == "<EOS>" :
                    break
        return [vocabulary.itos[idx] for idx in result_caption]


In [18]:
import os  # when loading file paths
import pandas as pd  # for lookup in annotation file
import spacy  # for tokenizer
import torch
from torch.nn.utils.rnn import pad_sequence  # pad batch
from torch.utils.data import DataLoader, Dataset
from PIL import Image  # Load img
import torchvision.transforms as transforms


# We want to convert text -> numerical values
# 1. We need a Vocabulary mapping each word to a index
# 2. We need to setup a Pytorch dataset to load the data
# 3. Setup padding of every batch (all examples should be
#    of same seq_len and setup dataloader)
# Note that loading the image is very easy compared to the text!

# Download with: python -m spacy download en
spacy_eng = spacy.load("en_core_web_sm")

#vocabulary is just a dictionary class that I got from stackoverflow to make life easier.
class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1

                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

# REFERENCE : https://www.kaggle.com/code/mdteach/torch-data-loader-flicker-8k
class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        # Get img, caption columns
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

        # Initialize vocabulary and build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        return img, torch.tensor(numericalized_caption)


class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)

        return imgs, targets


def get_loader(
    root_folder,
    annotation_file,
    transform,
    batch_size=32,
    num_workers=8,
    shuffle=True,
    pin_memory=True,
):
    dataset = FlickrDataset(root_folder, annotation_file, transform=transform)

    pad_idx = dataset.vocab.stoi["<PAD>"]

    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx),
    )

    return loader, dataset


if __name__ == "__main__":
    transform = transforms.Compose(
        [transforms.Resize((224, 224)), transforms.ToTensor(),]
    )

    loader, dataset = get_loader(
        "/kaggle/input/flickr8kimagescaptions/flickr8k/images", "/kaggle/input/flickr8kimagescaptions/flickr8k/captions.txt", transform=transform
    )

#     for idx, (imgs, captions) in enumerate(loader):
#         print(imgs.shape)
#         print(captions.shape)

In [19]:
def print_examples(model, device, dataset):
    transform = transforms.Compose(
        [
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    model.eval()
    test_img1 = transform(Image.open("/kaggle/input/test-e/test_examples/dog.jpg").convert("RGB")).unsqueeze(
        0
    )
    print("Example 1 CORRECT: Dog on a beach by the ocean")
    print(
        "Example 1 OUTPUT: "
        + " ".join(model.caption_image(test_img1.to(device), dataset.vocab))
    )
    test_img2 = transform(
        Image.open("/kaggle/input/test-e/test_examples/child.jpg").convert("RGB")
    ).unsqueeze(0)
    print("Example 2 CORRECT: Child holding red frisbee outdoors")
    print(
        "Example 2 OUTPUT: "
        + " ".join(model.caption_image(test_img2.to(device), dataset.vocab))
    )
    test_img3 = transform(Image.open("/kaggle/input/test-e/test_examples/bus.png").convert("RGB")).unsqueeze(
        0
    )
    print("Example 3 CORRECT: Bus driving by parked cars")
    print(
        "Example 3 OUTPUT: "
        + " ".join(model.caption_image(test_img3.to(device), dataset.vocab))
    )
    test_img4 = transform(
        Image.open("/kaggle/input/test-e/test_examples/boat.png").convert("RGB")
    ).unsqueeze(0)
    print("Example 4 CORRECT: A small boat in the ocean")
    print(
        "Example 4 OUTPUT: "
        + " ".join(model.caption_image(test_img4.to(device), dataset.vocab))
    )
    test_img5 = transform(
        Image.open("/kaggle/input/test-e/test_examples/horse.png").convert("RGB")
    ).unsqueeze(0)
    print("Example 5 CORRECT: A cowboy riding a horse in the desert")
    print(
        "Example 5 OUTPUT: "
        + " ".join(model.caption_image(test_img5.to(device), dataset.vocab))
    )
#     model.train()

In [29]:
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("!!!!Saving checkpoint!!!!")
    torch.save(state, filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("!!!!Loading checkpoint!!!!")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    step = checkpoint["step"]
    return step

def train():
    transform = transforms.Compose(
        [
            transforms.Resize((356, 356)),
            transforms.RandomCrop((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    train_loader, dataset = get_loader(
        root_folder="/kaggle/input/flickr8kimagescaptions/flickr8k/images",
        annotation_file="/kaggle/input/flickr8kimagescaptions/flickr8k/captions.txt",
        transform=transform,
        num_workers=2,
    )

    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    load_model = True
    save_model = True
    trainConvnet = False

    # Hyperparameters
    embedding_len = 256
    hidden_len = 256
    vocab_len = len(dataset.vocab)
    num_layers = 1
    learning_rate = 3e-4
    num_epochs = 10

    # for tensorboard
    writer = SummaryWriter("runs/flickr")
    step = 0

    # initialize model, loss etc
    model = Bridge(embedding_len, hidden_len, vocab_len, num_layers).to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Only finetune the CNN
    for name, param in model.cnn_encoder.inception.named_parameters():
        if "fc.weight" in name or "fc.bias" in name:
            param.requires_grad = True
        else:
            param.requires_grad = trainConvnet

    if load_model:
        step = load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)

    model.train()

    for epoch in range(num_epochs):
        # Uncomment the line below to see a couple of test cases
        # print_examples(model, device, dataset)

        if save_model:
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "step": step,
            }
            save_checkpoint(checkpoint)

        for idx, (imgs, captions) in tqdm(
            enumerate(train_loader), total=len(train_loader), leave=False
        ):
            imgs = imgs.to(device)
            captions = captions.to(device)

            outputs = model(imgs, captions[:-1])
            loss = criterion(
                outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
            )

            writer.add_scalar("Training loss", loss.item(), global_step=step)
            step += 1

            optimizer.zero_grad()
            loss.backward(loss)
            optimizer.step()


if __name__ == "__main__":
    train()

!!!!Loading checkpoint!!!!
!!!!Saving checkpoint!!!!


                                                   

!!!!Saving checkpoint!!!!


                                                   

!!!!Saving checkpoint!!!!


                                                   

!!!!Saving checkpoint!!!!


                                                   

!!!!Saving checkpoint!!!!


                                                   

!!!!Saving checkpoint!!!!


                                                   

!!!!Saving checkpoint!!!!


                                                   

!!!!Saving checkpoint!!!!


                                                   

!!!!Saving checkpoint!!!!


                                                   

In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_len = 256
hidden_len = 256
vocab_len = len(dataset.vocab)
num_layers = 1
learning_rate = 3e-4
model = Bridge(embedding_len, hidden_len, vocab_len, num_layers).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)

!!!!Loading checkpoint!!!!


45540

In [31]:
print_examples(model, device, dataset)

Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: <SOS> a brown dog is running through the snow . <EOS>
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: <SOS> a little girl in a pink shirt is running through a grassy area . <EOS>
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: <SOS> a man in a red jacket is standing on a street with a bicycle . <EOS>
Example 4 CORRECT: A small boat in the ocean
Example 4 OUTPUT: <SOS> a man in a wetsuit is surfing . <EOS>
Example 5 CORRECT: A cowboy riding a horse in the desert
Example 5 OUTPUT: <SOS> a man and a dog are walking through the snow . <EOS>


In [40]:
model.eval()
test_img1 = transform(Image.open("/kaggle/input/flickr8kimagescaptions/flickr8k/images/1022454332_6af2c1449a.jpg").convert("RGB")).unsqueeze(0)
print(
    "Example  OUTPUT: "
        + " ".join(model.caption_image(test_img1.to(device), dataset.vocab))
    )

Example  OUTPUT: <SOS> a man and a woman are sitting on a dock near a lake . <EOS>
