# Imports and Dependencies

In [None]:
!pip install git+https://github.com/openai/CLIP.git
!pip install rouge-score
!pip install git+https://github.com/salaniz/pycocoevalcap.git

In [None]:
import json
import os
import re
import gzip
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
# clip
import clip
# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
# eval
from multiprocessing import Pool
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from scipy.stats import ttest_ind, sem
# misc
from sklearn.metrics.pairwise import cosine_similarity
import string
from collections import Counter
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from collections import defaultdict

# Path

Please replace `your_path` in the cell below witht the path to your folder with all the data and the annotations.

In [None]:
# your_path = "path/to/folder"
your_path = "drive/MyDrive/MIT/3_JuniorYear/6.8611" # julia's path
project_folder = "6.8611_Final_Project" # Update to yours

In [None]:
# Mount gdrive for data access
from google.colab import drive
drive.mount('/content/drive')

# Our Model

In [None]:
class LightweightCaptioningModel(nn.Module):
    def __init__(self, clip_dim, kg_dim, hidden_dim, vocab_size, use_kg=False, num_layers=1):
        super(LightweightCaptioningModel, self).__init__()
        self.use_kg = use_kg

        # Image feature projection
        self.clip_img_proj = nn.Linear(clip_dim, hidden_dim)

        # Word embedding for decoder
        self.word_embedding = nn.Embedding(vocab_size, hidden_dim)

        # Optional KG projection
        if use_kg:
            self.kg_proj = nn.Linear(kg_dim, hidden_dim)
            self.feature_combine = nn.Linear(hidden_dim * 2, hidden_dim)

        # Decoder LSTM
        self.decoder_rnn = nn.LSTM(hidden_dim * 2, hidden_dim, num_layers=num_layers, batch_first=True)
        self.output_layer = nn.Linear(hidden_dim, vocab_size)

    def forward(self, clip_image_embedding, kg_embedding=None, captions=None, lengths=None, **kwargs):
        # Process image features
        img_features = F.relu(self.clip_img_proj(clip_image_embedding))

        # Combine with KG features if using
        if self.use_kg and kg_embedding is not None:
            kg_features = F.relu(self.kg_proj(kg_embedding))
            projected_features = F.relu(self.feature_combine(torch.cat([img_features, kg_features], dim=-1)))
        else:
            projected_features = img_features

        if captions is not None:  # Training mode
            return self.training_step(projected_features, captions, lengths)

        else:  # Inference mode
            return self.evaluate(img_features, **kwargs)

    def training_step(self, projected_features, captions, lengths=None):
          caption_embeddings = self.word_embedding(captions[:, :-1])
          batch_size, seq_len, _ = caption_embeddings.size()

          # Repeat visual features for each timestep
          repeated_features = projected_features.unsqueeze(1).expand(-1, seq_len, -1)
          decoder_input = torch.cat([caption_embeddings, repeated_features], dim=-1)

          # Pack sequence for RNN
          packed_input = nn.utils.rnn.pack_padded_sequence(
              decoder_input, lengths-1, batch_first=True, enforce_sorted=False
          )

          # Run through decoder
          packed_output, _ = self.decoder_rnn(packed_input)
          outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
          logits = self.output_layer(outputs)

          target_captions = captions[:, 1:]

          return logits, target_captions

    def evaluate(self, img_features, idx_to_word, numberbatch_df, idf_values, word_to_idx, max_length=20, beam_width=3):
        batch_size = img_features.size(0)
        vocab_size = len(word_to_idx)
        outputs = []

        for batch_idx in range(batch_size):
            beams = [(torch.tensor([word_to_idx["<s>"]], device=img_features.device), 0.0)]
            completed_sequences = []
            hidden_state = None
            curr_img_features = img_features[batch_idx].unsqueeze(0)

            for step in range(max_length - 1):
                candidates = []

                for beam_idx, (seq, score) in enumerate(beams):
                    if seq[-1].item() == word_to_idx["</s>"]:
                        completed_sequences.append((seq, score))
                        continue

                    # Get word embedding and combine with image features
                    word_embedding = self.word_embedding(seq[-1])

                    if self.use_kg:
                        # Compute KG embeddings for the current sequence
                        generated_words = [idx_to_word.get(idx.item(), "<UNK>") for idx in seq]
                        kg_embedding = self.compute_caption_tfidf_embedding(
                            generated_words, numberbatch_df, idf_values
                        )
                        kg_embedding = torch.tensor(kg_embedding, dtype=torch.float32, device=img_features.device)
                        kg_features = F.relu(self.kg_proj(kg_embedding))
                        projected_features = F.relu(self.feature_combine(
                            torch.cat([curr_img_features, kg_features.unsqueeze(0)], dim=-1)
                        ))
                    else:
                        projected_features = curr_img_features

                    # Run through decoder
                    decoder_input = torch.cat([word_embedding.unsqueeze(0), projected_features], dim=-1)
                    output, new_hidden = self.decoder_rnn(decoder_input.unsqueeze(0), hidden_state)
                    word_logits = self.output_layer(output.squeeze(0))

                    # Apply repetition penalty
                    prev_tokens = set(seq.tolist())
                    repetition_penalty = torch.ones_like(word_logits)
                    for token_idx in prev_tokens:
                        repetition_penalty[0, token_idx] = 0.5

                    word_logits = word_logits * repetition_penalty
                    word_probs = F.log_softmax(word_logits, dim=-1)

                    # Get top k candidates
                    values, indices = word_probs.topk(beam_width * 2)  # Get more candidates to account for filtering

                    indices = indices.squeeze()
                    values = values.squeeze()

                    # Track number of candidates added for this beam
                    candidates_added = 0
                    for value_idx in range(len(values)):
                        next_token = indices[value_idx].item()

                        # Skip if token index is invalid or token doesn't exist
                        if next_token >= vocab_size or next_token not in idx_to_word:
                            continue

                        token = idx_to_word[next_token]

                        # Skip special tokens
                        if token in ["<PAD>", "<s>", "</s>", "<UNK>"]:
                            continue

                        # Skip for trigam blocking
                        if self._has_excessive_repetition(seq, next_token):
                            continue

                        value = values[value_idx]
                        candidate_seq = torch.cat([seq, indices[value_idx].unsqueeze(0)])

                        # Length normalization for score
                        length_penalty = ((5.0 + len(candidate_seq)) / 6.0) ** 0.65
                        candidate_score = (score + value.item()) / length_penalty

                        candidates.append((candidate_seq, candidate_score))
                        candidates_added += 1

                        # Break if we have enough candidates for this beam
                        if candidates_added >= beam_width:
                            break

                # Keep top candidates that aren't complete
                incomplete_candidates = [cand for cand in candidates
                                      if cand[0][-1].item() != word_to_idx["</s>"]]

                # Update beams
                if incomplete_candidates:
                    incomplete_candidates.sort(key=lambda x: x[1], reverse=True)
                    beams = incomplete_candidates[:beam_width]
                    for seq, score in beams:
                        beam_tokens = [idx_to_word[idx.item()] for idx in seq]
                else:
                    break

            # Add any remaining sequences
            completed_sequences.extend(beams)

            # Get best sequence
            if completed_sequences:
                completed_sequences.sort(key=lambda x: x[1], reverse=True)
                best_seq = completed_sequences[0][0]
            else:
                best_seq = beams[0][0]

            best_tokens = [idx_to_word[idx.item()] for idx in best_seq]

            # Pad sequence as needed
            seq_len = best_seq.size(0)
            if seq_len < max_length:
                padding = torch.full((max_length - seq_len,), word_to_idx["<PAD>"],
                                  dtype=best_seq.dtype, device=best_seq.device)
                best_seq = torch.cat([best_seq, padding])
            else:
                best_seq = best_seq[:max_length]

            outputs.append(best_seq)

            final_tokens = [idx_to_word[idx.item()] for idx in best_seq]

        return torch.stack(outputs)

    def _has_excessive_repetition(self, seq, next_token):
        if len(seq) < 3:
            return False
        sequence = seq.tolist() + [next_token]
        # Check for repeating trigrams
        if len(sequence) >= 6:
            trigrams = [tuple(sequence[i:i+3]) for i in range(len(sequence)-2)]
            trigram_counts = Counter(trigrams)
            if max(trigram_counts.values()) > 1:
                return True
        # Check for repeating words
        token_counts = Counter(sequence)
        if max(token_counts.values()) > 2:
            return True

        return False

    def compute_word_tfidf_embedding(self, word, numberbatch_df, idf_values):
        """
        Compute the TF-IDF embedding for a single word.
        """
        if word in numberbatch_df.index:
            embedding = numberbatch_df.loc[word].values
            idf_value = idf_values.get(word, 1.0)
            return embedding * idf_value
        return np.zeros(numberbatch_df.shape[1])

    def compute_caption_tfidf_embedding(self, words, numberbatch_df, idf_values):
      """
      Compute the normalized TF-IDF-weighted embedding for a caption.
      """
      embeddings = []
      weights = []
      for word in words:
          embedding = self.compute_word_tfidf_embedding(word, numberbatch_df, idf_values)
          if np.any(embedding):
              embeddings.append(embedding)
              weights.append(idf_values.get(word, 1.0))

      if not embeddings:
          return np.zeros(numberbatch_df.shape[1])  # Return a zero vector if no embeddings are valid

      embeddings = np.array(embeddings)
      weights = np.array(weights)
      normalized_embedding = np.sum(embeddings, axis=0) / (np.sum(weights) + 1e-8)  # Avoid division by zero
      return normalized_embedding

