In [148]:
import numpy as np 
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torch import tensor
from sklearn.model_selection import train_test_split
import torchtext
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
import spacy
import torchvision.transforms as transforms
from torchvision.models import resnet34
import random
import matplotlib.pyplot as plt

# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [149]:
# data processing block

''' DATASET LOADING SEGMENT'''
# load csv
train = pd.read_csv('../input/flickr8k/captions.txt')

# add image path column
train['path'] = '../input/flickr8k/Images/' + train['image']

# remove unnecessary image name column
train.drop(columns=['image'], inplace=True)

# create train/val split 
xtrain, xval, ytrain, yval = train_test_split(train['path'], train['caption'], test_size=0.2)
train_data = list(zip(xtrain, ytrain))
val_data = list(zip(xval, yval))

# sort based on length
train_data.sort(key=lambda x: len(x[1]))
val_data.sort(key=lambda x: len(x[1]))

''' TEXT PROCESSING SEGMENT'''
# text processing
tokenizer = get_tokenizer('basic_english')

# vocabulary
def yield_tokens(iterator):
    for _, text in iterator:
        yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(train_data), max_tokens=10000, min_freq=2, specials=['<pad>', '<sos>', '<eos>', '<unk>'], special_first=True)
vocab.set_default_index(vocab['<unk>'])

# text to indices
def text_pipeline(text):
    return vocab(tokenizer(text))

'''IMAGE PROCESSING SEGMENT'''
# define transforms
transform = transforms.Compose([transforms.Resize((224, 224)),
#                                 transforms.RandomCrop((299, 299)),
                                transforms.ToTensor(), 
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# collate_fn
def collate_fn(batch):
    image_list, text_list = [], []
    for image, text in batch:
        image = transform(Image.open(image))
        image_list.append(image)
        text = '<SOS> ' + text + ' <EOS>'
        text = text_pipeline(text)
        text_list.append(tensor(text, dtype=torch.int64))
    image_list = torch.stack(image_list)
    text_list = torch.nn.utils.rnn.pad_sequence(text_list)
    return image_list, text_list

# create dataloader from dataset
train_loader = DataLoader(train_data, batch_size=32, shuffle=False, collate_fn=collate_fn, num_workers=2)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False, collate_fn=collate_fn, num_workers=2)

In [150]:
# # def show_image(img, title=None):
# #     #unnormalize 
# #     img[0] = img[0] * 0.229
# #     img[1] = img[1] * 0.224 
# #     img[2] = img[2] * 0.225 
# #     img[0] += 0.485 
# #     img[1] += 0.456 
# #     img[2] += 0.406
# #     img = img.numpy().transpose((1, 2, 0))
# #     plt.imshow(img)
# #     if title is not None:
# #         plt.title(title)
# #     plt.pause(0.002)
    
# image, label = next(iter(train_loader))
# image = image.to(device)
# for i in image:
#     show_image(i.cpu())
#     print(model.caption_image(i.unsqueeze(0)))
# # next(iter(train_loader))[1].shape
# # model.caption_image(image)

In [151]:
# next(iter(train_loader))[0].shape
print(len(vocab))

In [152]:
# encoder block
class Encoder(nn.Module):
    # output size is english_vocab_size (embedding done by decoder)
    def __init__(self, output_size):
        super().__init__()
        self.model = resnet34(pretrained=True).to(device)
        for param in self.model.parameters():
            param.requires_grad_(False)
        self.model.fc = nn.Linear(self.model.fc.in_features, output_size).to(device)
        self.model.fc.parameters(True)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    
    def forward(self, x):
        output = self.dropout(self.relu(self.model(x)))
        return output.to(device)

In [153]:
class Decoder(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, output_size, num_layers, dropout_prob):
        super().__init__()
        self.embedding = nn.Embedding(input_size, embedding_size)
        self.dropout = nn.Dropout(0.5)
        self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers=num_layers)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, image_features, text):
        embedding = self.dropout(self.embedding(text))
        image_features = image_features.unsqueeze(0) # makes it (1, N, embed_size)
        features = torch.cat((image_features, embedding), dim=0) # makes it # (seq_len, N, embed_size)
        output, _ = self.rnn(features) # (seq_len, N, hidden_size)
        output = self.fc(output) # (seq_len, N, output_size (vocab size))
        return output

In [154]:
# CNN2RNN block

torch. set_printoptions(profile="full")

