In [None]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

from PIL import Image

In [None]:
import os
import cv2
from matplotlib import pyplot as plt

# Set the paths to the dataset directories
data_dir = "."
dataset_dir =  data_dir + "/Images"
annotations_path = data_dir + "/captions.txt"

# Load annotations
# with open(annotations_path, 'r') as file:
#     annotations = file.readlines()
with open(annotations_path, 'r') as file:
    captions_doc = file.read()

# # Process annotations to create a dictionary mapping image filenames to their captions
# image_captions = {}
# print(annotations[1])
# for annotation in annotations[1:]:
#     parts = annotation.strip().split(',')
#     image_filename = parts[0]
#     image_filename = image_filename
#     caption = parts[1]
#     if image_filename not in image_captions:
#         image_captions[image_filename] = []
#     image_captions[image_filename].append(caption)

# # Load images

# Images = []
# for image_filename in image_captions.keys():
#     image_path = os.path.join(dataset_dir, image_filename)
#     Images.append(image_path)

In [None]:
# Device configuration
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available() # For macOS
    else "cpu"
)

print(f"Using {device}")

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, seq_len, embedding_dim):
        super(PositionalEncoding, self).__init__()
        self.encoding = self.generate_positional_encoding(seq_len, embedding_dim)

    def generate_positional_encoding(self, seq_len, embedding_dim):
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * -(torch.log(torch.tensor(10000.0)) / embedding_dim))
        pe = torch.zeros(seq_len, embedding_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # Add batch dimension
        return pe

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)].to(device)


class EncoderDecoderBlock(nn.Module):

    def __init__(self, dim, n_self_heads, n_cross_heads, mlp_ratio=4, p_dropout=0.5):
        super(EncoderDecoderBlock, self).__init__()

        self.dim = dim
        self.n_self_heads = n_self_heads
        self.n_cross_heads = n_cross_heads
        self.p_dropout = p_dropout
        self.mlp_ratio = mlp_ratio
        self.norm1 = nn.LayerNorm(self.dim)
        self.norm2 = nn.LayerNorm(self.dim)
        self.norm3 = nn.LayerNorm(self.dim)
        self.cross_attention = nn.MultiheadAttention(self.dim, self.n_cross_heads, dropout=self.p_dropout, batch_first=True).to(device)
        self.first_attention = nn.MultiheadAttention(self.dim, self.n_self_heads, dropout=self.p_dropout, batch_first=True).to(device)
        self.MLP = nn.Sequential(
            nn.Linear(self.dim, self.dim * mlp_ratio),
            nn.ReLU(),
            nn.Dropout(self.p_dropout),
            nn.Linear(self.dim * mlp_ratio, self.dim)
        )

    def forward(self, x, features, attn_mask, key_mask):
        """
        x : [n_samples, n_patches + 1, embedding_dim]
        output : [n_samples, n_patches + 1, embedding_dim]
        """
        attention_out, attn1_weights = self.first_attention(x, x, x, attn_mask=attn_mask, key_padding_mask=key_mask)
        first_out = self.norm1(attention_out + x)
        cross_attention, attn2_weights = self.cross_attention(first_out.to(device), features.to(device), features.to(device))
        second_out = self.norm2(first_out + cross_attention)
        mlp_out = self.MLP(second_out)
        output = self.norm3(mlp_out + second_out)
        # output = mlp_out
        return output, attn1_weights, attn2_weights



In [None]:
# Create mapping of image to captions
mapping = {}
for line in tqdm(captions_doc.split('\n')[1:]):
    tokens = line.split(',')
    if len(line) < 2:
        continue
    image_id, caption = tokens[0], tokens[1:]
    image_id = image_id.split('.')[0]
    caption = " ".join(caption)
    if image_id not in mapping:
        mapping[image_id] = []
    mapping[image_id].append(caption)