# Load all Embeddings and Setup


We wil, be loading the vocaulary we saved previously and adding the special tokens for generation to it. After this, there should be 18809 elements in the vocabulary dict.

In [None]:
vocabulary_path = os.path.join(your_path, f"{project_folder}/COCO/our_vocabulary.pkl")
with open(vocabulary_path, 'rb') as f:
    word_to_idx = pickle.load(f)

print("Vocabulary loaded!")
print(len(word_to_idx))
word_to_idx["<PAD>"] = len(word_to_idx)
word_to_idx["<UNK>"] = len(word_to_idx)
word_to_idx["<s>"] = len(word_to_idx)
word_to_idx["</s>"] = len(word_to_idx)
print(len(word_to_idx))
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

In [None]:
def normalize_token(token):
        return token.strip(string.punctuation).lower()
def encode_caption(caption, vocab, max_length=20):
    words = caption.split(" ")
    indices = [vocab["<s>"]] + [vocab.get(normalize_token(word), vocab["<UNK>"]) for word in words] + [vocab["</s>"]]
    if len(indices) > max_length:
        indices = indices[:max_length]
    else:
        indices += [vocab["<PAD>"]] * (max_length - len(indices))
    return indices

def decode_caption(indices, idx_to_word):
    words = []
    for idx in indices:
        word = idx_to_word.get(idx, "<UNK>")
        if word == "</s>":
            break
        words.append(word)
    return " ".join(words[1:])

