In [None]:
# If you're on Google Colab, use `!pip install -r requirements.txt` instead
%pip install -r requirements.txt

In [None]:
# Check if GPU is being used
import torch
print("CUDA Available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))
    print("PyTorch using CUDA version:", torch.version.cuda)

In [None]:
# Run only if using COCO dataset, if using Flickr, the images are pulled directly 
!wget http://images.cocodataset.org/zips/train2014.zip
!wget http://images.cocodataset.org/zips/val2014.zip
!wget http://images.cocodataset.org/zips/test2014.zip

!unzip train2014.zip -d coco
!unzip test2014.zip -d coco
!unzip val2014.zip -d coco

In [None]:
from datasets import load_dataset
from transformers import AutoProcessor, CLIPModel
from tqdm import trange
import torch
import numpy as np
import os
from PIL import Image
import json

#################
# Configuration #
#################
SAVE_DIR = "embedded_data"
os.makedirs(SAVE_DIR, exist_ok=True)

CLIP_MODEL_ID = "openai/clip-vit-large-patch14"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 128
COCO_ROOT = "coco"  # Local folder where COCO images are stored

#####################
# Dataset Selection #
#####################

# Switch between available datasets here, code supports both:
# DATASET_NAME = "jxie/flickr8k"
DATASET_NAME = "yerevann/coco-karpathy"
SPLIT = "validation+test+train"

# Load dataset
ds = load_dataset(DATASET_NAME, split=SPLIT)
print(f"Loaded dataset '{DATASET_NAME}' with {len(ds)} entries")
print(f"Dataset shape: {ds.shape}")

#############
# Load CLIP #
#############

processor = AutoProcessor.from_pretrained(CLIP_MODEL_ID)
clip = CLIPModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE).eval()

#######################
# Build Captiom Model #
#######################

caption_map = {}
for i, row in enumerate(ds):
    if DATASET_NAME == "jxie/flickr8k":
        # Flick8k: Each image has 5 caption columns
        caption_map[i] = [row[f"caption_{j}"] for j in range(5)]
    else:
        # COCO: List of captions stored under 'sentences'
        caption_map[i] = row["sentences"]

# Save captions to disk
with open(os.path.join(SAVE_DIR, "captions.json"), "w") as f:
    json.dump(caption_map, f)

In [None]:
####################
# HELPER FUNCTIONS #
####################

@torch.no_grad()
def embed_images(images):
    ''' 
    Takes a batch of PIL images and produces CLIP vision embeddings (normalized without CLS token)
    
    Returns:
        NumPy array [batch, patches, embed_dim] in float16 CPU format
        (to reduce RAM and disk usage)
    '''
    
    # Preprocess images to CLIP format and move to device
    inputs = processor(images=images, return_tensors="pt").to(DEVICE)
    
    # Autocast (amp) allows CLIP to run heavy matrix ops in float16
    # while keeping precision ops in float32. Benefits include faster
    # inference and lower memory usage.
    with torch.cuda.amp.autocast():
        out = clip.vision_model(inputs["pixel_values"], output_hidden_states=True)
    
    # Removes CLS token
    # CLIP output: CLS token + patch embedding
    # CLS = global summary vector (which would be used for classification)
    # Not super helpful in our case as we want local visual features
    patch_tokens = out.last_hidden_state[:, 1:, :]

    # Normalize embeddings for stability
    tokens = torch.nn.functional.normalize(patch_tokens, p=2, dim=-1)

    # Convert to float16, reduce RAM and disk by 50%
    return tokens.half().cpu().numpy()

def load_image(filename):
    '''
    Construct valid path to a COCO image based on its filename convention.

    Example:
    'COCO_val2014_000000184613.jpg'
        -> folder = 'val2014'
        -> full path = 'coco/val2014/...'

    Raises a clear error if the file is missing,
    preventing silent training corruption.
    '''
    folder = filename.split("_")[1] 
    full_path = os.path.join(COCO_ROOT, folder, filename)

    if not os.path.exists(full_path):
        raise FileNotFoundError(f"COCO image not found: {full_path}")
    
    return Image.open(full_path).convert("RGB")


#####################
# Create Embeddings #
#####################

# NOTE:
# COCO contains ~90k rows, and we were hitting limits on Colab
# We only embed 1 out of 5 entries (~20% sampling)
# If you want 100% coverage, set TOTAL = len(ds)
TOTAL = len(ds)//5 if DATASET_NAME == "yerevann/coco-karpathy" else len(ds)
print(f"Unique images selected for embedding: {TOTAL:,}")

