In [None]:
!pip install rouge-score

import os
import re
import cv2
import nltk
import json
import random
import numpy as np
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from collections import Counter
from collections import defaultdict
from nltk.corpus import stopwords
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.models.efficientnet import EfficientNet_B0_Weights
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
IMAGES_DIR = "/kaggle/input/flickr8k/Images"
CAPTIONS_FILE = "/kaggle/input/flickr8k/captions.txt"

BATCH_SIZE = 32

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
def load_captions(captions_file):
    captions_dict = {}
    with open(captions_file, 'r') as file:
        for line in file:
            tokens = line.strip().split(",") 
            if len(tokens) == 2:
                image_id, caption = tokens
                if image_id not in captions_dict:
                    captions_dict[image_id] = []
                captions_dict[image_id].append(caption)
    return captions_dict

def display_images_with_captions(images_dir, captions_dict, num_images=3, image_size=(224, 224)):
    images = list(captions_dict.keys())
    if len(images) < num_images: raise ValueError(f"Doesn't Exist.")    
    selected_images = random.sample(images, num_images)
    plt.figure(figsize=(10 * num_images, 10))
    for i, image in enumerate(selected_images):
        img_path = os.path.join(images_dir, image)
        img = Image.open(img_path)        
        plt.subplot(num_images, 1, i + 1)
        plt.imshow(img)
        plt.axis('off')
        captions = "\n".join(captions_dict[image]) 
        plt.title(captions, fontsize=10, loc='center', wrap=True)
    plt.tight_layout()
    plt.show()