In [None]:
# Clean the captions
def clean(mapping):
    for key, captions in mapping.items():
        for i in range(len(captions)):
            caption = captions[i]
            caption = caption.lower()
            caption = caption.replace('[^A-Za-z]', '')
            caption = caption.replace('\s+', ' ')
            caption = 'startseq ' + " ".join([word for word in caption.split() if len(word)>1]) + ' endseq'
            captions[i] = caption

In [None]:
# Preprocess the text
clean(mapping)

In [None]:
all_captions = [caption for captions in mapping.values() for caption in captions]
len(all_captions)

In [None]:
import torchtext
from torchtext.data import get_tokenizer

tokenizer = get_tokenizer("basic_english")

# Tokenize the text
tokenized_text = [tokenizer(caption) for caption in all_captions]

# Build vocabulary : Mapping every token to an integer index
vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_text)
vocab_size = len(vocab)
print(vocab_size)

In [None]:
max_length = max(len(caption.split()) for caption in all_captions)
print(max_length)

In [None]:

def one_hot(a, num_classes):

    out = np.zeros(num_classes)
    out[a] = 1
    return out


In [None]:
# Extract features from images
features = {}
# directory = 'Images'


from torchvision.models.vision_transformer import vit_b_16
from torchvision.models import ViT_B_16_Weights


vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT).to(device)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

