# Model Definition


In [None]:
from PIL import Image

import torch
from torch.utils.data import Dataset

class ImageCaptionDataset(Dataset):
    def __init__(self, dataframe, clip_model, feature_extractor, vision_transformer, device='cpu', max_len=402):
        self.df = dataframe
        self.clip_model = clip_model
        self.feature_extractor = feature_extractor
        self.vision_transformer = vision_transformer
        self.device = device

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = row['image']      # Already a PIL.Image object
        caption = row['caption']  # Text string

        # Get image patch embeddings using ViT
        image_feats = self.get_image_patches(image)

        # Get text token IDs using CLIP tokenizer
        inputs = self.clip_model(text=caption, return_tensors="pt", padding=max_len).to(self.device)
        caption_ids = inputs["input_ids"].squeeze(0).to(device)  # Shape: (seq_len,)

        return image_feats, caption_ids

    def get_image_patches(self, image):
        inputs = self.feature_extractor(images=image, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.vision_transformer(**inputs)
        return outputs.last_hidden_state.squeeze(0)  # shape: (num_patches + 1, hidden_dim)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.register_buffer('pe', pe.unsqueeze(1))  # shape: (max_len, 1, d_model)

    def forward(self, x):  # x: (batch, seq_len, d_model)
        return x + self.pe[:x.size(1)].transpose(0, 1)  # (1, seq_len, d_model)


In [None]:
class MaskedSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.head_dim = d_model // n_heads
        self.n_heads = n_heads

        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        B, T, D = x.size()
        H = self.n_heads
        qkv = self.qkv_proj(x)  # (B, T, 3*D)
        qkv = qkv.view(B, T, H, 3 * self.head_dim)
        q, k, v = qkv.chunk(3, dim=-1)  # each: (B, T, H, head_dim)

        # Rearrange to (B, H, T, head_dim)
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)  # (B, H, T, T)

        # Causal mask
        mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0)  # (1, 1, T, T)
        attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(attn_scores, dim=-1)
        out = torch.matmul(attn, v)  # (B, H, T, head_dim)

        out = out.permute(0, 2, 1, 3).contiguous().view(B, T, D)  # (B, T, D)
        return self.out_proj(out)


In [None]:
class CustomDecoderBlock(nn.Module):
    def __init__(self, d_model, n_heads, dim_ff):
        super().__init__()
        self.self_attn = MaskedSelfAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(),
            nn.Linear(dim_ff, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # Self-attention
        attn_out = self.self_attn(x)
        x = self.norm1(x + attn_out)

        # Feed-forward
        ff_out = self.ff(x)
        x = self.norm2(x + ff_out)
        return x


In [None]:
class CustomDecoder(nn.Module):
    def __init__(self, vocab_size, image_feature_dim, d_model=768, n_heads=8, dim_ff=2048, num_layers=3, max_len=512):
        super().__init__()
        self.image_proj = nn.Linear(image_feature_dim, d_model)
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len + 1)  # +1 for image token
        self.layers = nn.ModuleList([
            CustomDecoderBlock(d_model, n_heads, dim_ff) for _ in range(num_layers)
        ])
        self.output_proj = nn.Linear(d_model, vocab_size)

    def forward(self, image_feats, input_ids):
        """
        image_feats: (batch, image_feature_dim)
        input_ids: (batch, seq_len)
        """
        # Project image to decoder hidden size
        img_token = self.image_proj(image_feats).unsqueeze(1)  # (batch, 1, d_model)

        # Embed input tokens
        x = self.embedding(input_ids)  # (batch, seq_len, d_model)

        # Concatenate image token to beginning of sequence
        x = torch.cat([img_token, x], dim=1)  # (batch, seq_len + 1, d_model)

        # Add positional encoding
        x = self.pos_enc(x)

        # Pass through transformer layers
        for layer in self.layers:
            x = layer(x)

        # Project to vocab
        logits = self.output_proj(x)  # (batch, seq_len + 1, vocab_size)
        return F.log_softmax(logits, dim=-1)


In [None]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    image_feats, captions = zip(*batch)

    # Stack image features: shape -> (batch_size, num_patches+1, hidden_dim)
    image_feats = torch.stack(image_feats)

    # Pad captions
    captions_padded = pad_sequence(captions, batch_first=True, padding_value=0)

    return image_feats, captions_padded

In [None]:
from transformers import ViTFeatureExtractor
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
device = 'cpu'
# print(device)
dataset = ImageCaptionDataset(
    dataframe=train_df,
    clip_model=clip_tokenizer,
    feature_extractor=feature_extractor,
    vision_transformer=vision_transformer,
    device=device
)

dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=collate_fn
)

In [None]:
# dataloader
for image_feats, caption_tokens in dataloader:
    print(image_feats.shape)       # (batch_size, num_patches+1, hidden_dim)
    print(caption_tokens.shape)    # (batch_size, max_caption_len)
    break

AttributeError: CLIPTokenizer has no attribute processor

In [None]:
import torch.nn as nn
from tqdm import tqdm

# Setup
device = 'cpu'

vocab_size = clip_tokenizer.vocab_size
model = CustomDecoder(vocab_size=vocab_size)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # padding index is 0

num_epochs = 5


for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for image_feats, captions in tqdm(dataloader):
        # Move to device
        image_feats = image_feats.to(device)
        captions = captions.to(device)
 
        print(image_feats.device, captions.device, next(model.parameters()).device)
        # break
        
        # Prepare inputs and targets
        inputs = captions[:, :-1]      # remove <end> token
        targets = captions[:, 1:]      # shift left

        # Forward pass
        outputs = model(image_feats, inputs)  # shape: (B, seq_len, vocab_size)

        # Compute loss
        outputs = outputs.view(-1, outputs.size(-1))    # (B * seq_len, vocab_size)
        targets = targets.reshape(-1)                   # (B * seq_len)
        loss = criterion(outputs, targets)

        # Backprop + optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss = {epoch_loss:.4f}")


  0%|          | 0/18125 [00:00<?, ?it/s]


AttributeError: CLIPTokenizer has no attribute processor

In [None]:
def generate_caption(model, image_feats, tokenizer, max_len=20):
    model.eval()
    image_feats = image_feats.unsqueeze(0).to(device)  # add batch dim

    caption_ids = [tokenizer.bos_token_id]
    for _ in range(max_len):
        input_ids = torch.tensor(caption_ids, dtype=torch.long).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = model(image_feats, input_ids)
        next_token_logits = outputs[0, -1]  # last time step
        next_token = next_token_logits.argmax().item()
        if next_token == tokenizer.eos_token_id:
            break
        caption_ids.append(next_token)

    return tokenizer.decode(caption_ids, skip_special_tokens=True)