def plot_training_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(train_losses) + 1), train_losses, marker='o', label='Training Loss')
    plt.plot(range(1, len(val_losses) + 1), val_losses, marker='x', label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss Over Epochs')
    plt.grid(True)
    plt.legend()
    plt.show()

# def display_random_images_with_captions(images_dir, captions_dict, num_images=5):
#     selected_images = random.sample(list(captions_dict.keys()), num_images)
#     plt.figure(figsize=(15, 10))

#     for i, image_id in enumerate(selected_images):
#         img_path = f"{images_dir}/{image_id}"
#         img = Image.open(img_path)

#         plt.subplot(num_images, 1, i + 1)
#         # plt.subplot(1, num_images, i + 1)
#         plt.imshow(img)
#         plt.axis("off")
#         plt.title(captions_dict[image_id], fontsize=10)
    
#     plt.tight_layout()
#     plt.show()
        
def plot_caption_length_variation(captions_dict):
    lengths = []
    for captions in captions_dict.values():
        lengths.extend([len(caption.split()) for caption in captions])
    plt.figure(figsize=(15, 10))
    plt.scatter(range(len(lengths)), lengths, alpha=0.5)
    plt.title("Variation in Caption Lengths", fontsize=14)
    plt.xlabel("Caption Index", fontsize=12)
    plt.ylabel("Caption Length (Number of Words)", fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.show()

def plot_word_frequency_histogram(captions_dict, top_n=20):
    word_counter = Counter()
    for captions in captions_dict.values():
        for caption in captions:
            preprocessed_caption = preprocess_text(caption)
            word_counter.update(preprocessed_caption.split())
    most_common_words = word_counter.most_common(top_n)
    words, counts = zip(*most_common_words)
    plt.figure(figsize=(12, 6))
    plt.bar(words, counts, color='skyblue')
    plt.title(f"Top {top_n} Most Frequent Words", fontsize=14)
    plt.xlabel("Words", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.xticks(rotation=45, fontsize=10)
    plt.tight_layout()
    plt.show()

In [None]:
def split_dataset(captions_dict, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, random_seed=42):
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1."

    flattened_captions = []
    for image_id, captions in captions_dict.items():
        for caption in captions:
            flattened_captions.append((image_id, caption))

    random.seed(random_seed)
    random.shuffle(flattened_captions)

    total_samples = len(flattened_captions)
    train_end = int(total_samples * train_ratio)
    val_end = train_end + int(total_samples * val_ratio)

    train_samples = flattened_captions[:train_end]
    val_samples = flattened_captions[train_end:val_end]
    test_samples = flattened_captions[val_end:]

    train_captions = {image_id: [] for image_id, _ in train_samples}
    for image_id, caption in train_samples:
        train_captions[image_id].append(caption)

    val_captions = {image_id: [] for image_id, _ in val_samples}
    for image_id, caption in val_samples:
        val_captions[image_id].append(caption)

    test_captions = {image_id: [] for image_id, _ in test_samples}
    for image_id, caption in test_samples:
        test_captions[image_id].append(caption)

    return train_captions, val_captions, test_captions

def save_captions_to_json(captions, filepath):
    with open(filepath, 'w') as file:
        json.dump(captions, file, ensure_ascii=False, indent=4)

In [None]:
captions_dict = load_captions(CAPTIONS_FILE)
display_images_with_captions(IMAGES_DIR, captions_dict)

In [None]:
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

def preprocess_text(text):
    text = text.lower()
    text = re.sub(r"[^\w\s]", "", text)  # Remove punctuation
    text = re.sub(r"\d+", "", text)      # Remove numbers
    words = text.split()                 # Tokenize text
    words = [word for word in words if word not in stop_words]  # Remove stop words
    return " ".join(words).strip()

token_to_index = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}

index_to_word = {v: k for k, v in token_to_index.items()}  # Reverse the token_to_index mapping

vocab = set()
all_words_dict = {}

for captions in captions_dict.values():
    for caption in captions:
        caption = caption.lower()
        caption = re.sub(r"[^\w\s]", "", caption)  # Remove punctuation
        caption = re.sub(r"\d+", "", caption)      # Remove numbers
        # preprocessed_caption = preprocess_text(caption)
        for word in caption.split():
            if word not in token_to_index:  
                idx = len(token_to_index)
                token_to_index[word] = idx
                index_to_word[idx] = word
                vocab.add(word)
            
            if word not in all_words_dict:
                all_words_dict[word] = len(all_words_dict)

word_to_index_dictionary = dict(token_to_index)
index_to_word_dictionary = dict(index_to_word)

print("Word to Index:", len(word_to_index_dictionary))
print("Index to Word:", len(index_to_word_dictionary))
print("All Words Dictionary:", len(all_words_dict))  # Output the size of the all_words_dict

In [None]:
TOKENIZER_FILE = "tokenizer.json"
with open(TOKENIZER_FILE, 'w') as file:
    json.dump(word_to_index_dictionary, file, ensure_ascii=False, indent=4)
print(f"Word to Index Tokenizer saved to {TOKENIZER_FILE}. Total tokens: {len(word_to_index_dictionary)}")

# with open(TOKENIZER_FILE, "r") as file:
#     word_to_index_dictionary = json.load(file)

In [None]:
train_captions, val_captions, test_captions = split_dataset(captions_dict)

save_captions_to_json(train_captions, "train_captions.json")
save_captions_to_json(val_captions, "val_captions.json")
save_captions_to_json(test_captions, "test_captions.json")

print("Dataset split complete:")
print(f"Training: {len(train_captions)} images")
print(f"Validation: {len(val_captions)} images")
print(f"Test: {len(test_captions)} images")

In [None]:
class FlickrDataset(Dataset):
    def __init__(self, captions_file, images_dir, transform=None, tokenizer=None):
        with open(captions_file, 'r') as file:
            self.captions_dict = json.load(file)
        self.images_dir = images_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.image_ids = [img for img in self.captions_dict.keys() if os.path.exists(os.path.join(images_dir, img))]
     
    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = os.path.join(self.images_dir, image_id)

        # Load image
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Randomly select a caption
        caption = random.choice(self.captions_dict[image_id])
        caption = preprocess_text(caption)

        if self.tokenizer:
            caption = self.tokenizer(caption)

        return image, caption

In [None]:
plot_caption_length_variation(captions_dict)

In [None]:
plot_word_frequency_histogram(captions_dict, top_n=20)

In [None]:
def pad_caption(caption, word_to_index, max_length=15):
    tokens = ["<sos>"] + caption.lower().split() + ["<eos>"]
    # token_ids = [word_to_index.get(word, word_to_index["<unk>"]) for word in tokens]
    token_ids = []
    for word in tokens:
        token_id = word_to_index.get(word, word_to_index["<unk>"])
        token_ids.append(token_id)

    if len(token_ids) < max_length:
        token_ids += [word_to_index["<pad>"]] * (max_length - len(token_ids))
    else:
        token_ids = token_ids[:max_length]
    return torch.tensor(token_ids)

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

denormalize = transforms.Compose([
    transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], 
                         std=[1/0.229, 1/0.224, 1/0.225])
])

