In [33]:
import pandas as pd
import numpy as np

import os

import torch
import torch.nn as nn
from torchvision import transforms
import torchvision.models as models
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

from tqdm import tqdm

import matplotlib.pyplot as plt
import time

import spacy
from PIL import Image

import warnings
warnings.filterwarnings('ignore')

In [34]:
class EncoderCNN(nn.Module):
    
    def __init__(self, embed_size, train_CNN = False):
        super(EncoderCNN, self).__init__()
        
        self.train_CNN = train_CNN
        
        self.inception = models.inception_v3(pretrained = True, aux_logits = True)
        self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, images):
        features = self.inception(images)[0]
        return self.dropout(self.relu(features))

class DecoderRNN(nn.Module):
    
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        
        self.embed = nn.Embedding(vocab_size,embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, features, captions):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(0), embeddings),dim = 0)
        hiddens,_ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs

class CNNtoRNN(nn.Module):
    
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CNNtoRNN, self).__init__()
        
        self.encoderCNN = EncoderCNN(embed_size)
        self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)
    
    def forward(self, images ,captions):
        features = self.encoderCNN(images)
        outputs = self.decoderRNN(features,captions)
        return outputs
    
    def caption_image(self, image, vocabulary, max_length = 38):
        result_caption = []
        
        with torch.no_grad():
            
            x = self.encoderCNN(image).unsqueeze(0)
            states = None
            
            for _ in range(max_length):
                
                hidden, states = self.decoderRNN.lstm(x,states)
                output = self.decoderRNN.linear(hidden.squeeze(0))
                predicted = torch.argmax(output)
                result_caption.append(predicted.item())
                
                x = self.decoderRNN.embed(predicted).unsqueeze(0) # Prepare the output to input

                if vocabulary.itos[predicted.item()] == '<EOS>':
                    break
            
        return [vocabulary.itos[idx] for idx in result_caption]

In [35]:
transform = transforms.Compose(
        [
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

In [36]:
spacy_eng = spacy.load("en_core_web_sm")

class Vocabulary:
    
    def __init__(self, freq_threshold):
        self.itos = {0:'<PAD>',1:'<SOS>',2:'<EOS>',3:'<UNK>'}
        self.stoi = {'<PAD>':0,'<SOS>':1,'<EOS>':2,'<UNK>':3}
        
        self.freq_threshold = freq_threshold
    
    def __len__(self):
        return len(self.stoi)

    def tokenize_caption(self,text):
        return [token.text.lower() for token in spacy_eng.tokenizer(text)]
    
    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4
        
        for sentence in sentence_list:
            for word in self.tokenize_caption(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1
                
                if frequencies[word] == self.freq_threshold and len(word)>1 and word.isalpha():
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
                    
    
    def convert_to_vector(self, text):
        tokenized_text = self.tokenize_caption(text)
        return [self.stoi[token] if token in self.stoi else self.stoi['<UNK>'] for token in tokenized_text]

In [37]:
class FlickrDataset(Dataset):
    
    def __init__(self, root_dir, captions_file, transform = None, freq_threshold = 10):
        self.root_dir = root_dir
        self.transform = transform
        
        self.df = pd.read_csv(captions_file)
        self.df = self.df[:4000]
        
        self.imgs = self.df['image']
        self.captions = self.df['caption']
        
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir,img_id)).convert('RGB')
        
        if self.transform is not None:
            img = self.transform(img)
        
        caption_vector = [self.vocab.stoi['<SOS>']]
        caption_vector += self.vocab.convert_to_vector(caption)
        caption_vector.append(self.vocab.stoi['<EOS>'])
        
        return img, torch.tensor(caption_vector)

class MyCollate:
    
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
        
    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs,dim = 0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first = False, padding_value = self.pad_idx)
        return imgs,targets

In [38]:
dir_path = 'D:/Dataset/flickr8k/images'

dataset = FlickrDataset(root_dir = dir_path, captions_file = 'D:/Dataset/flickr8k/captions.txt', transform=transform)

pad_idx = dataset.vocab.stoi['<PAD>']

loader = DataLoader(
    dataset=dataset,
    batch_size=64,
    shuffle=True,
    pin_memory=True,
    collate_fn=MyCollate(pad_idx=pad_idx)
)

In [39]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [40]:
new_model = torch.load('entire_model.pth')

In [46]:
new_model.eval()

print('## EXAMPLE 1 ##')
test_img1 = transform(Image.open("child.jpg").convert("RGB")).unsqueeze(0)

print("CORRECT: Child holding a red frisbee outdoors")
print(
    "OUTPUT: "
    + " ".join(new_model.caption_image(test_img1.to(device), dataset.vocab))
)

print()
print('## EXAMPLE 2 ##')
test_img1 = transform(Image.open("boat.png").convert("RGB")).unsqueeze(0)

print("CORRECT: A small boat in the ocean")
print(
    "OUTPUT: "
    + " ".join(new_model.caption_image(test_img1.to(device), dataset.vocab))
)

print()
print('## EXAMPLE 3 ##')
test_img1 = transform(Image.open("dog.jpg").convert("RGB")).unsqueeze(0)

print("CORRECT: Dog on a beach by the ocean")
print(
    "OUTPUT: "
    + " ".join(new_model.caption_image(test_img1.to(device), dataset.vocab))
)

## EXAMPLE 1 ##
CORRECT: Child holding a red frisbee outdoors
OUTPUT: <SOS> <UNK> little girl in <UNK> <UNK> <UNK> runs through the sand at the beach <EOS>

## EXAMPLE 2 ##
CORRECT: A small boat in the ocean
OUTPUT: <SOS> <UNK> man <UNK> <UNK> <UNK> while <UNK> <UNK> <EOS>

## EXAMPLE 3 ##
CORRECT: Dog on a beach by the ocean
OUTPUT: <SOS> <UNK> brown dog <UNK> <UNK> black and white dog <UNK> and <UNK> black dog are all standing in <UNK> field <UNK> <EOS>


In [15]:
device

'cude'