In [None]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import seaborn as sns

In [None]:
import nltk
import pickle
from PIL import Image
from collections import Counter
import matplotlib.pyplot as plt
from torch.utils.data import Subset

import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms
import torchvision.models as models
import torchvision.transforms as transforms
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import random_split
from torchtext.data.metrics import bleu_score

nltk.download('punkt')

# Image Captioning 

In [None]:
def read_file(file_name, text_dir):
    with open(os.path.join(text_dir, file_name), 'rb') as files:
        lines = files.read().splitlines()
    return lines

def map_imgs():
     img_cap_dict={}
     for caption in captions:
        caption = caption.decode("utf-8")
        image_name = caption.split('#')[0]
        image_caption = caption.split('#')[1].split('\t')[1]
        if image_name not in img_cap_dict.keys():
            img_cap_dict[image_name] = [image_caption]
     return img_cap_dict


text_dir = '/kaggle/input/flickr8k/Flickr8K/Flickr8k_text'
train_image_paths = read_file('Flickr_8k.trainImages.txt', text_dir)
test_image_paths = read_file('Flickr_8k.testImages.txt', text_dir)
captions = read_file('Flickr8k.token.txt', text_dir)
img_cap_dict=map_imgs()

In [None]:
del img_cap_dict['2258277193_586949ec62.jpg.1'] # since there are odd number of images in the dataset (not a multiple of five)

In [None]:
class Vocab(object):
    def __init__(self):
        self.w2i={}
        self.i2w={}
        self.index=0
    
    def __call__(self,token):
        if not token in self.w2i:
            return self.w2i['<ukn>']
        return self.w2i[token]
    
    def __len__(self):
        return len(self.w2i)
    
    def add_token(self,token):
        if token not in self.w2i:
            self.w2i[token]=self.index
            self.i2w[self.index]=token
            self.index+=1
            
def build_vocabulary(map):
    counter=Counter()
    ids=map.keys()
    for i,id in enumerate(ids):
        captions=map[id]
        for caption in captions:
            tokens = nltk.tokenize.word_tokenize(caption.lower())
            counter.update(tokens)
    tokens = [token for token, cnt in counter.items()]
    vocab = Vocab()
    vocab.add_token('<pad>')
    vocab.add_token('<start>')
    vocab.add_token('<end>')
    vocab.add_token('<unk>')
    for i, token in enumerate(tokens):
        vocab.add_token(token)
    return vocab

In [None]:
vocab = build_vocabulary(img_cap_dict)
vocab_path = '/kaggle/working/vocabulary.pkl'
with open(vocab_path, 'wb') as f:
    pickle.dump(vocab, f)

## Reshape Images

In [None]:
def reshape_images(input_path, output_path, shape):
    if not os.path.exists(output_path):
        os.makedirs(output_path)
        
    images = os.listdir(input_path)
    num_im = len(images)
    
    for i, im in enumerate(images):
        with open(os.path.join(input_path, im), 'rb') as f:
            with Image.open(f) as image:
                image = image.resize(shape, Image.ANTIALIAS)
                image.save(os.path.join(output_path, im), image.format)

input_path = '/kaggle/input/flickr8k/Flickr8K/Flicker8k_Images/'
output_path = '/kaggle/working/Flickr8K/resized_images/'
image_shape = [224, 224]
reshape_images(input_path, output_path, image_shape)

## Data Loader