In [None]:
# hyperparams
clip_dim = 512
kg_dim = 300
hidden_dim = 256
vocab_size = len(word_to_idx)
batch_size = 32
num_epochs = 75
learning_rate = 1e-4
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model_kg = LightweightCaptioningModel(clip_dim, kg_dim, hidden_dim, vocab_size, use_kg=True).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=18805)
optimizer = optim.Adam(model_kg.parameters(), lr=learning_rate)

In [None]:
# Data
coco_embeddings_path = os.path.join(your_path, f"{project_folder}/COCO/coco_embeddings.npz")
idf_values_path = os.path.join(your_path, f"{project_folder}/COCO/idf_values.csv")
numberbatch_embeddings_path = os.path.join(your_path, f'{project_folder}/ConceptNet_Data_Container/numberbatch-en-19.08.txt.gz')

coco_data = np.load(coco_embeddings_path, allow_pickle=True)
coco_image_features = torch.tensor(coco_data["image_features"]).float().to(device)
coco_text_features = torch.tensor(coco_data["text_features"]).float().to(device)
coco_captions = coco_data["text_captions"]  # Use tokenized captions
coco_image_ids = coco_data["image_ids"]

In [None]:
kg_coco_embedding_path = os.path.join(your_path, f"{project_folder}/COCO/kg_coco_embeddings.npz")
kg_coco_embeddings = np.load(kg_coco_embedding_path)
kg_embeddings = kg_coco_embeddings["kg_embeddings"]

We filter out the duplicated image embeddings as each image occurs mutliple times in the set paired with different reference captions.

