In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import torch.nn as nn
import torchvision.models as models
from tqdm import tqdm
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image


import spacy  # for tokenizer
from torch.nn.utils.rnn import pad_sequence  # pad batch
from torch.utils.data import DataLoader, Dataset


In [3]:
spacy_eng = spacy.load("en_core_web_sm")

In [4]:
data_set = '/kaggle/input/flickr8k/Flickr8K/'
captions_dir = data_set + "Flickr8k_text/"
train_images = captions_dir + 'Flickr_8k.trainImages.txt'
test_images = captions_dir + 'Flickr_8k.testImages.txt'

In [5]:
def load_all_captions(file_name):
    text_file = open(file_name, "r")
    lines = text_file.readlines()
    data_set = []
    for l in range(len(lines)):
        line = lines[l].strip()
        image_name = line[:line.find("#")]
        caption_number = line[line.find("#")+1:line.find("#")+2]
        caption = line[line.find("\t")+1:]
        data_set.append([image_name, caption_number, caption])
    return pd.DataFrame(data_set, columns =['image', 'caption#', 'caption'])

In [6]:
captions_df = load_all_captions(captions_dir + "Flickr8k.token.txt")
# captions_df.head(10)

In [7]:
test_images = pd.read_csv("../input/flickr8k/Flickr8K/Flickr8k_text/Flickr_8k.testImages.txt", header=None, names=["image"])
# test_images.head()

In [8]:
def get_ground_captions(test_df, captions_df):
    new_df = pd.DataFrame(columns = ['image', 'caption'])
    for i in tqdm(test_df.index):
        temp = captions_df[captions_df['image']==test_df.iloc[i]['image']]
        for j in range(5):
            new_df = new_df.append({'image': test_df.iloc[i]['image'], 'caption': temp.iloc[j]['caption']}, ignore_index = True)
    return new_df


In [9]:
test_final = get_ground_captions(test_images,captions_df)
# test_final.head()

In [10]:
train_images = pd.read_csv("../input/flickr8k/Flickr8K/Flickr8k_text/Flickr_8k.trainImages.txt", header=None, names=["image"])
train_final = get_ground_captions(train_images,captions_df)
# train_final.head()

In [11]:
test_final.to_csv('test_final.csv', index=False)
train_final.to_csv('train_final.csv', index=False)

In [25]:
val_images = pd.read_csv("../input/flickr8k/Flickr8K/Flickr8k_text/Flickr_8k.valImages.txt", header=None, names=["image"])
val_images.head()

In [13]:

class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1

                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]


class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=3):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        # Get img, caption columns
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]


        # Initialize vocabulary and build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        return img, torch.tensor(numericalized_caption)


class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)

        return imgs, targets


def get_loader(
    root_folder,
    annotation_file,
    transform,
    batch_size=64,#64
    num_workers=8,
    shuffle=True,
    pin_memory=True,
):
    dataset = FlickrDataset(root_folder, annotation_file, transform=transform)

    pad_idx = dataset.vocab.stoi["<PAD>"]

    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx),
    )

    return loader, dataset


In [14]:

class EncoderForCNN(nn.Module):
    def __init__(self,embedding_size,train_CNN_model=False):
        super(EncoderForCNN,self).__init__()
        self.train_CNN_model=train_CNN_model
        self.inception = models.inception_v3(pretrained=True,aux_logits= False)
        self.inception.fc = nn.Linear(self.inception.fc.in_features,embedding_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self,images):
        features = self.inception(images)
        
        for name, param in self.inception.named_parameters():
            if "fc.weight" in name or "fc.bias" in name:
                param.requires_grad = True
            else:
                # can set to true if we want to train the cnn model
                param.requires_grad = self.train_CNN_model
        
        # return features 
        return self.dropout(self.relu(features))

class DecoderForRNN(nn.Module):
    def __init__(self,embedding_size,hidden_size, vocab_size,num_layers):
        super(DecoderForRNN,self).__init__()
        self.embed = nn.Embedding(vocab_size,embedding_size)
        self.lstm = nn.LSTM(embedding_size,hidden_size,num_layers)
        self.linear = nn.Linear(hidden_size,vocab_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self,features, captions):
        embeddings = self.dropout(self.embed(captions))
        embeddings = torch.cat((features.unsqueeze(0),embeddings),dim = 0)
        hiddens,_ = self.lstm(embeddings)

        outputs = self.linear(hiddens)
        return outputs
    