train_dataset = FlickrDataset("train_captions.json", IMAGES_DIR, transform=transform, tokenizer=lambda caption: pad_caption(caption, word_to_index_dictionary))
val_dataset = FlickrDataset("val_captions.json", IMAGES_DIR, transform=val_transform, tokenizer=lambda caption: pad_caption(caption, word_to_index_dictionary))
test_dataset = FlickrDataset("test_captions.json", IMAGES_DIR, transform=val_transform, tokenizer=lambda caption: pad_caption(caption, word_to_index_dictionary))

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Testing batches: {len(test_loader)}")

In [None]:
def masked_cross_entropy_loss(predictions, targets, pad_token=0):
    batch_size, seq_len, vocab_size = predictions.shape
    predictions = predictions.reshape(-1, vocab_size)  # Flatten logits
    targets = targets.reshape(-1)  # Flatten targets
    mask = targets != pad_token
    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_token, label_smoothing=0.1) # F.cross_entropy
    loss = loss_fn(predictions, targets) # , reduction='none'
    loss = loss * mask
    return loss.sum() / mask.sum()

def save_checkpoint(model, optimizer, epoch, loss, checkpoint_path="checkpoint.pth"):
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "loss": loss,
    }
    torch.save(checkpoint, checkpoint_path)

def load_checkpoint(checkpoint_path, model, optimizer):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    epoch = checkpoint["epoch"]
    loss = checkpoint["loss"]
    print(f"Checkpoint loaded from epoch {epoch}")
    return epoch, loss

def generate_caption(model, image, tokenizer, max_length=15):
    model.eval() 

    index_to_word = {index: word for word, index in tokenizer.items()}
    start_token = tokenizer["<sos>"]
    end_token = tokenizer["<eos>"]

    with torch.no_grad():
        features = model.encoder(image.unsqueeze(0).to(device))

        caption = [start_token]  # Begin with <sos>
        for _ in range(max_length):
            caption_tensor = torch.tensor([caption], device=device)  # Shape: (1, current_seq_len)
            outputs = model.decoder(features, caption_tensor)
            next_word_idx = outputs[0, -1].argmax().item()
            caption.append(next_word_idx)
            if next_word_idx == end_token:
                break

    words = [index_to_word[idx] for idx in caption if idx != start_token and idx != end_token]
    return " ".join(words)

# def visualize_results(model, dataloader, tokenizer, num_images=5):
#     model.eval()
#     for images, _ in dataloader:
#         images = images[:num_images].to(device)
#         denorm_images = torch.stack([denormalize(img) for img in images])
#         captions = [generate_caption(model, image, tokenizer) for image in images]
#         plt.figure(figsize=(5, num_images * 5))
#         for i in range(num_images):
#             ax = plt.subplot(num_images, 1, i + 1)
#             img = denorm_images[i].permute(1, 2, 0).cpu().numpy()
#             img = img.clip(0, 1)  # Ensure pixel values are in [0,1] range
#             ax.imshow(img)
#             ax.axis("off")  # Remove axes
#             ax.set_title(captions[i], fontsize=12, wrap=True)
#         plt.tight_layout()
#         plt.show()
#         break

In [None]:
def visualize_attention(image, attention_weights, caption, vocab):
    """
    Visualizes the attention maps for each generated word in the caption.
    """
    fig, axes = plt.subplots(1, len(caption), figsize=(20, 5))
    for idx, ax in enumerate(axes):
        attn_map = attention_weights[idx].detach().cpu().numpy().reshape(14, 14)  # Assuming a 14x14 feature map
        ax.imshow(image.permute(1, 2, 0).cpu().numpy())
        ax.imshow(attn_map, cmap='jet', alpha=0.6)
        ax.axis('off')
        ax.set_title(vocab[caption[idx].item()] if caption[idx].item() in vocab else "<unk>")
    plt.show()