In [None]:
kg_caption_to_index = defaultdict(list)
for idx, caption in enumerate(kg_coco_embeddings["captions"]):
    kg_caption_to_index[caption].append(idx)

unique_image_features = []
unique_text_features = []
unique_kg_embeddings = []
unique_captions = []
unique_image_ids = []

for i, (image_id, caption) in enumerate(zip(coco_image_ids, coco_captions)):
    if image_id not in unique_image_ids:
        kg_indices = kg_caption_to_index.get(caption, [])
        if kg_indices:
            kg_embedding = kg_embeddings[kg_indices[0]]
        else:
            print(f"No KG embedding found for caption: {caption}")
            kg_embedding = np.zeros_like(kg_embeddings[0])

        unique_image_features.append(coco_image_features[i])
        unique_text_features.append(coco_text_features[i])
        unique_kg_embeddings.append(kg_embedding)
        unique_captions.append(caption)
        unique_image_ids.append(image_id)

unique_image_features = torch.stack(unique_image_features)
unique_text_features = torch.stack(unique_text_features)
unique_kg_embeddings = np.array(unique_kg_embeddings)
unique_captions = np.array(unique_captions, dtype=object)
unique_image_ids = np.array(unique_image_ids)

print(unique_image_features.shape)
print(unique_text_features.shape)
print(unique_kg_embeddings.shape)
print(len(unique_captions))
print(len(unique_image_ids))

In [None]:
filtered_coco_embeddings_path = os.path.join(your_path, f"{project_folder}/COCO/filtered_coco_embeddings.npz")

In [None]:
np.savez_compressed(
    filtered_coco_embeddings_path,
    image_features=unique_image_features.cpu().numpy(),
    text_features=unique_text_features.cpu().numpy(),
    kg_embeddings=unique_kg_embeddings,
    captions=unique_captions,
    image_ids=unique_image_ids
)

print(f"Saved filtered embeddings and captions to {filtered_coco_embeddings_path}")

In [None]:
all_coco_embeddings = np.load(filtered_coco_embeddings_path, allow_pickle=True)
image_features = torch.tensor(all_coco_embeddings["image_features"]).float().to(device)
text_features = torch.tensor(all_coco_embeddings["text_features"]).float().to(device)
kg_embeddings = all_coco_embeddings["kg_embeddings"]
coco_captions = all_coco_embeddings["captions"]
image_ids = all_coco_embeddings["image_ids"]

# Training

In [None]:
class CaptionDataset(Dataset):
    def __init__(self, clip_image_features, clip_text_features, kg_embeddings, captions, vocab, max_length=20):
        self.clip_image_features = clip_image_features
        self.clip_text_features = clip_text_features
        self.kg_embeddings = kg_embeddings
        self.captions = captions
        self.vocab = vocab
        self.max_length = max_length

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        image_feature = torch.tensor(self.clip_image_features[idx], dtype=torch.float32)
        text_feature = torch.tensor(self.clip_text_features[idx], dtype=torch.float32)
        kg_embedding = torch.tensor(self.kg_embeddings[idx], dtype=torch.float32)

        caption = encode_caption(self.captions[idx], self.vocab, self.max_length)
        caption = torch.tensor(caption, dtype=torch.long)

        return image_feature, text_feature, kg_embedding, caption