class CNNtoRNN(nn.Module):
    def __init__(self,embedding_size,hidden_size,vocab_size,num_layers) :
        super(CNNtoRNN,self).__init__()
        self.encoderCNN = EncoderForCNN(embedding_size=embedding_size)
        self.decoderRNN = DecoderForRNN(embedding_size=embedding_size,hidden_size=hidden_size,vocab_size=vocab_size,num_layers=num_layers)

    def forward(self,images,captions):
        features = self.encoderCNN(images)
        outputs = self.decoderRNN(features,captions)
        return outputs

    def caption_image(self,image,vocabulary, max_length=50):
        result_caption =[]

        with torch.no_grad():
            x = self.encoderCNN(image).unsqueeze(0)
            states = None

            for _ in range(max_length):
                hiddens,states = self.decoderRNN.lstm(x,states)
                output = self.decoderRNN.linear(hiddens.squeeze(0))
                predicted = output.argmax(1)

                result_caption.append(predicted.item())
                x = self.decoderRNN.embed(predicted).unsqueeze(0)

                if vocabulary.itos[predicted.item()]=="<EOS>":
                    break

        return [vocabulary.itos[idx] for idx in result_caption]

In [15]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    step = checkpoint["step"]
    return step


# ===========================================================================  #

def train(
    embed_size = 256,
    hidden_size = 256,
    num_layers = 4,
    learning_rate = 3e-4,
    num_epochs = 20,
    load_model = False,
    save_model = True,
    train_CNN = False,
):
    transform = transforms.Compose(
        [
            transforms.Resize((356, 356)),
            transforms.RandomCrop((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    train_loader, dataset = get_loader(
        root_folder="../input/flickr8k/Flickr8K/Flicker8k_Images",
        annotation_file="./train_final.csv",
        transform=transform,
        num_workers=4,
    )

    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Hyperparameters
    # embed_size = 256
    # hidden_size = 256
    vocab_size = len(dataset.vocab)
    # num_layers = 1
    # learning_rate = 3e-4
    # num_epochs = 100

    step = 0

    # initialize model, loss etc
    model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Only finetune the CNN
    for name, param in model.encoderCNN.inception.named_parameters():
        if "fc.weight" in name or "fc.bias" in name:
            param.requires_grad = True
        else:
            param.requires_grad = train_CNN

    if load_model:
        step = load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)

    model.train()

    # print_test(model, device, dataset)
    
    for epoch in range(num_epochs):
        # Uncomment the line below to see a couple of test cases
#         print_examples(model, device, dataset)
        
        if save_model:
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "step": step,
            }
            save_checkpoint(checkpoint)

        losses = []
        for idx, (imgs, captions) in tqdm(
            enumerate(train_loader), total=len(train_loader), leave=False
        ):
            imgs = imgs.to(device)
            captions = captions.to(device)

            outputs = model(imgs, captions[:-1])
            loss = criterion(
                outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
            )

            step += 1

            losses.append(loss)
            optimizer.zero_grad()
            loss.backward(loss)
            optimizer.step()
        
        print("losses : ",losses)

In [16]:
train()

In [17]:
import nltk
from nltk.translate.bleu_score import sentence_bleu


In [27]:
def print_test(model, device, dataset,val_images,captions_df):
    transform = transforms.Compose(
        [
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    model.eval()
    n = 0
    score = 0
    for i in tqdm(val_images.index):
        img_id = val_images.iloc[i]['image']
        temp = captions_df[captions_df['image']==val_images.iloc[i]['image']]
        reference = []
        for j in range(5):
            ref = temp.iloc[j]['caption']
            ref_list = ref.split(" ")
            reference.append(ref_list)
        # get image 
        img = Image.open(os.path.join("../input/flickr8k/Flickr8K/Flicker8k_Images", img_id)).convert("RGB")
        img = transform(img).unsqueeze(0)
        # get caption 
        candidate = model.caption_image(img.to(device),dataset.vocab)
        score += sentence_bleu(reference,candidate)
        n += 1
    
    model.train()
    print("bleu score :",score/n)

In [28]:
def test(
    val_images,
    captions_df,
    embed_size = 256,
    hidden_size = 256,
    num_layers = 4,
    learning_rate = 3e-4,
    num_epochs = 20,
    load_model = True,
    save_model = True,
    train_CNN = False,

):
    transform = transforms.Compose(
        [
            transforms.Resize((356, 356)),
            transforms.RandomCrop((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    train_loader, dataset = get_loader(
        root_folder="../input/flickr8k/Flickr8K/Flicker8k_Images",
        annotation_file="./train_final.csv",
        transform=transform,
        num_workers=4,
    )

    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Hyperparameters
    # embed_size = 256
    # hidden_size = 256
    vocab_size = len(dataset.vocab)
    # num_layers = 1
    # learning_rate = 3e-4
    # num_epochs = 100

    step = 0

    # initialize model, loss etc
    model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Only finetune the CNN
    for name, param in model.encoderCNN.inception.named_parameters():
        if "fc.weight" in name or "fc.bias" in name:
            param.requires_grad = True
        else:
            param.requires_grad = train_CNN

    if load_model:
        step = load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)

    model.train()
#     total_score = 0.0
#     n=0
    print_test(model, device, dataset,val_images,captions_df)
    

In [29]:
test(val_images,captions_df)


In [32]:
import matplotlib.pyplot as plt

In [42]:
def print_test_few(model, device, dataset,val_images,captions_df):
    transform = transforms.Compose(
        [
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    model.eval()
    n = 0
    score = 0
    idx = 0
    for i in tqdm(val_images.index):
        if(idx>10):
            break
        idx += 1
        img_id = val_images.iloc[i]['image']
        temp = captions_df[captions_df['image']==val_images.iloc[i]['image']]
        reference = []
        for j in range(5):
            ref = temp.iloc[j]['caption']
            ref_list = ref.split(" ")
            reference.append(ref_list)
        # get image 
        img = Image.open(os.path.join("../input/flickr8k/Flickr8K/Flicker8k_Images", img_id)).convert("RGB")
        img = transform(img).unsqueeze(0)
        # get caption 
        correct = ""
        for rl in reference:
            for w in rl:
                correct=correct + " "+ w 
        plt.imshow(Image.open(os.path.join("../input/flickr8k/Flickr8K/Flicker8k_Images", img_id)).convert("RGB"))
        print("Correct : "+ correct)
        candidateStr=""
        candidate = model.caption_image(img.to(device),dataset.vocab)
        for c in candidate:
            candidateStr += " "+c
        print("Predicted: " + candidateStr)
        
        score += sentence_bleu(reference,candidate)
        n += 1
        t = input("enter ")
    
    model.train()
    print("bleu score :",score/n)

In [43]:
def test_few(
    val_images,
    captions_df,
    embed_size = 256,
    hidden_size = 256,
    num_layers = 4,
    learning_rate = 3e-4,
    num_epochs = 20,
    load_model = True,
    save_model = True,
    train_CNN = False,

):
    transform = transforms.Compose(
        [
            transforms.Resize((356, 356)),
            transforms.RandomCrop((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

    train_loader, dataset = get_loader(
        root_folder="../input/flickr8k/Flickr8K/Flicker8k_Images",
        annotation_file="./train_final.csv",
        transform=transform,
        num_workers=4,
    )

    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Hyperparameters
    # embed_size = 256
    # hidden_size = 256
    vocab_size = len(dataset.vocab)
    # num_layers = 1
    # learning_rate = 3e-4
    # num_epochs = 100

    step = 0

    # initialize model, loss etc
    model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Only finetune the CNN
    for name, param in model.encoderCNN.inception.named_parameters():
        if "fc.weight" in name or "fc.bias" in name:
            param.requires_grad = True
        else:
            param.requires_grad = train_CNN

    if load_model:
        step = load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)

    model.train()
#     total_score = 0.0
#     n=0
    print_test_few(model, device, dataset,val_images,captions_df)

In [44]:
# test_few(val_images,captions_df)