In [None]:
class FlickrDataLoader(data.Dataset):
    def __init__(self,data_path,map, vocabulary,transform=None):
        self.root = data_path
        self.indices = list(map.keys())
        self.vocabulary = vocabulary
        self.transform = transform
        self.map=map
    
    def __getitem__(self, idx):
        vocabulary = self.vocabulary
        id = self.indices[idx]
        captions = self.map[id] 
        image = Image.open(os.path.join(self.root,id)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        caption = []
        caption.append(vocabulary('<start>'))
        for cap in captions:
            word_tokens = nltk.tokenize.word_tokenize(str(cap).lower())
            caption.extend([vocabulary(token) for token in word_tokens])       
        
        caption.append(vocabulary('<end>'))
        image_caption = torch.Tensor(caption)
        return image,image_caption
 
    def __len__(self):
        return len(self.indices)
def collate_function(data_batch):
    data_batch.sort(key=lambda d: len(d[1]), reverse=True)
    imgs, caps = zip(*data_batch)
    imgs = torch.stack(imgs, 0)
    cap_lens = [len(cap) for cap in caps]
    tgts = torch.zeros(len(caps), max(cap_lens)).long()
    for i, cap in enumerate(caps):
        end = cap_lens[i]
        tgts[i, :end] = cap[:end]        
    return imgs, tgts, cap_lens
 
def get_loader(data_path, map, vocabulary, transform, batch_size, shuffle, num_workers):
    flicker_dataset = FlickrDataLoader(data_path=data_path,map=map,vocabulary=vocabulary,transform=transform)
    train_ds, test_ds = random_split(flicker_dataset, [7091,1000])
    train_data_loader = torch.utils.data.DataLoader(dataset=train_ds, batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,collate_fn=collate_function)
    test_data_loader = torch.utils.data.DataLoader(dataset=test_ds, batch_size=1,shuffle=shuffle,num_workers=num_workers,collate_fn=collate_function)
    return train_data_loader,test_data_loader

# Models

In [None]:
class CNNModel(nn.Module):
    def __init__(self, embedding_size):
        super(CNNModel, self).__init__()
        resnet = models.resnet152(pretrained=True)
        module_list = list(resnet.children())[:-1]
        self.resnet_module = nn.Sequential(*module_list)
        self.linear_layer = nn.Linear(resnet.fc.in_features, embedding_size)
        self.batch_norm = nn.BatchNorm1d(embedding_size, momentum=0.01)
        
    def forward(self, input_images):
        with torch.no_grad():
            resnet_features = self.resnet_module(input_images)
        resnet_features = resnet_features.reshape(resnet_features.size(0), -1)
        final_features = self.batch_norm(self.linear_layer(resnet_features))
        return final_features

In [None]:
class LSTMModel(nn.Module):
    def __init__(self, embedding_size, hidden_layer_size, vocabulary_size, num_layers, max_seq_len=20):
        super(LSTMModel, self).__init__()
        self.embedding_layer = nn.Embedding(vocabulary_size, embedding_size)
        self.lstm_layer = nn.LSTM(embedding_size, hidden_layer_size, num_layers, batch_first=True)
        self.linear_layer = nn.Linear(hidden_layer_size, vocabulary_size)
        self.max_seq_len = max_seq_len
        
    def forward(self, input_features, capts, lens):
        embeddings = self.embedding_layer(caps)
        embeddings = torch.cat((input_features.unsqueeze(1), embeddings), 1)
        lstm_input = pack_padded_sequence(embeddings, lens, batch_first=True) 
        hidden_variables, _ = self.lstm_layer(lstm_input)
        model_outputs = self.linear_layer(hidden_variables[0])
        return model_outputs
    
    def sample(self, input_features, lstm_states=None):
        sampled_indices = []
        lstm_inputs = input_features.unsqueeze(1)
        for i in range(self.max_seq_len):
            hidden_variables, lstm_states = self.lstm_layer(lstm_inputs, lstm_states)
            model_outputs = self.linear_layer(hidden_variables.squeeze(1))
            _, predicted_outputs = model_outputs.max(1)
            sampled_indices.append(predicted_outputs)
            lstm_inputs = self.embedding_layer(predicted_outputs)
            lstm_inputs = lstm_inputs.unsqueeze(1)
        sampled_indices = torch.stack(sampled_indices, 1)
        return sampled_indices

## Training

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if not os.path.exists('/kaggle/working/models/'):
    os.makedirs('/kaggle/working/models/')

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

with open('/kaggle/working/vocabulary.pkl', 'rb') as f:
    vocabulary = pickle.load(f)

train_data_loader,test_data_loader = get_loader('/kaggle/working/Flickr8K/resized_images/', img_cap_dict, vocabulary, transform, 128,shuffle=True, num_workers=0) 

encoder_model = CNNModel(512).to(device)
decoder_model = LSTMModel(512, 512, len(vocabulary), 1).to(device) 
 
loss_criterion = nn.CrossEntropyLoss()
parameters = list(decoder_model.parameters()) + list(encoder_model.linear_layer.parameters()) + list(encoder_model.batch_norm.parameters())
optimizer = torch.optim.Adam(parameters, lr=0.001)


total_num_steps = len(train_data_loader)
for epoch in range(10):
    for i, (imgs, caps, lens) in enumerate(train_data_loader):
        imgs = imgs.to(device)
        caps = caps.to(device)
        tgts = pack_padded_sequence(caps, lens, batch_first=True)[0] 
        feats = encoder_model(imgs)
        outputs = decoder_model(feats, caps, lens)
        loss = loss_criterion(outputs, tgts)
        decoder_model.zero_grad()
        encoder_model.zero_grad()
        loss.backward()
        optimizer.step()
 
        if i % 10 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                  .format(epoch, 5, i, total_num_steps, loss.item())) 
 


### Save Checkpoint

In [None]:
torch.save(decoder_model.state_dict(), os.path.join('/kaggle/working/models/', 'decoder-{}.ckpt'.format(epoch+1)))
torch.save(encoder_model.state_dict(), os.path.join('/kaggle/working/models/', 'encoder-{}.ckpt'.format(epoch+1)))

In [None]:
sum([7091, 1000])


## Prediction and BLEU Score

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

with open('/kaggle/working/vocabulary.pkl', 'rb') as f:
    vocabulary = pickle.load(f)

encoder_model = CNNModel(512).eval()
decoder_model = LSTMModel(512, 512, len(vocabulary), 1)
encoder_model = encoder_model.to(device)
decoder_model = decoder_model.to(device)