# Dataloader initilization
max_length = 20
dataset = CaptionDataset(
    clip_image_features=image_features.squeeze(1),  # CLIP image embeddings
    clip_text_features=text_features.squeeze(1),   # CLIP text embeddings
    kg_embeddings=kg_embeddings,        # Precomputed KG embeddings
    captions=coco_captions,             # Captions
    vocab=word_to_idx,                  # Vocabulary
    max_length=max_length               # Maximum length for padding
)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
def train(model, data_loader, criterion, optimizer, num_epochs, device):
    model.train()
    losses = []

    for epoch in range(num_epochs):
        epoch_loss = 0
        for batch in tqdm(data_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
            clip_image_features, clip_text_features, kg_embeddings, captions = batch
            clip_image_features = clip_image_features.to(device)
            clip_text_features = clip_text_features.to(device)
            kg_embeddings = kg_embeddings.to(device)
            captions = captions.to(device)

            lengths = (captions != word_to_idx["<PAD>"]).sum(dim=1).cpu()

            optimizer.zero_grad()
            logits, target_captions = model(
                clip_image_features,
                kg_embedding=kg_embeddings,
                captions=captions,
                lengths=lengths
            )

            # Pad logits to match target sequence length
            target_len = target_captions.size(1)
            batch_size, seq_len, vocab_size = logits.shape
            padding_size = target_len - seq_len

            if padding_size > 0:
                # Create padding
                padding = torch.zeros(batch_size, padding_size, vocab_size, device=device)
                # Set padding token probability to 1 in padding positions
                padding[:, :, word_to_idx["<PAD>"]] = 1
                # Concatenate with original logits
                logits = torch.cat([logits, padding], dim=1)

            # Reshape for loss calculation
            logits = logits.reshape(-1, vocab_size)
            target_captions = target_captions.reshape(-1)

            # Create mask for non-padding tokens
            mask = target_captions != word_to_idx["<PAD>"]
            logits = logits[mask]
            target_captions = target_captions[mask]

            loss = criterion(logits, target_captions)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss / len(data_loader)}")
        losses.append(epoch_loss / len(data_loader))

    return losses

In [None]:
# train model
losses = train(model_kg, data_loader, criterion, optimizer, num_epochs=num_epochs, device="cuda")

In [None]:
# save trained model
torch.save(model_kg.state_dict(), os.path.join(your_path, f"path/to/model.pth"))

In [None]:
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Over Epochs")
plt.show()

# Evaluation

## Model Loading

In [None]:
model_kg = LightweightCaptioningModel(clip_dim, kg_dim, hidden_dim, vocab_size, use_kg=True).to(device)
model_kg.load_state_dict(torch.load(os.path.join(your_path, f"path/to/model.pth")))

## Loading all Embeddings

In [None]:
numberbatch_embeddings_path = os.path.join(your_path, f'{project_folder}/ConceptNet_Data_Container/numberbatch-en-19.08.txt.gz')

def load_numberbatch_embeddings(file_path):
    embeddings = []
    terms = []

    with gzip.open(file_path, 'rt', encoding='utf8') as f:
        next(f)
        for line in f:
            elements = line.strip().split()
            term = elements[0]
            vector = list(map(float, elements[1:]))
            terms.append(term)
            embeddings.append(vector)

    return pd.DataFrame(embeddings, index=terms)


numberbatch_df = load_numberbatch_embeddings(numberbatch_embeddings_path)

print(f"Loaded {len(numberbatch_df)} embeddings.")
print(numberbatch_df.head())

In [None]:
idf_values_path = os.path.join(your_path, f"{project_folder}/Viz-Wiz Captions/idf_values.csv")
def load_idf(idf_path):
  idf_df = pd.read_csv(idf_path)
  idf_dict = dict(zip(idf_df['word'], idf_df['idf']))
  return idf_dict
idf_vals = load_idf(idf_values_path)

In [None]:
# Set up testing dataloader
viz_wiz_images_dir = f"{project_folder}/Viz-Wiz Captions/val"
viz_wiz_annotations_path = f"{project_folder}/Viz-Wiz Captions/val.json"

total_wiz_images_dir = os.path.join(your_path, viz_wiz_images_dir)
total_wiz_annotations_path = os.path.join(your_path, viz_wiz_annotations_path)

with open(total_wiz_annotations_path, 'r') as f:
    viz_wiz_data = json.load(f)

images = {img['id']: img for img in viz_wiz_data['images']}
annotations = viz_wiz_data['annotations']
print(f"Loaded {len(images)} images and {len(annotations)} annotations.")

In [None]:
kg_viz_wiz_embedding_path = os.path.join(your_path, f"{project_folder}/Viz-Wiz Captions/kg_viz_wiz_embeddings.npz")
kg_viz_wiz_embeddings = np.load(kg_viz_wiz_embedding_path)
kg_embeddings = kg_viz_wiz_embeddings["kg_embeddings"]

Note that if you use our precomputed embeddings for `viz_wiz_embddings_path`, ignore the first 5004 entries (they are erroneous); you can do this by slicing [5004:] on each array.