In [None]:
def train_model(
    model, train_loader, val_loader, optimizer, num_epochs=50, 
    checkpoint_path="checkpoint.pth", resume=False, patience=5,
    weight_decay=1e-4
):
    model.train()
    if resume and os.path.exists(checkpoint_path):
        start_epoch, _ = load_checkpoint(checkpoint_path, model, optimizer)
    else:
        start_epoch = 0
        
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    patience_counter = 0 

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    for epoch in range(start_epoch, num_epochs):
        model.train()
        epoch_loss = 0
        tepoch = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        
        for batch_idx, (images, captions) in enumerate(tepoch):
            images, captions = images.to(device), captions.to(device)
            optimizer.zero_grad()

            outputs, attention_weights = model(images, captions[:, :-1])  # Exclude the last word for teacher forcing            
            outputs = outputs[:, :captions.size(1) - 1, :]  # Match target sequence length

            loss = masked_cross_entropy_loss(outputs, captions[:, 1:])  # Exclude <sos> for target
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            tepoch.set_postfix(loss=f"{loss.item():.4f}")

        train_losses.append(epoch_loss / len(train_loader))
        scheduler.step(epoch_loss / len(train_loader))
        print(f"Epoch {epoch+1}, Train Loss: {train_losses[-1]:.4f}")

        # Validation step
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, captions in val_loader:
                images, captions = images.to(device), captions.to(device)
                outputs, attention_weights = model(images, captions[:, :-1])
                outputs = outputs[:, :captions.size(1) - 1, :]
                loss = masked_cross_entropy_loss(outputs, captions[:, 1:])
                val_loss += loss.item()

        val_loss /= len(val_loader)
        val_losses.append(val_loss)
        print(f"Epoch {epoch+1}, Validation Loss: {val_loss:.4f}")

        # Save checkpoint
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0 
            save_checkpoint(model, optimizer, epoch + 1, val_loss, checkpoint_path)
        else:
            patience_counter += 1
            print(f"No improvement in validation loss for {patience_counter} epoch(s).")

        # Show sample image with caption at the end of each epoch
        show_sample_caption(model, val_loader)

        # Early stopping
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

    return train_losses, val_losses


def show_sample_caption(model, val_loader):
    """ Display a sample image with predicted and real captions """
    model.eval()
    with torch.no_grad():
        # Get a random sample from validation set
        images, captions = next(iter(val_loader))
        image = images[0].unsqueeze(0).to(device)  # Select one image
        real_caption = captions[0]  # Select corresponding caption

        # Generate caption
        predicted_caption = generate_caption(model, image)

        # Convert image to display format
        img = image.squeeze(0).cpu().permute(1, 2, 0).numpy()
        img = (img - img.min()) / (img.max() - img.min())  # Normalize for display

        # Plot image with captions
        plt.figure(figsize=(6, 6))
        plt.imshow(img)
        plt.axis("off")
        plt.title(f"Real: {' '.join(real_caption)}\nPred: {predicted_caption}", fontsize=10)
        plt.show()


def generate_caption(model, image, max_length=20):
    """ Generate a caption using greedy search (replace with beam search if needed) """
    model.eval()
    with torch.no_grad():
        caption = ["<sos>"]
        for _ in range(max_length):
            input_seq = torch.tensor([vocab.word2idx[word] for word in caption]).unsqueeze(0).to(device)
            output, _ = model(image, input_seq)
            next_word_idx = output.argmax(-1)[:, -1].item()
            next_word = vocab.idx2word[next_word_idx]

            if next_word == "<eos>":
                break
            caption.append(next_word)

    return " ".join(caption[1:])  # Remove <sos>


In [None]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        self.model = models.efficientnet_b0(pretrained=True)
        self.model.classifier = nn.Identity()
        self.fc = nn.Linear(1280, embed_size)
        self.relu = nn.ReLU()
    
    def forward(self, images):
        features = self.model(images)  # (batch_size, 1280)
        features = self.fc(features)   # (batch_size, embed_size)
        features = self.relu(features)
        # Ensure output shape is (batch_size, num_regions, feature_dim)
        features = features.unsqueeze(1)  # Change (batch_size, embed_size) to (batch_size, 1, embed_size)
        return features


