In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms, models

from nltk.tokenize import sent_tokenize, word_tokenize
import nltk

from sklearn.model_selection import train_test_split

from collections import Counter
import math

from PIL import Image

from datasets import load_dataset
import os

from transformers import ViTModel

In [2]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

# Data extraction

In [3]:
images_path = '/kaggle/input/flickr8kimagescaptions/flickr8k/images'
captions_file = '/kaggle/input/flickr8kimagescaptions/flickr8k/captions.txt'

imgs = []
prompts = []

with open(captions_file, 'r') as f:
    next(f)
    for i, line in enumerate(tqdm(f)):
        if i == 40000:
            break
        if i % 3 != 0:
            continue
        image_name, caption = line.strip().split(',', 1)
        
        image_path = os.path.join(images_path, image_name)
        
        img = Image.open(image_path).convert('RGB')
        
        imgs.append(img)
        prompts.append(caption)

0it [00:00, ?it/s]

In [4]:
vocab = set(['<bos>', '<eos>', '<unk>', '<pad>'])

for prompt in prompts:
    for word in word_tokenize(prompt):
        vocab.add(word.lower())

In [5]:
len(vocab)

5526

In [6]:
word2ind = {word: i for i, word in enumerate(vocab)}
ind2word = {i: word for word, i in word2ind.items()}

In [7]:
def preprocess_img(img):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img = transform(img)
    return img

def preprocess_prompt(prompt):
    tokenized_prompt = [word2ind['<bos>']]
    tokenized_prompt += [word2ind.get(token.lower(), word2ind['<unk>']) for token in word_tokenize(prompt)]
    tokenized_prompt += [word2ind['<eos>']]
    
    return tokenized_prompt

In [8]:
class Img2textDataset(Dataset):
    def __init__(self, imgs, prompts):
        self.imgs = imgs
        self.prompts = prompts
        
    def __getitem__(self, idx):
        return preprocess_img(self.imgs[idx]), preprocess_prompt(self.prompts[idx])
    
    def __len__(self):
        return len(self.imgs)

In [9]:
def collate_fn_with_padding(batch):
    imgs, prompts = zip(*batch)
    imgs = torch.stack(imgs, dim=0)
    
    seq_lens = [len(prompt) for img, prompt in batch]
    max_seq_len = max(seq_lens)
    
    padded_prompts = [prompt + [word2ind['<pad>']] * (max_seq_len - len(prompt)) for prompt in prompts]
    padded_prompts = torch.LongTensor(padded_prompts)
    
    return imgs, padded_prompts[:,:-1], padded_prompts[:,1:]

In [10]:
train_imgs, test_imgs, train_prompts, test_prompts = train_test_split(imgs, prompts, test_size=0.2)

train_dataset = Img2textDataset(train_imgs, train_prompts)
test_dataset = Img2textDataset(test_imgs, test_prompts)

BATCH_SIZE = 64

train_dataloader = DataLoader(
    train_dataset, collate_fn=collate_fn_with_padding, batch_size=BATCH_SIZE, shuffle=True
)

test_dataloader = DataLoader(
    test_dataset, collate_fn=collate_fn_with_padding, batch_size=BATCH_SIZE
)

# Model training

In [218]:
class Encoder(nn.Module):
    def __init__(self, embed_dim, num_unfrozen_layers=2):
        super(Encoder, self).__init__()
        
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        
        for param in self.vit.parameters():
            param.requires_grad = False
            
        if num_unfrozen_layers > 0:
            for layer in self.vit.encoder.layer[-num_unfrozen_layers:]:
                for param in layer.parameters():
                    param.requires_grad = True
        
        self.fc = nn.Linear(self.vit.config.hidden_size, embed_dim)
        
    def forward(self, x):
        with torch.no_grad():
            outputs = self.vit(x)
            x = outputs.last_hidden_state[:, 0]
        
        x = self.fc(x)
        return x

In [219]:
class Decoder(nn.Module):
    def __init__(self, hidden_dim, vocab_size, num_layers):
        super(Decoder, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.rnn = nn.LSTM(hidden_dim, hidden_dim, num_layers=num_layers, dropout=0.5, batch_first=True)
        self.linear = nn.Linear(hidden_dim, hidden_dim)
        self.projection = nn.Linear(hidden_dim, vocab_size)
        self.non_lin = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, input_batch, hidden_state=None):
        embeddings = self.embedding(input_batch)
        
        if hidden_state is None:
            output, hidden_state = self.rnn(embeddings)
        else:
            output, hidden_state = self.rnn(embeddings, hidden_state)

        output = self.dropout(self.linear(self.non_lin(output)))
        projection = self.projection(output)
        return projection, hidden_state