In [None]:
viz_wiz_embeddings_path = os.path.join(your_path, f"{project_folder}/Viz-Wiz Captions/viz_wiz_embeddings.npz")
viz_wiz_data = np.load(viz_wiz_embeddings_path, allow_pickle=True)
viz_wiz_image_features = torch.tensor(viz_wiz_data["image_features"][5004:]).float().to(device)
viz_wiz_text_features = torch.tensor(viz_wiz_data["text_features"][5004:]).float().to(device)
viz_wiz_captions = viz_wiz_data["text_captions"][5004:] # Use tokenized captions
viz_wiz_image_ids = viz_wiz_data["image_ids"][5004:]

## Dataset Setup

In [None]:
class GroupedCaptionDataset(Dataset):
    def __init__(self, clip_image_features, captions, vocab, img_ids, max_length=20):
        self.clip_image_features = clip_image_features
        self.captions = captions
        self.vocab = vocab
        self.max_length = max_length
        self.num_captions_per_image = 5
        self.check_img_ordering(img_ids)
        self.img_ids = img_ids

    def __len__(self):
        return len(self.captions) // self.num_captions_per_image

    def __getitem__(self, idx):
        start_idx = idx * self.num_captions_per_image
        indices = range(start_idx, start_idx + self.num_captions_per_image)

        # Take only first instance of image features since they're identical
        image_feature = torch.tensor(self.clip_image_features[start_idx], dtype=torch.float32).unsqueeze(0)
        img_ids = torch.tensor(self.img_ids[start_idx], dtype=torch.int32).unsqueeze(0)

        # Stack all 5 captions
        captions = torch.stack([torch.tensor(encode_caption(self.captions[i], self.vocab, self.max_length), dtype=torch.long) for i in indices])

        return image_feature, captions, img_ids

    def check_img_ordering(self, array):
        if len(array) % 5 != 0:
            print("len( array)", len(array))
            raise ValueError("The array length is not divisible by 5.")
        else:
            reshaped = array.reshape(-1, 5)
            # Check if all numbers in each group are the same
            valid = np.all(reshaped[:, 0] == reshaped.T)
            if valid:
              return True
            else:
                raise ValueError("The array does not satisfy the condition.")

In [None]:
test_dataset = GroupedCaptionDataset(
    clip_image_features=viz_wiz_image_features.squeeze(1),
    captions=viz_wiz_captions,
    vocab=word_to_idx,
    img_ids=viz_wiz_image_ids,
    max_length=20
)
test_data_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

## Inference

In [None]:
def evaluate(model, data_loader, vocab, device, metrics=['bleu']):
    model.eval()
    all_predictions = []  # Will store (img_id, prediction) tuples
    all_references = []   # Will store (img_id, [references]) tuples
    criterion = nn.CrossEntropyLoss(ignore_index=18805)
    counter = 0
    total_samples_used = 0

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            clip_image_features, captions, img_ids = batch
            clip_image_features = clip_image_features.to(device)
            captions = captions.to(device)

            # Initialize the mask
            batch_size = captions.size(0)
            mask = torch.ones(batch_size, dtype=torch.bool)

            # Check each batch
            for i in range(batch_size):
                for j in range(captions.size(1)):
                    caption_tokens = captions[i, j]
                    decoded_caption = decode_caption(caption_tokens.tolist(), vocab)
                    if "quality issue" in decoded_caption.lower():
                        mask[i] = False
                        break

            # Apply the mask to filter batches
            captions = captions[mask]
            clip_image_features = clip_image_features[mask]
            # Keep track of total samples actually used
            total_samples_used += captions.size(0)

            # Get model predictions
            outputs = model(clip_image_features.squeeze(1),
                          idx_to_word=idx_to_word,
                          numberbatch_df=numberbatch_df,
                          idf_values=idf_vals,
                          word_to_idx=word_to_idx)

            for pred_seq, ref_captions, img_id in zip(outputs, captions, img_ids):
                # Convert prediction (skip special tokens but keep spaces between words)
                pred_tokens = []
                for token_idx in pred_seq:
                    token = idx_to_word[token_idx.item()]
                    if token not in ["<PAD>", "<s>", "</s>", "<UNK>"]:
                        pred_tokens.append(token)

                # Store prediction with image ID
                if pred_tokens:
                    all_predictions.append((str(img_id.item()), " ".join(pred_tokens))) # We make the img_id a string for CIDER and SPICE
                else:
                    print("Warning: Empty prediction sequence")
                    all_predictions.append((str(img_id.item()), "<EMPTY>"))

                # Format references
                refs = []
                for ref_caption in ref_captions:
                    ref_tokens = []
                    for token_idx in ref_caption:
                        token = idx_to_word[token_idx.item()]
                        if token not in ["<PAD>", "<s>", "</s>", "<UNK>"]:
                            ref_tokens.append(token)
                    if ref_tokens:  # Only add non-empty references
                        refs.append(" ".join(ref_tokens))
                all_references.append((str(img_id.item()), refs))
            counter +=1

    # Debug prints
    print(f"\nNumber of predictions: {len(all_predictions)}")
    print(f"Number of references: {len(all_references)}")
    print("\nSample predictions:")
    for i in range(min(5, len(all_predictions))):
        print(f"Image ID {all_predictions[i][0]}:")
        print(f"Prediction: '{all_predictions[i][1]}'")
        print(f"References: {all_references[i][1]}\n")
    return total_samples_used, all_predictions, all_references