class Attention(nn.Module):
    def __init__(self, feature_dim, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(feature_dim + hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1)

    def forward(self, features, hidden):
        hidden = hidden.unsqueeze(1).expand(-1, features.size(1), -1)
        attn_input = torch.cat((features, hidden), dim=2)
        attn_weights = F.softmax(self.v(torch.tanh(self.attn(attn_input))), dim=1)
        context = (attn_weights * features).sum(dim=1)
        return context, attn_weights


class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, feature_dim, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.attention = Attention(feature_dim, hidden_size)
        self.lstm = nn.LSTM(embed_size + feature_dim, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.softmax = nn.LogSoftmax(dim=2)

    def forward(self, features, captions):
        embeddings = self.embedding(captions)  # (batch, max_seq_len, embed_dim)
        h, c = torch.zeros(1, features.size(0), self.lstm.hidden_size).to(features.device), \
               torch.zeros(1, features.size(0), self.lstm.hidden_size).to(features.device)

        outputs = []
        attention_weights = []  # To store attention weights for visualization

        for t in range(captions.size(1)):  # Iterate over each time step
            context, attn_weights = self.attention(features, h.squeeze(0))  # Compute attention-weighted feature
            attention_weights.append(attn_weights)  # Save attention weights for visualization
            lstm_input = torch.cat((embeddings[:, t, :], context), dim=1).unsqueeze(1)  # Merge context and word embedding
            output, (h, c) = self.lstm(lstm_input, (h, c))  # LSTM step
            output = self.fc(output)  # Final prediction
            outputs.append(output)

        outputs = torch.cat(outputs, dim=1)  # Convert list to tensor
        attention_weights = torch.stack(attention_weights, dim=1)  # (batch_size, seq_len, num_regions)
        return self.softmax(outputs), attention_weights

class ImageCaptioningModel(nn.Module):
    def __init__(self, vocab_size, embed_size=512, hidden_size=512, num_layers=1, dropout=0.5):
        super(ImageCaptioningModel, self).__init__()
        self.encoder = EncoderCNN(embed_size).to(device)
        self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size, feature_dim=embed_size, num_layers=num_layers).to(device)

    def forward(self, images, captions):
        features = self.encoder(images)  # Extract image features
        outputs = self.decoder(features, captions)  # Generate captions
        return outputs

In [None]:
vocab_size = len(word_to_index_dictionary)
learning_rate = 0.001
weight_decay = 1e-4

model = ImageCaptioningModel(
    vocab_size, embed_size=256, hidden_size=512, num_layers=1, dropout=0.5
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [None]:
train_losses, val_losses = train_model(
    model, train_loader, val_loader, optimizer, num_epochs=50, 
    checkpoint_path="caption_model.pth", patience=15
)

In [None]:
plot_training_losses(train_losses, val_losses)

In [None]:
model.eval()
images_samples, references, hypotheses = [], [], []
index_to_word = {index: word for word, index in word_to_index_dictionary.items()}
special_tokens = {word_to_index_dictionary["<sos>"], word_to_index_dictionary["<eos>"], word_to_index_dictionary["<pad>"]}

with torch.no_grad():
    for images, captions in test_loader:
        images = images.to(device)
        captions = captions.to(device)
        predicted_captions = [generate_caption(model, image, word_to_index_dictionary) for image in images]
        true_captions = []
        for caption in captions.cpu().numpy():
            words = [index_to_word[idx] for idx in caption if idx not in special_tokens]
            true_captions.append(" ".join(words))
        references.extend([[ref] for ref in true_captions])
        hypotheses.extend(predicted_captions)
        images_samples.extend(images.cpu())

smoothie = SmoothingFunction().method1
bleu1 = corpus_bleu(references, hypotheses, weights=(1.0, 0, 0, 0), smoothing_function=smoothie)
bleu2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothie)
bleu3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0), smoothing_function=smoothie)
bleu4 = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothie)

scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
rouge_l_scores = [scorer.score(" ".join(refs[0]), " ".join(hyp))['rougeL'].fmeasure for refs, hyp in zip(references, hypotheses)]
rouge_l = np.mean(rouge_l_scores)

print(f"Evaluation on test set:")
print(f"BLEU-1 Score: {bleu1:.4f}")
print(f"BLEU-2 Score: {bleu2:.4f}")
print(f"BLEU-3 Score: {bleu3:.4f}")
print(f"BLEU-4 Score: {bleu4:.4f}")
print(f"ROUGE-L Score: {rouge_l:.4f}")

selected_indices = random.sample(range(len(images_samples)), min(5, len(images_samples)))    
fig, axes = plt.subplots(len(selected_indices), 1, figsize=(8, len(selected_indices) * 3))
if len(selected_indices) == 1: axes = [axes]
for i, idx in enumerate(selected_indices):
    img = images_samples[idx].permute(1, 2, 0).numpy()
    img = (img - img.min()) / (img.max() - img.min())  # Normalize image for display
    ref_caption = references[idx][0]
    hyp_caption = hypotheses[idx]
    axes[i].imshow(img)
    axes[i].axis("off")
    axes[i].set_title(f"Real: {ref_caption}\nPred: {hyp_caption}", fontsize=10)