# Determine the embedding shape by running one test forward pass
test_img = load_image(ds[0]['filename']) if DATASET_NAME != "jxie/flickr8k" else ds[0]["image"]
test_emb = embed_images([test_img])
TOKEN_SHAPE = test_emb.shape[1:]

print(f"Embedding shape per image: {TOKEN_SHAPE}")

# Create an on-disk array to store embeddings continuously instead of once at the end
from numpy.lib.format import open_memmap

# Prepare memory-mapped array for incremental saving
SAVE_PATH = os.path.join(SAVE_DIR, "patch_embeds_float16.npy")

fp = open_memmap(
    SAVE_PATH,
    mode="w+",
    dtype=np.float16,
    shape=(TOTAL,) + TOKEN_SHAPE,
)

print(f"Preallocated memmap on disk: {SAVE_PATH}")
print(f"Full shape: {(TOTAL,) + TOKEN_SHAPE}\n")
print("Starting embedding process...\n")

idx = 0
for start in trange(0, TOTAL, BATCH_SIZE):
    
    end = min(start + BATCH_SIZE, TOTAL)
    
    batch_rows = ds[start:end]

    # Load image files for the current batch
    if DATASET_NAME == "jxie/flickr8k":
        batch = batch_rows["image"]  # PIL Images already present
    else:
        # COCO-Karpathy: use filename column name
        batch = [load_image(row) for row in batch_rows['filename']]

    # Embed via CLIP
    embeds = embed_images(batch)
    b = embeds.shape[0]

    # Write batch directly to disk
    fp[idx:idx + b] = embeds
    idx += b

# Flush and release memmap
del fp

print(f"Embedding complete: saved to: {SAVE_PATH}")
print("Output Folder:", SAVE_DIR)

In [None]:
import os, json, torch
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset

##############################
# Load Embeddings + Captions #
##############################

SAVE_DIR = "embedded_images"
BART_MODEL_ID = "sshleifer/distilbart-cnn-12-6"
MAX_LEN = 64 # Max caption length (tokenized)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("Loading embeddings...")
embs = np.load(SAVE_PATH, mmap_mode="r")

print("Loading captions...")
with open(f"{SAVE_DIR}/captions.json","r") as f:
    caption_map = json.load(f)

all_imgs = np.arange(len(embs))

# First split: 80% train, 20% val+test
train_imgs, tmp = train_test_split(all_imgs, test_size=0.2, random_state=42)
# Second split: split the reserved 20% in half: 10% val, 10% test
val_imgs, test_imgs = train_test_split(tmp, test_size=0.5, random_state=42)


# Each image has multiple corresponding captions
# We expand the ds by duplicating image embeds per caption (idx only)
def expand_ids(imgs):
    img_idx, caps = [], []
    for i in imgs:
        for c in caption_map[str(i)]:
            img_idx.append(i)
            caps.append(c)
    return img_idx, caps

train_img_idx, train_caps = expand_ids(train_imgs)
val_img_idx, val_caps = expand_ids(val_imgs)
test_img_idx, test_caps = expand_ids(test_imgs)

print("Expanded Caption Instances (Image duplicated per caption):")
print(f"Train captions: {len(train_caps):,}")
print(f"Val captions:   {len(val_caps):,}")
print(f"Test captions:  {len(test_caps):,}")

In [None]:
from transformers import AutoTokenizer

# Load tokenizer for DistilBART
tokenizer = AutoTokenizer.from_pretrained(BART_MODEL_ID)

