In [None]:

!pip install -q --upgrade "pyarrow>=21.0.0"
!pip install -q "pydantic>=2.0,<2.12"

!pip install -q transformers ftfy regex tqdm
!pip install -q git+https://github.com/openai/CLIP.git

In [None]:
import clip
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, get_linear_schedule_with_warmup # GPT2LMHeadModel: ph·∫ßn sinh ng√¥n ng·ªØ
from torch.optim import AdamW
from PIL import Image
from tqdm import tqdm
from collections import defaultdict
import os
import requests
import random
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:

clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2")

tokenizer.add_special_tokens({"pad_token": "[PAD]", "bos_token": "<|startoftext|>", "eos_token": "<|endoftext|>"})
gpt2_model.resize_token_embeddings(len(tokenizer)) # Resize model embeddings

for param in clip_model.parameters():
    param.requires_grad = False
    
for param in gpt2_model.parameters():
    param.requires_grad = False


In [None]:

IMAGE_DIR = "/kaggle/input/d/adityajn105/flickr8k/Images"
CAPTIONS_FILE = "/kaggle/input/d/adityajn105/flickr8k/captions.txt"

df = pd.read_csv(CAPTIONS_FILE)
print(f"Total captions: {len(df)}")

image_to_captions = defaultdict(list)
for index, row in df.iterrows():
    image_name, caption = row['image'], row['caption']
    image_to_captions[image_name].append(caption)

all_images = list(image_to_captions.keys())
print(f"Total unique images: {len(all_images)}")

train_size = int(0.8 * len(all_images))
val_size = int(0.1 * len(all_images))

train_images = all_images[:train_size]
val_images = all_images[train_size:train_size + val_size]
test_images = all_images[train_size + val_size:]

print(f"Train images: {len(train_images)}, Val images: {len(val_images)}, Test images: {len(test_images)}")

In [None]:
class FlickrDataset(Dataset):
    def __init__(self, image_keys, image_to_captions, tokenizer, preprocess, train=True):
        self.image_keys = image_keys
        self.image_to_captions = image_to_captions
        self.tokenizer = tokenizer
        self.preprocess = preprocess
        self.max_len = 40 
        self.train = train

    def __len__(self):
        return len(self.image_keys)

    def __getitem__(self, idx):
        image_key = self.image_keys[idx]
        image_path = os.path.join(IMAGE_DIR, image_key)
        
        if self.train:
            caption = random.choice(self.image_to_captions[image_key])
        else:
            caption = self.image_to_captions[image_key][0]
            
        image = Image.open(image_path).convert("RGB")
        image_processed = self.preprocess(image)
        
        caption_tokens = self.tokenizer(
            f"<|startoftext|>{caption}<|endoftext|>",
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        return image_processed, caption_tokens, image_path
       

In [None]:
class MappingNetwork(nn.Module):

    def __init__(self, clip_embedding_dim: int, gpt_embedding_dim: int, 
                 prefix_length: int = 10, num_layers: int = 8, num_heads: int = 8):
        super().__init__()
        self.prefix_length = prefix_length

        self.prefix_const = nn.Parameter(torch.randn(1, prefix_length, gpt_embedding_dim))

        self.clip_projection = nn.Linear(clip_embedding_dim, gpt_embedding_dim)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=gpt_embedding_dim,
            nhead=num_heads
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

    def forward(self, clip_embedding: torch.Tensor) -> torch.Tensor:
        batch_size = clip_embedding.shape[0]
    
        clip_memory = self.clip_projection(clip_embedding).unsqueeze(0)  # [1, batch, dim]
    
        input_prefix = self.prefix_const.expand(batch_size, -1, -1)       # [batch, prefix_len, dim]
        input_prefix = input_prefix.permute(1, 0, 2)                      # [prefix_len, batch, dim]
    
        output_prefix = self.transformer_decoder(tgt=input_prefix, memory=clip_memory)
    
        return output_prefix.permute(1, 0, 2)

In [None]:
class ClipCapModel(nn.Module):
    def __init__(self, clip_model, gpt2_model, prefix_length=10):
        super().__init__()
        self.gpt2 = gpt2_model
        self.clip = clip_model
        
        clip_embedding_dim = clip_model.visual.output_dim
        gpt_embedding_dim = gpt2_model.config.hidden_size
        
        self.mapping_network = MappingNetwork(clip_embedding_dim, gpt_embedding_dim, prefix_length=prefix_length)

    def forward(self, image_features, caption_tokens):
     
        caption_embeddings = self.gpt2.transformer.wte(caption_tokens['input_ids'].squeeze(1))
        
        with torch.no_grad():
            image_embeddings = self.clip.encode_image(image_features).float()
            
        prefix_embeddings = self.mapping_network(image_embeddings)
        
        combined_embeddings = torch.cat([prefix_embeddings, caption_embeddings], dim=1) # (16, 10, 768) + (16, 40, 768)
        
        prefix_length = self.mapping_network.prefix_length
        ignore_labels = torch.full((prefix_embeddings.shape[0], prefix_length), -100, device=device)
        labels = torch.cat([ignore_labels, caption_tokens['input_ids'].squeeze(1)], dim=1) # (16, 10) + (16, 40)
        
        prefix_mask = torch.ones(prefix_embeddings.shape[0], prefix_embeddings.shape[1], device=device)
        combined_mask = torch.cat([prefix_mask, caption_tokens['attention_mask'].squeeze(1)], dim=1) # (16, 10) + (16, 40)
        
        outputs = self.gpt2(inputs_embeds=combined_embeddings, attention_mask=combined_mask, labels=labels)
        
        return outputs.loss

In [None]:

EPOCHS = 30
BATCH_SIZE = 32
LEARNING_RATE = 1e-5
WEIGHT_DECAY = 1e-5

train_dataset = FlickrDataset(train_images, image_to_captions, tokenizer, preprocess)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=16)