In [None]:
total_samples_used, eval_all_predictions, eval_all_references = evaluate(model_kg, test_data_loader, idx_to_word, device)

In [None]:
# Prepare references - each item should be a list of word lists (multiple references per example)
non_tuple_references = [[r.split() for r in ref[-1]] for ref in eval_all_references]
print("References", non_tuple_references[0])

# Prepare candidates - each item should be a list of words
non_tuple_candidates = [pred[-1].split() for pred in eval_all_predictions]
print(f"Candidate: {non_tuple_candidates[0]}")

## Evaluation Metrics

In [None]:
# BLEU-1 Calculation Function
def calculate_bleu1(references, candidates):
    weights = (1.0, 0, 0, 0)
    smoothing_fn = SmoothingFunction().method1
    return corpus_bleu(references, candidates, weights=weights, smoothing_function=smoothing_fn)

# ROUGE-L Calculation Function
def calculate_rouge(references, candidates):
    scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
    rouge_scores = []
    for refs, cand in zip(references, candidates):
        score = scorer.score(" ".join(refs[0]), " ".join(cand))
        rouge_scores.append(score["rougeL"].fmeasure)
    return rouge_scores

# Bootstrap for Confidence Intervals
def single_bootstrap(metric_fn, references, candidates):
    indices = np.random.choice(len(candidates), len(candidates), replace=True)
    sampled_references = [references[i] for i in indices]
    sampled_candidates = [candidates[i] for i in indices]
    return metric_fn(sampled_references, sampled_candidates)

def bootstrap_metric_parallel(metric_fn, references, candidates, n_bootstraps=500, n_jobs=6):
    with Pool(n_jobs) as pool:
        scores = pool.starmap(
            single_bootstrap,
            [(metric_fn, references, candidates) for _ in range(n_bootstraps)]
        )
    mean_score = np.mean(scores)
    lower_ci = np.percentile(scores, 2.5)
    upper_ci = np.percentile(scores, 97.5)
    return mean_score, lower_ci, upper_ci

# BLEU-1
bleu_mean, bleu_ci_lower, bleu_ci_upper = bootstrap_metric_parallel(calculate_bleu1, non_tuple_references, non_tuple_candidates)
bleu_results = {"mean": bleu_mean, "95% CI": (bleu_ci_lower, bleu_ci_upper)}
print(f"with kg BLEU-1: {bleu_mean:.4f} (95% CI: {bleu_ci_lower:.4f}, {bleu_ci_upper:.4f})")

# ROUGE-L
rouge_scores = calculate_rouge(non_tuple_references, non_tuple_candidates)
rouge_mean = np.mean(rouge_scores)
rouge_ci_lower = np.percentile(rouge_scores, 2.5)
rouge_ci_upper = np.percentile(rouge_scores, 97.5)
rouge_results = {"mean": rouge_mean, "95% CI": (rouge_ci_lower, rouge_ci_upper)}
print(f"with kg ROUGE-L: {rouge_mean:.4f} (95% CI: {rouge_ci_lower:.4f}, {rouge_ci_upper:.4f})")

