In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import math
import random
from tqdm.notebook import trange, tqdm

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision

from transformers import AutoTokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"

torch.backends.cuda.matmul.allow_tf32 = True

In [None]:
# Define the learning rate for the optimizer
learning_rate = 1e-4

# Image size
image_size = 128

# Define the number of epochs for training
nepochs = 200

# Define the batch size for mini-batch gradient descent
batch_size = 128

# Define the root directory of the dataset
data_set_root = '../Datasets/ROCO2'
train_captions_file = os.path.join(data_set_root, 'train_captions.csv')
test_captions_file = os.path.join(data_set_root, 'test_captions.csv')
train_image_path = os.path.join(data_set_root, 'train_images/train/')
test_image_path = os.path.join(data_set_root, 'test_images/test/')

# ## Data processing and Tokenization

class SampleCaption(nn.Module):
    def __call__(self, sample):
        rand_index = random.randint(0, len(sample) - 1)
        return sample[rand_index]

train_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.RandomCrop(image_size),
    transforms.AutoAugment(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Create a custom dataset for ROCO 2
class ROCO2Dataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, captions_file, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.captions_df = pd.read_csv(captions_file)

    def __len__(self):
        return len(self.captions_df)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.captions_df.iloc[idx, 0])
        image = torchvision.io.read_image(img_name)
        captions = self.captions_df.iloc[idx, 1].split('|')  # Assuming captions are separated by '|'

        if self.transform:
            image = self.transform(image)

        return image, captions

train_dataset = ROCO2Dataset(train_image_path, train_captions_file, transform=train_transform)
eval_dataset = ROCO2Dataset(test_image_path, test_captions_file, transform=transform)

data_loader_train = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
data_loader_eval = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

# Create a dataloader iterable object
dataiter = next(iter(data_loader_eval))
test_images, test_captions = dataiter

index = 0
# # Let's visualize an entire batch of images!
# plt.figure(figsize=(3, 3))
# out = torchvision.utils.make_grid(test_images[index].unsqueeze(0), 1, normalize=True)
# _ = plt.imshow(out.numpy().transpose((1, 2, 0)))

caption = test_captions[index]
print(caption)

# We'll use a pre-built Tokenizer for the BERT Model
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

tokens = tokenizer(test_captions, padding=True, truncation=True, return_tensors="pt")

token_ids = tokens['input_ids'][index]
print(tokenizer.decode(token_ids))

class TokenDrop(nn.Module):
    def __init__(self, prob=0.1, blank_token=1, eos_token=102):
        self.prob = prob
        self.eos_token = eos_token
        self.blank_token = blank_token

    def __call__(self, sample):
        mask = torch.bernoulli(self.prob * torch.ones_like(sample)).long()
        can_drop = (~(sample == self.eos_token)).long()
        mask = mask * can_drop
        mask[:, 0] = torch.zeros_like(mask[:, 0]).long()
        replace_with = (self.blank_token * torch.ones_like(sample)).long()
        sample_out = (1 - mask) * sample + mask * replace_with
        return sample_out

def extract_patches(image_tensor, patch_size=16):
    bs, c, h, w = image_tensor.size()
    unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)
    unfolded = unfold(image_tensor)
    unfolded = unfolded.transpose(1, 2).reshape(bs, -1, c * patch_size * patch_size)
    return unfolded

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class Decoder(nn.Module):
    def __init__(self, num_emb, hidden_size=128, num_layers=3, num_heads=4):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(num_emb, hidden_size)
        self.embedding.weight.data = 0.001 * self.embedding.weight.data
        self.pos_emb = SinusoidalPosEmb(hidden_size)
        decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_size, nhead=num_heads,
                                                   dim_feedforward=hidden_size * 4, dropout=0.0,
                                                   batch_first=True)
        self.decoder_layers = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(hidden_size, num_emb)

    def forward(self, input_seq, encoder_output, input_padding_mask=None,
                encoder_padding_mask=None):
        input_embs = self.embedding(input_seq)
        bs, l, h = input_embs.shape
        seq_indx = torch.arange(l, device=input_seq.device)
        pos_emb = self.pos_emb(seq_indx).reshape(1, l, h).expand(bs, l, h)
        embs = input_embs + pos_emb
        causal_mask = torch.triu(torch.ones(l, l, device=input_seq.device), 1).bool()
        output = self.decoder_layers(tgt=embs, memory=encoder_output, tgt_mask=causal_mask,
                                     tgt_key_padding_mask=input_padding_mask,
                                     memory_key_padding_mask=encoder_padding_mask)
        return self.fc_out(output)