plt.tight_layout()
plt.show()

In [None]:
def visualize_attention(image, attention_weights, captions, vocab):
    """
    Visualizes the attention maps for each generated word in the caption.
    """
    fig, axes = plt.subplots(1, len(captions), figsize=(20, 5))
    for idx, ax in enumerate(axes):
        attn_map = attention_weights[0, idx].reshape(14, 14).detach().cpu().numpy()  # Assume 14x14 feature map
        ax.imshow(image.permute(1, 2, 0).cpu().numpy())
        ax.imshow(attn_map, cmap='jet', alpha=0.6)  # Overlay attention map
        ax.axis('off')
        ax.set_title(vocab[captions[idx].item()])
    plt.show()


In [None]:
def visualize_attention(image, alphas, caption, index_to_word, save_path=None):
    image = image.permute(1, 2, 0).cpu().numpy()  # Convert image to HWC format
    image = np.clip(image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]), 0, 1)  # Unnormalize

    for t, alpha in enumerate(alphas):
        plt.figure(figsize=(8, 8))  # Increase figure size for better clarity
        plt.imshow(image)

        # Reshape alpha and add batch/channel dimensions for interpolation
        alpha = alpha.view(1, 1, 7, 7)  # Reshape to (N, C, H, W)
        alpha = F.interpolate(alpha, size=(image.shape[0], image.shape[1]), mode="bilinear", align_corners=False)
        alpha = alpha.squeeze().cpu().detach().numpy()  # Remove dimensions and convert to numpy

        plt.imshow(alpha, cmap="jet", alpha=0.5)  # Overlay heat map
        word = index_to_word.get(caption[t], "<unk>")
        plt.title(f"Word: {word}", fontsize=16)  # Larger font for better readability
        plt.axis("off")
        if save_path:
            plt.savefig(f"{save_path}_t{t}.png", bbox_inches="tight", dpi=200)  # Save with high resolution
        plt.show()

def generate_and_visualize_attention(model, test_loader, tokenizer, index_to_word, max_length=15, save_path=None):
    model.eval()
    sample_image, _ = next(iter(test_loader))  # Get a single batch from the test loader
    sample_image = sample_image[0].unsqueeze(0).to(device)  # Take the first image from the batch

    start_token = tokenizer["<sos>"]
    end_token = tokenizer["<eos>"]

    with torch.no_grad():
        features = model.encoder(sample_image)
        caption = [start_token]
        alphas = []

        for _ in range(max_length):
            caption_tensor = torch.tensor([caption], device=device)
            outputs, alpha = model.decoder(features, caption_tensor)
            next_word_idx = outputs[0, -1].argmax().item()
            alphas.append(alpha[0, -1])
            caption.append(next_word_idx)
            if next_word_idx == end_token:
                break

    words = [index_to_word[idx] for idx in caption if idx not in {start_token, end_token}]
    caption_text = " ".join(words)
    print(f"Generated Caption: {caption_text}")

    visualize_attention(sample_image[0], alphas, caption, index_to_word, save_path)






    # outputs, attention_weights = model(images, captions[:, :-1])  # Exclude the last word for teacher forcing            
    # outputs = outputs[:, :captions.size(1) - 1, :]  # Match target sequence length

    # loss = masked_cross_entropy_loss(outputs, captions[:, 1:])  # Exclude <sos> for target
    # loss.backward()
    # optimizer.step()
    # epoch_loss += loss.item()
    # tepoch.set_postfix(loss=f"{loss.item():.4f}")

    # # Visualize attention for the first batch of every epoch
    # if batch_idx == 0:
    #     visualize_attention(images[0], attention_weights[0], captions[0], word_to_index_dictionary)


In [None]:
# generate_and_visualize_attention(
#     model, test_loader, tokenizer=word_to_index_tokenizer,
#     index_to_word={index: word for word, index in word_to_index_tokenizer.items()},
#     max_length=15#, save_path="attention_heatmap"
# )

In [None]:
generate_and_visualize_attention(
    model=model, 
    test_loader=test_loader, 
    tokenizer=token_to_index,  # Use the token_to_index as the tokenizer
    index_to_word=index_to_word, 
    max_length=15,  # Set the maximum caption length
    save_path="output/attention_map"  # Optional: specify the path where images should be saved
)