for img_name in tqdm(os.listdir(dataset_dir)):
    img_path = os.path.join(dataset_dir, img_name)
    image = Image.open(img_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    feats = vit._process_input(image)
    # Expand the class token to the full batch
    batch_class_token = vit.class_token.expand(image.shape[0], -1, -1)
    feature = torch.cat([batch_class_token, feats], dim=1)
    feature = vit.encoder(feature)

    feature = feature.squeeze().clone().detach().requires_grad_(False).cpu().numpy()
    image_id = img_name.split('.')[0]
    features[image_id] = feature

In [None]:
# # Store features in pickle
with open(os.path.join('./', 'features_vit.pkl'), 'wb') as f:
     pickle.dump(features, f)

In [None]:
# # Load features from pickle
with open(os.path.join('./', 'features_vit.pkl'), 'rb') as f:
    features = pickle.load(f)

In [None]:

print(features['3250076419_eb3de15063'].shape)

In [None]:
class CaptioningDataset(Dataset):
  def __init__(self, data_keys, features, mapping, transform, tokenizer, max_length):
    self.data_keys = data_keys
    self.mapping = mapping
    self.transform = transform
    self.tokenizer = tokenizer
    self.max_length = max_length
    self.features = features

  def __len__(self):
    return len(self.data_keys)

  def __getitem__(self, idx):
      key = self.data_keys[idx]
      captions = self.mapping[key]

      caption = captions[0]#np.random.choice(len(captions))]
      input2, y = torch.zeros(self.max_length).int(), torch.zeros((self.max_length, vocab_size))

      tokens = self.tokenizer(caption)
      caption_indices = [vocab[token] for token in tokens]
      feats = torch.as_tensor(self.features[key])

      for i in range(1, len(caption_indices)):
          in_seq, out_seq = caption_indices[i-1], caption_indices[i]

          out_seq = int(out_seq)

          out_seq = one_hot(out_seq, num_classes=vocab_size)
          input2[i-1] = int(in_seq)

          y[i-1] = torch.as_tensor(out_seq)
      return feats, input2, y, idx




In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
image_ids = list(mapping.keys())
split = int(len(image_ids) * 0.75)
train = image_ids[:split]
test = image_ids[split:]

In [None]:
print(len(image_ids))

In [None]:
batch_size = 32
train_dataset = CaptioningDataset(train, features, mapping, transform, tokenizer, max_length)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:

test_dataset = CaptioningDataset(test, features, mapping, transform, tokenizer, max_length)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
print(train_dataset.__getitem__(0)[0].size())
print(train_dataset.__getitem__(0)[1].size())
print(train_dataset.__getitem__(0)[2].size())

In [None]:
class ImageCaptioningModel(nn.Module):
  def __init__(self, encoder_decoder, pos_enc, vocab_size, img_embedding_dim, token_embedding_dim):
    super().__init__()
    self.vocab_size = vocab_size
    self.img_embedding_dim = img_embedding_dim
    self.token_embedding_dim = token_embedding_dim

    self.transformer = encoder_decoder.to(device)
    self.pos_enc = pos_enc.to(device)
    self.process_feats = nn.Sequential(
            nn.Linear(self.img_embedding_dim, 512),
            nn.ReLU(),
            nn.Linear(512, self.token_embedding_dim)
        )
    self.embedding = nn.Embedding(self.vocab_size, self.token_embedding_dim)
    self.decoder = nn.Linear(self.token_embedding_dim, vocab_size)

  def forward(self, feats, input2, attn_mask, key_mask=None):
    embedding_out = self.embedding(input2)

    pe_out = self.pos_enc(embedding_out.to(device))

    feats_out = self.process_feats(feats)

    output, attn1, attn2 = self.transformer(pe_out, feats_out, attn_mask, key_mask)

    output = self.decoder(output)

    return output, attn1, attn2




In [None]:
data_iter = iter(train_loader)
single_batch = next(data_iter)

image, input2, targets, idx = single_batch

In [None]:
# Instantiate the model
token_embedding_size = 512
img_embedding_size = 768

n_heads = 8
dropout = 0.0

transformer = EncoderDecoderBlock(token_embedding_size,n_heads, n_heads, mlp_ratio=2, p_dropout=dropout)
pe_enc = PositionalEncoding(max_length, token_embedding_size)

model = ImageCaptioningModel(transformer , pe_enc, vocab_size, img_embedding_dim=img_embedding_size, token_embedding_dim=token_embedding_size)
model = model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
at_mask = [[-np.inf for _ in range(max_length)] for _ in range(max_length)]
at_mask = torch.tensor(at_mask)
attn_mask = torch.triu(at_mask, diagonal=1)
print(attn_mask)

In [None]:
# Evaluation
def idx_to_word(index):
    try:
        return vocab.get_itos()[index]
    except:
        return None

def predict_caption(model, image_id, max_length):
    with torch.no_grad():
        model.eval()
        feats = torch.as_tensor(features[image_id]).unsqueeze(0).to(device)
        input2 = torch.zeros(max_length, dtype=torch.int64)
        in_text = 'startseq'
        idx = 1
        attns = []
        for _ in range(max_length):
            input2[idx - 1] = torch.as_tensor(vocab[in_text.split(' ')[-1]], dtype=torch.int64)
            input2 = input2.to(device)

            mask = torch.arange(0, max_length) >= idx
            mask = mask.unsqueeze(0).to(device)

            wrapper = torch.zeros(n_heads, max_length, max_length)
            at_mask = [[-np.inf for _ in range(max_length)] for _ in range(max_length)]
            at_mask = torch.tensor(at_mask)
            attn_mask = torch.triu(at_mask, diagonal=1)


            for i in range(n_heads):
                wrapper[i,:,:] = attn_mask
            #mask = torch.sum(targets, dim=-1) == 0
            attn_mask = wrapper

            outputs, attn1, attn2 = model(feats, input2, attn_mask.to(device), mask.to(device))
            outs = outputs
            attns.append([attn1.detach(), attn2.detach()])
            outputs = F.softmax(outputs[0][idx], dim=-1)

            idx += 1
            #y_pred = torch.multinomial(outputs, 1).squeeze(0).item()
            y_pred = torch.argmax(outputs, dim = -1)
            word = idx_to_word(y_pred)
            in_text += ' ' + word

            if word is None or word == 'endseq' or idx == max_length:
                break

        target_tokens = tokenizer(mapping[image_id][0])
        target = torch.tensor([vocab[token] for token in target_tokens]).squeeze().to(device).long()
        test_loss = criterion(F.one_hot(input2[:len(target)], num_classes=vocab_size).float(), target)

    return in_text, test_loss.item(), attns

In [None]:
loss_test_hist = []
loss_train_hist = []
blue_score_hist = []

In [None]:
# Train the model

from nltk.translate.bleu_score import corpus_bleu

num_epochs = 100

for epoch in range(num_epochs):
    total_loss_train = 0
    total_loss_test = 0
    for batch in train_loader:
        model.train()

        image, inputs2, targets, _ = batch

        image, inputs2, targets = image.to(device), inputs2.to(device), targets.to(device)
        # Generate output sequence from the model
        wrapper = torch.zeros(n_heads * len(batch[0]), max_length, max_length)
        at_mask = [[-np.inf for _ in range(max_length)] for _ in range(max_length)]
        at_mask = torch.tensor(at_mask)
        attn_mask = torch.triu(at_mask, diagonal=1)

        for i in range(n_heads * len(batch[0])):
           wrapper[i,:,:] = attn_mask
        mask = torch.sum(targets, dim=-1) == 0
        attn_mask = wrapper
        output, attn1, attn2 = model(image, inputs2, attn_mask.to(device), mask.to(device))
        mask_out = ~mask#torch.sum(targets, dim=-1) != 0
        output_flat = output.view(-1, vocab_size)
        targets_flat = targets.view(-1, vocab_size)

        # Apply the mask
        output_masked = output_flat[mask_out.view(-1)]
        targets_masked = targets_flat[mask_out.view(-1)]
        loss = criterion(output_masked, targets_masked)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss_train += loss.item()

    average_loss_train = total_loss_train / len(train_loader)
    #if epoch % 4 == 0 or epoch == num_epochs - 1 :
    actual, predicted = [], []
    for key in test[:128]:
        captions = mapping[key]
        y_pred, test_loss, attn = predict_caption(model, key, max_length)
        actual_captions = [captions[0].split() ]#for caption in captions]

        
        y_pred = y_pred.split()
        actual.append(actual_captions)
        predicted.append(y_pred)

        total_loss_test += test_loss
    
    average_loss_test = total_loss_test / len(test[:128])
    bleu1_test = corpus_bleu(actual, predicted, weights=(1.0, 0, 0, 0))

    loss_train_hist.append(average_loss_train)
    loss_test_hist.append(average_loss_test)
    blue_score_hist.append(bleu1_test)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss Train: {average_loss_train:.4f}, Loss Test: {average_loss_test:.4f} BLEU-1 Score Test: {bleu1_test}')
    #else :
        #print(f'Epoch [{epoch + 1}/{num_epochs}], Loss Train: {average_loss_train:.4f}')

In [None]:
x = np.linspace(0,num_epochs, num_epochs)

plt.plot(x, loss_train_hist, 'b', label='train loss')
plt.plot(x, loss_test_hist, 'r', label='test loss')
plt.legend()
plt.title('Loss evolution over epochs')
plt.savefig('loss_total.png')

In [None]:
plt.plot(x, blue_score_hist, label='blue score')
plt.title('Bleu score evolution over epochs')
plt.savefig('bleu_score_total.png')

In [None]:
# Generate caption for an image
def generate_caption(model, image_name) :
    image_id = image_name.split('.')[0]
    img_path = os.path.join(dataset_dir, image_name)
    image = Image.open(img_path)

    captions = mapping[image_id]
    print('---------------------Actual---------------------')
    for caption in captions:
        print(caption)

    y_pred, loss, attns = predict_caption(model, image_id, 35)
    print('--------------------Predicted--------------------')
    print(y_pred)

    plt.imshow(image)

In [None]:
print(generate_caption(model, f'{train[900]}.jpg'))
#tensor([4998, 5071,  619,  402, 5488,  642, 1609, 5000, 5381, 6045, 3427, 2610,

In [None]:
print(generate_caption(model, f'{test[160]}.jpg'))

In [None]:
torch.save(model.state_dict(), 'alphaModel.pt')