In [None]:
# Evaluate stored model with no KG for comparison tests below
model_no_kg = LightweightCaptioningModel(clip_dim, kg_dim, hidden_dim, vocab_size, use_kg=False).to(device)
model_no_kg.load_state_dict(torch.load(os.path.join(your_path, f"path/to/model.pth")))
CTRL_total_samples_used, CTRL_eval_all_predictions, CTRL_eval_all_references = evaluate(model_no_kg, test_data_loader, idx_to_word, device)

In [None]:
# Prepare references - each item should be a list of word lists (multiple references per example)
CTRL_non_tuple_references = [[r.split() for r in ref[-1]] for ref in CTRL_eval_all_references]
print("References", CTRL_non_tuple_references[0])

# Prepare candidates - each item should be a list of words
CTRL_non_tuple_candidates = [pred[-1].split() for pred in CTRL_eval_all_predictions]
print(f"Candidate: {CTRL_non_tuple_candidates[0]}")

In [None]:
# BLEU-1
bleu_mean, bleu_ci_lower, bleu_ci_upper = bootstrap_metric_parallel(calculate_bleu1, CTRL_non_tuple_references, CTRL_non_tuple_candidates)
bleu_results = {"mean": bleu_mean, "95% CI": (bleu_ci_lower, bleu_ci_upper)}
print(f"no kg BLEU-1: {bleu_mean:.4f} (95% CI: {bleu_ci_lower:.4f}, {bleu_ci_upper:.4f})")

# ROUGE-L
rouge_scores = calculate_rouge(CTRL_non_tuple_references, CTRL_non_tuple_candidates)
rouge_mean = np.mean(rouge_scores)
rouge_ci_lower = np.percentile(rouge_scores, 2.5)
rouge_ci_upper = np.percentile(rouge_scores, 97.5)
rouge_results = {"mean": rouge_mean, "95% CI": (rouge_ci_lower, rouge_ci_upper)}
print(f"no kg ROUGE-L: {rouge_mean:.4f} (95% CI: {rouge_ci_lower:.4f}, {rouge_ci_upper:.4f})")

In [None]:
# Save values for easy comparison
bleu1_with_kg = [calculate_bleu1([non_tuple_references[i]], [non_tuple_candidates[i]]) for i in range(len(non_tuple_candidates))]
bleu1_without_kg = [calculate_bleu1([non_tuple_references[i]], [CTRL_non_tuple_candidates[i]]) for i in range(len(CTRL_non_tuple_candidates))]
rouge_with_kg = calculate_rouge(non_tuple_references, non_tuple_candidates)
rouge_without_kg = calculate_rouge(non_tuple_references, CTRL_non_tuple_candidates)

In [None]:
def plot_distribution(group1, group2):
    plt.figure(figsize=(4, 4))
    n_bins = len(np.unique(np.concatenate([group1, group2])))

    plt.hist(group1, bins=n_bins, alpha=0.5, density=True,
             label='With KG', color='blue')
    plt.hist(group2, bins=n_bins, alpha=0.5, density=True,
             label='Without KG', color='red')

    plt.axvline(np.mean(group1), color='blue', linestyle='--')
    plt.axvline(np.mean(group2), color='red', linestyle='--')

    # Calculate Cohen's d
    n1, n2 = len(group1), len(group2)
    var1, var2 = np.var(group1, ddof=1), np.var(group2, ddof=1)
    pooled_std = np.sqrt(((n1-1)*var1 + (n2-1)*var2) / (n1+n2-2))
    cohens_d = (np.mean(group1) - np.mean(group2)) / pooled_std

    plt.title("Score Distribution")
    plt.xlabel('Score')
    plt.ylabel('Density')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
# BLEU-1
plot_distribution(bleu1_with_kg, bleu1_without_kg)

In [None]:
# ROUGE
plot_distribution(rouge_with_kg, rouge_without_kg)

In [None]:
gts = {kc[0]: kc[1] for kc in eval_all_references}
res = {kc[0]: [kc[1]] for kc in eval_all_predictions}

In [None]:
# Calculate CIDEr
cider_scorer = Cider()
cider_score, _ = cider_scorer.compute_score(gts, res)
print(f"CIDEr Score: {cider_score}")

# Calculate SPICE
spice_scorer = Spice()
spice_score, _ = spice_scorer.compute_score(gts, res)
print(f"SPICE Score: {spice_score}")