class VisionEncoder(nn.Module):
    def __init__(self, image_size, channels_in, patch_size=16, hidden_size=128, num_layers=3, num_heads=4):
        super(VisionEncoder, self).__init__()
        self.patch_size = patch_size
        self.fc_in = nn.Linear(channels_in * patch_size * patch_size, hidden_size)
        seq_length = (image_size // patch_size) ** 2
        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_size).normal_(std=0.02))
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=num_heads,
                                                   dim_feedforward=hidden_size * 4, dropout=0.0,
                                                   batch_first=True)
        self.encoder_layers = nn.TransformerEncoder(encoder_layer, num_layers)

    def forward(self, image):
        bs = image.shape[0]
        patch_seq = extract_patches(image, patch_size=self.patch_size)
        patch_emb = self.fc_in(patch_seq)
        embs = patch_emb + self.pos_embedding
        output = self.encoder_layers(embs)
        return output

class VisionEncoderDecoder(nn.Module):
    def __init__(self, image_size, channels_in, num_emb, patch_size=16,
                 hidden_size=128, num_layers=(3, 3), num_heads=4):
        super(VisionEncoderDecoder, self).__init__()
        self.encoder = VisionEncoder(image_size=image_size, channels_in=channels_in, patch_size=patch_size,
                                     hidden_size=hidden_size, num_layers=num_layers[0], num_heads=num_heads)
        self.decoder = Decoder(num_emb=num_emb, hidden_size=hidden_size,
                               num_layers=num_layers[1], num_heads=num_heads)

    def forward(self, input_image, target_seq, padding_mask):
        bool_padding_mask = padding_mask == 0
        encoded_seq = self.encoder(image=input_image)
        decoded_seq = self.decoder(input_seq=target_seq,
                                   encoder_output=encoded_seq,
                                   input_padding_mask=bool_padding_mask)
        return decoded_seq



In [None]:
# ## Initialise Model and Optimizer

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define number of unique tokens (vocabulary size)
num_tokens = tokenizer.vocab_size

# Initialize the model
model = VisionEncoderDecoder(image_size=image_size,
                             channels_in=3,  # RGB channels
                             num_emb=num_tokens,
                             hidden_size=128,
                             num_layers=(3, 3),
                             num_heads=4)


In [None]:
# Move model to device
model.to(device)

# Initialize the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Loss Function
criterion = nn.CrossEntropyLoss()

# Directory for saving checkpoints
checkpoint_dir = './checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# ## Training Loop
# Start training
for epoch in range(nepochs):
    model.train()
    running_loss = 0.0
    for i, (images, captions) in tqdm(enumerate(data_loader_train)):
        optimizer.zero_grad()

        # Convert the images and captions to tensors and move to device
        images = images.to(device)
        target_seq = tokenizer(captions, padding=True, truncation=True, return_tensors="pt")['input_ids'].to(device)
        padding_mask = (target_seq == 0)

        # Forward pass
        outputs = model(input_image=images, target_seq=target_seq[:, :-1], padding_mask=padding_mask[:, 1:])
        
        # Compute loss
        loss = criterion(outputs.reshape(-1, outputs.size(-1)), target_seq[:, 1:].reshape(-1))
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        if (i + 1) % 10 == 0:  # Print every 10 batches
            print(f"Epoch [{epoch + 1}/{nepochs}], Step [{i + 1}/{len(data_loader_train)}], Loss: {running_loss / 10:.4f}")
            running_loss = 0.0

    # Save checkpoint at the end of each epoch
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch + 1}.pth')
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss.item(),  # Optionally save the loss if needed
    }, checkpoint_path)

    print(f"Checkpoint saved: {checkpoint_path}")

In [None]:
# ## Evaluation

model.eval()
with torch.no_grad():
    total_loss = 0.0
    for images, captions in data_loader_eval:
        images = images.to(device)
        target_seq = tokenizer(captions, padding=True, truncation=True, return_tensors="pt")['input_ids'].to(device)
        padding_mask = (target_seq == 0)

        outputs = model(input_image=images, target_seq=target_seq[:, :-1], padding_mask=padding_mask[:, 1:])
        
        loss = criterion(outputs.reshape(-1, outputs.size(-1)), target_seq[:, 1:].reshape(-1))
        total_loss += loss.item()

    print(f"Evaluation Loss: {total_loss / len(data_loader_eval):.4f}")