In [220]:
class Img2textModel(nn.Module):
    def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers):
        super(Img2textModel, self).__init__()
        
        self.encoder = Encoder(embed_dim)
        self.decoder = Decoder(hidden_dim, vocab_size, num_layers)
        self.init_hidden_layer = nn.Linear(embed_dim, hidden_dim)
        self.init_cell_layer = nn.Linear(embed_dim, hidden_dim)
        
        self.num_layers = num_layers
        
    def forward(self, images, captions):
        features = self.encoder(images)  
        
        hidden_state = self.init_hidden_layer(features)
        cell_state = self.init_cell_layer(features)

        hidden_state = hidden_state.unsqueeze(0).repeat(self.num_layers, 1, 1)
        cell_state = cell_state.unsqueeze(0).repeat(self.num_layers, 1, 1)
        
        outputs, _ = self.decoder(captions, (hidden_state, cell_state))
        
        return outputs

In [14]:
def validate_model(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0

    with torch.no_grad():
        for images, inputs, targets in dataloader:
            images = images.to(device)
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            outputs = model(images, inputs)
            
            loss = criterion(outputs.view(-1, outputs.size(2)), targets.reshape(-1))
            
            running_loss += loss.item()

    validation_loss = running_loss / len(dataloader)
    print(f'Validation Loss: {validation_loss:.4f}')
    model.train()

In [15]:
def train(model, train_dataloader, test_dataloader, criterion, optimizer, num_epochs, device):
    model = model.to(device)
    model.train()
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        
        with tqdm(total=len(train_dataloader), desc="Training", unit="batch") as pbar:
            for i, (images, inputs, targets) in enumerate(train_dataloader):
                images = images.to(device)
                inputs = inputs.to(device)
                targets = targets.to(device)

                outputs = model(images, inputs) # (batch_size, seq_len, vocab_size)

                loss = criterion(outputs.view(-1, outputs.size(2)), targets.reshape(-1))

                optimizer.zero_grad()

                loss.backward()

                optimizer.step()

                running_loss += loss.item()

                if (i + 1) % 10 == 0:
                    pbar.set_postfix({'Loss': running_loss / (i + 1)})
    
                pbar.update(1)
                
        epoch_loss = running_loss / len(train_dataloader)
        print(f'Epoch {epoch + 1} Loss: {epoch_loss:.4f}')
        
        validate_model(model, test_dataloader, criterion, device)
        
    model.to('cpu')

In [221]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Img2textModel(embed_dim=1024, hidden_dim=512, vocab_size=len(vocab), num_layers=1)
criterion = nn.CrossEntropyLoss(ignore_index=word2ind['<pad>'])
optimizer = torch.optim.Adam(model.parameters())

In [222]:
num_params = sum(p.numel() for p in model.parameters())
param_size_bytes = 4
total_size_bytes = num_params * param_size_bytes
total_size_megabytes = total_size_bytes / (1024 ** 2)
print(f"Model size: {total_size_megabytes:.2f} MB")

Model size: 367.18 MB


In [224]:
train(model, train_dataloader, test_dataloader, criterion, optimizer, 1, device)

Training:   0%|          | 0/167 [00:00<?, ?batch/s]

Epoch 1 Loss: 2.3674
Validation Loss: 3.0732


min_val_loss: 3.06

# Testing

In [19]:
def generate_caption(model, img, word2ind, ind2word, max_length=50, temperature=0.5, device='cpu'):
    model.to(device)
    model.eval()
    
    img = preprocess_img(img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        features = model.encoder(img)
        
        hidden_state = model.init_hidden_layer(features)
        cell_state = model.init_cell_layer(features)
        num_layers = model.num_layers
        hidden_state = hidden_state.unsqueeze(0).repeat(num_layers, 1, 1)
        cell_state = cell_state.unsqueeze(0).repeat(num_layers, 1, 1)
        
        input_token = torch.LongTensor([word2ind['<bos>']]).unsqueeze(0).to(device)
        
        generated_tokens = []
        
        for _ in range(max_length):
            output, (hidden_state, cell_state) = model.decoder(input_token, (hidden_state, cell_state))
            
            output = output.squeeze(1) / temperature
            probabilities = torch.softmax(output, dim=-1)
            
            next_token = torch.multinomial(probabilities, 1).item()
            
            if next_token == word2ind['<eos>']:
                break
            
            generated_tokens.append(next_token)
            
            
            input_token = torch.LongTensor([next_token]).unsqueeze(0).to(device)
    
    generated_caption = ' '.join([ind2word[idx] for idx in generated_tokens if idx in ind2word])
    model.to('cpu')
    return generated_caption

In [20]:
def rebuild_img(img):
    inv_normalize = transforms.Normalize(
        mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
        std=[1 / 0.229, 1 / 0.224, 1 / 0.225]
    )

    img_ = inv_normalize(img)

    img_ = img_.permute(1, 2, 0).numpy()

    img_ = np.clip(img_, 0, 1)
    
    return img_

In [None]:
img = Image.open('/kaggle/input/coco-2017-dataset/coco2017/train2017/000000000731.jpg')
img

In [282]:
generate_caption(model, img, word2ind, ind2word, device=device)

'a man and a woman in a field of grass and a woman in a field .'

In [128]:
for prompt in prompts:
    if 'front of a crowd of people' in prompt:
        print(prompt)

A black race car starts up in front of a crowd of people .
A person wearing a black coat and a helmet with spiky things is posing on the lawn in front of a crowd of people .