# Custom Dataset Class
class ImageDataset(Dataset):
    '''
    Dataset that pairs
        - Precomputed CLIP patch embeddings
        - Tokenized caption text
    '''

    def __init__(self, img_idx, captions, embeddings, tokenizer):
        self.img_idx = img_idx          # image index per caption
        self.captions = captions        # raw text list
        self.embeddings = embeddings    # memmapped numpy array of CLIP embeddings
        self.tokenizer = tokenizer
        self.image_ids = img_idx        # used for eval stats

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, i):
        # Select image ID for this caption
        idx = self.img_idx[i]
        txt = self.captions[i]

        # Load embedding for corresponding image
        vis = torch.from_numpy(self.embeddings[idx]).to(torch.float16)

        # Tokenize caption text for language modeling
        # - pad/truncate to fixed sequence length
        # - return PyTorch tensor
        tok = self.tokenizer(
            txt,
            max_length=MAX_LEN,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

         # Return model-ready training batch dictionary
        return {
            "patch_embeds": vis,                         # (patch_tokens, embed_dim)
            "labels": tok["input_ids"].squeeze(0),       # remove batch dimension
            "image_id": idx                              # kept for BLEU/CIDEr eval
        }

train_ds = ImageDataset(train_img_idx, train_caps, embs, tokenizer)
val_ds = ImageDataset(val_img_idx, val_caps, embs, tokenizer)
test_ds = ImageDataset(test_img_idx, test_caps, embs, tokenizer)

In [None]:
import torch.nn as nn

class CLIP2DistilBART(nn.Module):
    '''
    A lightweight multimodal captioning model that:

    1. Accepts CLIP patch embeddings as input
    2. Projects them through an adapter layer 
        - match CLIP dimension to BART encoder dimensions
    3. Feeds them into DistilBART
        - encoder receives adapted visual patches
        - decoder generates captions text
    
    Efficient as only small adapter and BART is fine-tuned
    '''

    _keys_to_ignore_on_save = [] # Prevents warnings during saving

    def __init__(self, bart_id, embed_dim):
        super().__init__()
        
        # Load pretrained DistilBART language model (decoder)
        self.bart = AutoModelForSeq2SeqLM.from_pretrained(bart_id)

        # Linear adapter: map CLIP embeddings -> BART encoder dimension
        # Shape: (num_patches, embed_dim) → (num_patches, bart_dim)
        self.adapter = nn.Linear(embed_dim, self.bart.config.d_model)

    def forward(self, patch_embeds, labels=None):
        '''
        Inputs:
            patch_embeds: (batch, patch_tokens, CLIP_dim)
            labels: tokenized caption IDs (optional)
                    If provided → training mode (returns loss)
                    If None → inference mode (decoder free-runs)

        Returns:
            Standard Seq2SeqLMOutput from DistilBART
        '''

        # Convert to float32 inside adapter for numerical stability
        enc = self.adapter(patch_embeds.float())
        # Wrap as valid BART encoder output
        enc_out = BaseModelOutput(last_hidden_state=enc)

        # Forward through BART
        return self.bart(
            encoder_outputs=enc_out,
            labels=labels,
        )

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForSeq2SeqLM, TrainingArguments, Trainer, EarlyStoppingCallback
from transformers.modeling_outputs import BaseModelOutput


os.environ["TRANSFORMERS_NO_SAFE_TENSORS"] = "1"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BART_MODEL_ID = "sshleifer/distilbart-cnn-12-6"
embed_dim = embs.shape[-1]

model = CLIP2DistilBART(BART_MODEL_ID, embed_dim).to(DEVICE)

def collate(batch):
    vis = torch.stack([b["patch_embeds"] for b in batch])
    lbl = torch.stack([b["labels"] for b in batch])
    return {"patch_embeds": vis, "labels": lbl}

args = TrainingArguments(
    output_dir="./caption_model",
    per_device_train_batch_size=128,
    num_train_epochs=10,
    learning_rate=5e-5,
    logging_steps=30,
    fp16=torch.cuda.is_available(),
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to=[],
)

class WrappedTrainer(Trainer):
    def save_model(self, output_dir=None, _internal_call=False):
        output_dir = output_dir or self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)

        self.model.bart.save_pretrained(output_dir, safe_serialization=False)
        torch.save(self.model.adapter.state_dict(), f"{output_dir}/adapter.pt")
        torch.save({
            "adapter_in": self.model.adapter.in_features,
            "adapter_out": self.model.adapter.out_features,
            "bart_model": BART_MODEL_ID,
        }, f"{output_dir}/config.pt")

trainer = WrappedTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)

trainer.train()
trainer.save_model("./caption_model")
tokenizer.save_pretrained("./caption_model")

In [None]:
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.cider.cider import Cider

preds = {}
refs = {}

model.eval()

from tqdm import tqdm

for i in tqdm(range(len(test_ds))):
    sample = test_ds[i]
    vis = sample["patch_embeds"].unsqueeze(0).to(DEVICE)

    enc = model.adapter(vis.float())
    enc_out = BaseModelOutput(last_hidden_state=enc)

    out = model.bart.generate(
        encoder_outputs=enc_out,
        max_length=16,
        num_beams=5,
    )
    pred = tokenizer.decode(out[0], skip_special_tokens=True)

    preds[i] = [pred]

    imgid = sample["image_id"]  
    refs[i] = caption_map[str(imgid)]

bleu, _ = Bleu(4).compute_score(refs, preds)
cider, _ = Cider().compute_score(refs, preds)

print("BLEU-4:", bleu[3])
print("CIDEr:", cider)