val_dataset = FlickrDataset(val_images, image_to_captions, tokenizer, preprocess, train=False)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=16)

model = ClipCapModel(clip_model, gpt2_model).to(device)

optimizer = AdamW(model.mapping_network.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

for epoch in range(EPOCHS):
    print(f"\n--- Epoch {epoch + 1}/{EPOCHS} ---")
    
    model.train()
    total_train_loss = 0
    for images, captions, _ in tqdm(train_dataloader, desc="Training"):
        images = images.to(device)
        captions = {key: val.to(device) for key, val in captions.items()}
        
        optimizer.zero_grad()
        
        loss = model(images, captions)
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
        
    avg_train_loss = total_train_loss / len(train_dataloader)
    print(f"Average Training Loss: {avg_train_loss:.4f}")
    
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for images, captions, _ in tqdm(val_dataloader, desc="Validating"):
            images = images.to(device)
            captions = {key: val.to(device) for key, val in captions.items()}
            
            loss = model(images, captions)
            total_val_loss += loss.item()
            
    avg_val_loss = total_val_loss / len(val_dataloader)
    print(f"Average Validation Loss: {avg_val_loss:.4f}")

    torch.save(model.mapping_network.state_dict(), f"mapping_network_epoch_{epoch+1}.pth")

print("\nTraining complete!")

In [None]:
def generate_caption(image_path, model, max_length=40, num_beams=5):
    model.eval()
    
    image = Image.open(image_path).convert("RGB")
    image_processed = preprocess(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        image_embedding = clip_model.encode_image(image_processed).float()
        prefix_embeddings = model.mapping_network(image_embedding)
        
        output_ids = gpt2_model.generate(
            inputs_embeds=prefix_embeddings,
            max_length=max_length,
            num_beams=num_beams,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            early_stopping=True
        )
        
        caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    return caption

model.mapping_network.load_state_dict(torch.load(f"mapping_network_epoch_{30}.pth"))
print("Loaded best mapping network weights for inference.")

In [None]:
import matplotlib.pyplot as plt
import random

test_sample_keys = random.sample(test_images, 5)

for image_key in test_sample_keys:
    image_path = os.path.join(IMAGE_DIR, image_key)
    
    generated_caption = generate_caption(image_path, model)
    
    image = Image.open(image_path)
    plt.imshow(image)
    plt.title(f"Generated: {generated_caption}\nGround Truth 1: {image_to_captions[image_key][0]}")
    plt.axis('off')
    plt.show()

In [None]:
def generate_caption_batch(image_batch, model, max_length=40, num_beams=5):
    """
    Sinh ch√∫ th√≠ch cho m·ªôt batch ·∫£nh ƒë√£ ƒë∆∞·ª£c ti·ªÅn x·ª≠ l√Ω.
    """
    model.eval()

    with torch.no_grad():
        image_embeddings = clip_model.encode_image(image_batch).float()
        prefix_embeddings = model.mapping_network(image_embeddings)

        output_ids = gpt2_model.generate(
            inputs_embeds=prefix_embeddings,
            max_length=max_length,
            num_beams=num_beams,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            early_stopping=True
        )
        
        # Decode c·∫£ batch output v√† tr·∫£ v·ªÅ m·ªôt list c√°c caption
        captions = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    
    return captions

In [None]:
!pip install -q git+https://github.com/salaniz/pycocoevalcap

In [None]:
from tqdm import tqdm
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
import os
from PIL import Image

model = ClipCapModel(clip_model, gpt2_model).to(device)
model.mapping_network.load_state_dict(torch.load(f"mapping_network_epoch_{30}.pth"))

model.eval()
gts = {}  # Ground truth captions
res = {}  # Model-generated captions

image_to_captions_val = {key: image_to_captions[key] for key in val_images}

for images, labels, image_paths in tqdm(val_dataloader, desc="Evaluating on validation set"):
    images = images.to(device)
    gen_captions = generate_caption_batch(images, model)

    for j, image_path in enumerate(image_paths):
        image_key = os.path.basename(image_path)

        # 1. L∆∞u ch√∫ th√≠ch do m√¥ h√¨nh sinh ra
        res[image_key] = [gen_captions[j].strip()]
        
        # 2. L·∫•y T·∫§T C·∫¢ ch√∫ th√≠ch m·∫´u t·ª´ dictionary g·ªëc
        if image_key in image_to_captions_val:
            gts[image_key] = image_to_captions_val[image_key]

cider_scorer = Cider()
cider_score, _ = cider_scorer.compute_score(gts, res)

# --- T√≠nh SPICE ---
spice_scorer = Spice()
spice_score, _ = spice_scorer.compute_score(gts, res)

print(f"üîπ CIDEr: {cider_score:.4f}")
print(f"üî∏ SPICE: {spice_score:.4f}")