class CNN2RNN(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, image, text):
        features = self.encoder(image) # input: (N, channels, H, W), output: (N, embed_size)
        outputs = self.decoder(features, text) # input: (N, embed_size) and (seq_len, N), output: (seq_len, N, vocab_size)
        return outputs
    
    def caption_image(self, image, max_len=50):
        result = []
        with torch.no_grad():
            features = self.encoder(image).unsqueeze(0) # (1, 1, embed_size)

            states = None
            
            for _ in range(max_len):
                output, states = self.decoder.rnn(features, states) # output: (1, 1, hidden_size)
#                 print(output)
#                 print(states)
#                 print(output.squeeze(0)[:, 0:10])
                output = self.decoder.fc(output.squeeze(0)) # output: (1, output_size)
#                 print(output[:, 0:10])
                output = output.argmax(1) # output: (1)
                
                result.append(output.item())
                features = self.decoder.embedding(output).unsqueeze(0)# features: (1, 1, embed_size)
                print(features)
                # change features to next
                if vocab.get_itos()[output.item()] == '<eos>':
                    break
                
        return [vocab.get_itos()[item] for item in result]

In [155]:
# training set-up block
epochs = 20
lr=3e-4
hidden_size = 256
embedding_size = 256
input_size = len(vocab)
output_size = len(vocab)
num_layers = 1
dropout_prob = 0
teacher_force_ratio = 1

encoder = Encoder(hidden_size).to(device)
decoder = Decoder(input_size, embedding_size, hidden_size, output_size, num_layers, dropout_prob).to(device)
model = CNN2RNN(encoder, decoder).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss(ignore_index=vocab.get_stoi()['<pad>'])

def check_accuracy(iterator, teacher_force_ratio):
    image, text = next(iter(iterator))
    image = image.to(device)
    text = text.to(device)
    
    output = model(image, text)
    output = output.reshape(-1, output.shape[2])
    text = text.reshape(-1)
    prediction = output.argmax(1)
    
    total = 0
    correct = 0
    L = text.shape[0]
    for i in range(L):
        if text[i] != 0:
            total += 1
        if text[i] == prediction[i]:
            correct += 1
            
    return 100 * correct / total

In [156]:
# training block
check_train_loader = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=collate_fn)
check_val_loader = DataLoader(val_data, batch_size=64, shuffle=True, collate_fn=collate_fn)

for epoch in range(epochs):
    print(f'Epoch: [{epoch} / {epochs}]')
    for i, (image, text) in enumerate(train_loader):
        image = image.to(device) # (N, channels, H, W)
        text = text.to(device) #(seq_len, N)
        
        outputs = model(image, text[:-1])
        
        
        if i % 10 == 0:
            
            print('True Label: ', end = '')
            for num in (text[:, 0]):
                    print(vocab.get_itos()[num.item()] + ' ', end='')
            print()
            
            
            print('Predicted: ', end='')
            print(model.caption_image(image[0:1]))
        
        
        outputs = outputs.reshape(-1, outputs.shape[2])
        text = text.reshape(-1)
        loss = loss_fn(outputs, text)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if i % 200 == 0:
            print(f'Loss: {loss}')
            
            train_accuracy = check_accuracy(check_train_loader, 1)
            print(f'Train Accuracy: {train_accuracy}%')
    
            val_accuracy = check_accuracy(check_val_loader, 0)
            print(f'Validation Accuracy: {val_accuracy}%') 
        
#         print(text.shape)

In [None]:
def show_image(img, title=None):
    #unnormalize 
    img[0] = img[0] * 0.229
    img[1] = img[1] * 0.224 
    img[2] = img[2] * 0.225 
    img[0] += 0.485 
    img[1] += 0.456 
    img[2] += 0.406
    img = img.numpy().transpose((1, 2, 0))
    plt.imshow(img)
    if title is not None:
        plt.title(title)
    plt.pause(0.002)

In [None]:
temp = DataLoader(train_data, batch_size=64, collate_fn=collate_fn, shuffle=True)
image, text = next(iter(temp))
image = image.to(device)
text = text.to(device)
text.shape

In [None]:
for i in range(64):
    print('True Label: ', end = '')
    for num in (text[:, i]):
        if num != 0:
            print(vocab.get_itos()[num.item()] + ' ', end='')
    print()
    print('Predicted Label: ', end='')
    output = model(image, text, 1)
    output = output.argmax(2)
    for num in output[:, i]:
        if num != 0:
            print(vocab.get_itos()[num.item()] + ' ', end='')
        
#     print(output.shape)
    show_image(image[i].cpu())

In [None]:
batch[1].shape