encoder_model.load_state_dict(torch.load('/kaggle/working/models/encoder-10.ckpt'))
decoder_model.load_state_dict(torch.load('/kaggle/working/models/decoder-10.ckpt'))

candidate_corpus=[]
reference_corpus=[]

In [None]:
for i, (imgs, caps, lens) in enumerate(test_data_loader):
    imgs = imgs.to(device)
    feat = encoder_model(imgs)
    sampled_indices = decoder_model.sample(feat)
    sampled_indices = sampled_indices.cpu().numpy()
    caps=caps.numpy()
    predicted_caption = []
    target_caption=[]
    for token_index in sampled_indices:
        for tokens in token_index:
            word = vocabulary.i2w[tokens]
            if word=='<unk>':
                continue
            if word=='<end>':
                continue
            if word == '<start>':
                continue
            if word== '.' :
                continue            
            predicted_caption.append(word)
    for token_index in caps:
        for tokens in token_index:
            word = vocabulary.i2w[tokens]
            if word=='<unk>':
                continue
            if word=='<end>':
                continue
            if word == '<start>':
                continue
            if word== '.' :
                continue 
            target_caption.append(word)
            
    candidate_corpus.append(predicted_caption)
    reference_corpus.append([target_caption])

In [None]:
print(bleu_score(candidate_corpus, reference_corpus))

## Testing

In [None]:
def predict_caption(image_file_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    def load_image(image_file_path, transform=None):
        img = Image.open(image_file_path).convert('RGB')
        img = img.resize([224, 224], Image.LANCZOS)

        if transform is not None:
            img = transform(img).unsqueeze(0)

        return img

    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

    with open('/kaggle/working/vocabulary.pkl', 'rb') as f:
        vocabulary = pickle.load(f)

    encoder_model = CNNModel(512).eval()
    decoder_model = LSTMModel(512, 512, len(vocabulary), 1)
    encoder_model = encoder_model.to(device)
    decoder_model = decoder_model.to(device)

    encoder_model.load_state_dict(torch.load('/kaggle/working/models/encoder-10.ckpt'))
    decoder_model.load_state_dict(torch.load('/kaggle/working/models/decoder-10.ckpt'))

    img = load_image(image_file_path, transform)
    img_tensor = img.to(device)

    feat = encoder_model(img_tensor)
    sampled_indices = decoder_model.sample(feat)
    sampled_indices = sampled_indices[0].cpu().numpy()
    predicted_caption = []
    for token_index in sampled_indices:
        word = vocabulary.i2w[token_index]
        predicted_caption.append(word)
        if word == '<end>':
            break
    predicted_sentence = ' '.join(predicted_caption)
    
    return predicted_sentence

In [None]:
def predict_captions(df):
    output_file = open("/kaggle/working/output_captions.txt", "w")
    output_file.write("image,caption\n")
    for i in tqdm(df.index):
        img_file = df['image'][i]
        pred_caption = predict_caption('/kaggle/input/flickr8k/Flickr8K/Flicker8k_Images/' + img_file)
        output_file.write(img_file + "," + pred_caption + "\n")
    output_file.close()

In [None]:
data_set = '/kaggle/input/flickr8k/Flickr8K/'
captions_dir = data_set + "Flickr8k_text/"
train_images = captions_dir + 'Flickr_8k.trainImages.txt'
test_images = captions_dir + 'Flickr_8k.testImages.txt'

In [None]:
def load_all_captions(file_name):
    text_file = open(file_name, "r")
    lines = text_file.readlines()
    data_set = []
    for l in range(len(lines)):
        line = lines[l].strip()
        image_name = line[:line.find("#")]
        caption_number = line[line.find("#")+1:line.find("#")+2]
        caption = line[line.find("\t")+1:]
        data_set.append([image_name, caption_number, caption])
    return pd.DataFrame(data_set, columns =['image', 'caption#', 'caption'])

In [None]:
captions_df = load_all_captions(captions_dir + "Flickr8k.token.txt")
captions_df.head()

In [None]:
def load_test_image(file_name):
    text_file = open(file_name, "r")
    lines = text_file.readlines()
    data_set = []
    for l in range(len(lines)):
        line = lines[l].strip()
        data_set.append(line)
    return pd.DataFrame(data_set, columns =['image'])

In [None]:
test_df = load_test_image('/kaggle/input/flickr8k/Flickr8K/Flickr8k_text/Flickr_8k.testImages.txt')
test_df.head()

In [None]:
def get_ground_captions(test_df, captions_df):
    new_df = pd.DataFrame(columns = ['image', 'caption'])
    for i in tqdm(test_df.index):
        temp = captions_df[captions_df['image']==test_df.iloc[i]['image']]
        for j in range(5):
            new_df = new_df.append({'image': test_df.iloc[i]['image'], 'caption': temp.iloc[j]['caption']}, ignore_index = True)
    return new_df

In [None]:
test_captions_df = get_ground_captions(test_df, captions_df)
test_captions_df.head()

In [None]:
predict_captions(test_captions_df)