In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from transformers import CLIPProcessor, CLIPModel, GPT2Config, GPT2LMHeadModel
from pycocotools.coco import COCO
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import nltk
from nltk.tokenize import word_tokenize
from PIL import Image
import requests
from io import BytesIO
import json
import os
from tqdm import tqdm
import re


In [2]:
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [3]:
# ✅ Download and Load COCO Annotations
ann_file = "captions_train2017.json"

if not os.path.exists(ann_file):
    !wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
    !unzip annotations_trainval2017.zip -d .
    ann_file = "annotations/captions_train2017.json"

coco = COCO(ann_file)  # Load COCO captions dataset


--2025-01-30 19:16:47--  http://images.cocodataset.org/annotations/annotations_trainval2017.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 16.15.177.69, 54.231.170.249, 3.5.21.92, ...
Connecting to images.cocodataset.org (images.cocodataset.org)|16.15.177.69|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 252907541 (241M) [application/zip]
Saving to: ‘annotations_trainval2017.zip’


2025-01-30 19:17:05 (13.8 MB/s) - ‘annotations_trainval2017.zip’ saved [252907541/252907541]

Archive:  annotations_trainval2017.zip
  inflating: ./annotations/instances_train2017.json  
  inflating: ./annotations/instances_val2017.json  
  inflating: ./annotations/captions_train2017.json  
  inflating: ./annotations/captions_val2017.json  
  inflating: ./annotations/person_keypoints_train2017.json  
  inflating: ./annotations/person_keypoints_val2017.json  
loading annotations into memory...
Done (t=0.93s)
creating index...
index created!


In [4]:
class Vocabulary:
    def __init__(self, freq_threshold=5):
        self.itos = {0: "<PAD>", 1: "<BOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {v: k for k, v in self.itos.items()}
        self.freq_threshold = freq_threshold

    def build_vocab(self, captions):
        counter = nltk.FreqDist()
        for caption in captions:
            tokens = word_tokenize(caption.lower())
            counter.update(tokens)

        idx = 4
        for word, count in counter.items():
            if count >= self.freq_threshold:
                self.stoi[word] = idx
                self.itos[idx] = word
                idx += 1

    def numericalize(self, text):
        tokens = word_tokenize(text.lower())
        return [self.stoi.get(token, self.stoi["<UNK>"]) for token in tokens]

# ✅ Build Vocabulary from COCO Captions
captions = [coco.anns[ann_id]["caption"] for ann_id in coco.anns.keys()]
vocab = Vocabulary(freq_threshold=5)
vocab.build_vocab(captions)  # Now `vocab` is properly initialized

print(f"Vocabulary Size: {len(vocab.stoi)}")  # ✅ Debugging to confirm vocab size


Vocabulary Size: 10322


In [15]:
class COCODataset(Dataset):
    def __init__(self, ann_file, transform_norm, transform_clip, vocab):
        self.coco = COCO(ann_file)
        self.ids = list(self.coco.anns.keys())[:5000]  # Subset for faster training
        self.transform_norm = transform_norm
        self.transform_clip = transform_clip
        self.vocab = vocab

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        ann_id = self.ids[idx]
        caption = self.coco.anns[ann_id]["caption"]
        img_id = self.coco.anns[ann_id]["image_id"]
        img_data = self.coco.loadImgs(img_id)[0]
        img_url = img_data["coco_url"]

        response = requests.get(img_url)
        image = Image.open(BytesIO(response.content)).convert("RGB")

        # ✅ Apply two different transformations: one for ResNet, one for CLIP
        image_norm = self.transform_norm(image)  # Normalized for ResNet/GPT
        image_raw = self.transform_clip(image)  # Unnormalized for CLIP

        numericalized_caption = [self.vocab.stoi["<BOS>"]] + \
                                self.vocab.numericalize(caption) + \
                                [self.vocab.stoi["<EOS>"]]

        return image_norm, image_raw, torch.tensor(numericalized_caption)


# ✅ Define `collate_fn` to Fix Padding Issue
def collate_fn(batch):
    images_norm = [item[0] for item in batch]  # ResNet normalized images
    images_raw = [item[1] for item in batch]  # Original images for CLIP
    captions = [item[2] for item in batch]

    images_norm = torch.stack(images_norm, dim=0)  # ResNet input
    images_raw = torch.stack(images_raw, dim=0)  # CLIP input
    captions = pad_sequence(captions, batch_first=True, padding_value=0)  # Pad captions

    return images_norm, images_raw, captions


# ✅ Normalized transform (for ResNet/GPT)
transform_norm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ✅ Raw transform (for CLIP - NO NORMALIZATION)
transform_clip = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])  # No normalization for CLIP!

# ✅ Initialize Dataset & DataLoader with both transforms
train_dataset = COCODataset(ann_file, transform_norm, transform_clip, vocab)
train_loader = DataLoader(
    train_dataset, batch_size=16, shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn
)


loading annotations into memory...
Done (t=1.46s)
creating index...
index created!


In [16]:
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

class CLIPFeatureExtractor(nn.Module):
    def __init__(self, clip_model, input_dim=512, output_dim=768):
        super(CLIPFeatureExtractor, self).__init__()
        self.clip_model = clip_model
        self.projection = nn.Linear(input_dim, output_dim)  # Convert 512 → 768

    def forward(self, images_raw):  # ✅ Use unnormalized images
        inputs = clip_processor(images=images_raw, return_tensors="pt", do_rescale=False).to(images_raw.device)
        with torch.no_grad():
            features = self.clip_model.get_image_features(**inputs)  # (batch_size, 512)

        projected_features = self.projection(features)  # (batch_size, 768)
        return projected_features.unsqueeze(1)  # Ensure 3D shape (batch, 1, 768)

# ✅ Replace direct CLIP extraction with the new class
clip_feature_extractor = CLIPFeatureExtractor(clip_model).to(device)

In [13]:
class DecoderTransformer(nn.Module):
    def __init__(self, vocab_size, embed_size=768):
        super(DecoderTransformer, self).__init__()

        config = GPT2Config.from_pretrained("gpt2")
        config.add_cross_attention = True

        self.gpt2 = GPT2LMHeadModel(config)
        self.gpt2.resize_token_embeddings(vocab_size)

    def forward(self, captions, features):
        assert features.dim() == 3, f"Encoder output must be 3D, but got {features.shape}"
        outputs = self.gpt2(input_ids=captions, encoder_hidden_states=features)
        return outputs.logits


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embed_size = 768
vocab_size = len(vocab.stoi)

decoder = DecoderTransformer(vocab_size, embed_size).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0).to(device)
optimizer = optim.AdamW(decoder.parameters(), lr=1e-4)

num_epochs = 5
for epoch in range(num_epochs):
    decoder.train()
    total_loss = 0

    for images_norm, images_raw, captions in tqdm(train_loader):
        images_norm, images_raw, captions = images_norm.to(device), images_raw.to(device), captions.to(device)

        # ✅ Extract features using CLIP with proper projection
        features = clip_feature_extractor(images_raw)  # Now correctly 768-dimensional

        input_captions = captions[:, :-1]
        target_captions = captions[:, 1:]

        outputs = decoder(input_captions, features)
        loss = criterion(outputs.view(-1, outputs.size(-1)), target_captions.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(train_loader):.4f}")


100%|██████████| 313/313 [49:37<00:00,  9.51s/it]


Epoch [1/5], Loss: 4.5011


  5%|▌         | 16/313 [02:33<39:05,  7.90s/it]