# Dataset

In [2]:
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CocoCaptions
import torchvision.transforms as T


image_path = "coco/images/val2017"
annotations = "coco/annotations/captions_val2017.json"

transform = T.Compose([
    T.Resize((224,224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

coco_dataset = CocoCaptions(root=image_path, annFile=annotations, transform=transform)

loading annotations into memory...
Done (t=0.06s)
creating index...
index created!


In [3]:
from transformers import CLIPTokenizer


model_id = "openai/clip-vit-base-patch32"
tokenizer = CLIPTokenizer.from_pretrained(model_id)

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

In [4]:
class CocoTokenizedDataset(Dataset):
    def __init__(self, base_dataset, tokenizer, max_length=32):
        self.base_dataset = base_dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        image, captions = self.base_dataset[idx]
        caption = captions[0]

        tokens = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        input_ids = tokens.input_ids.squeeze(0)
        attention_mask = tokens.attention_mask.squeeze(0)

        return image, input_ids, attention_mask


tokenized_coco = CocoTokenizedDataset(coco_dataset, tokenizer)

In [5]:
def collate_fn(batch):
    images, input_ids, attention_masks = zip(*batch)
    images = torch.stack(images)
    input_ids = torch.stack(input_ids)
    attention_masks = torch.stack(attention_masks)
    return images, input_ids, attention_masks

# Model

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18


# Image Encoder
class ImageEncoder(nn.Module):
    def __init__(self, embed_dim=256):
        super().__init__()
        base_model = resnet18(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-1])  # Output: (batch, 512, 1, 1)
        self.fc = nn.Linear(512, embed_dim)

    def forward(self, x):
        features = self.feature_extractor(x).squeeze(-1).squeeze(-1)
        embedding = self.fc(features)  # (batch, embed_dim)
        embedding = F.normalize(embedding, dim=-1)  # L2-Norm
        return embedding


# Text Encoder
class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, max_len=32, n_heads=4, n_layers=2):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, embed_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=n_heads)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.fc = nn.Linear(embed_dim, embed_dim)
        self.max_len = max_len

    def forward(self, input_ids, attention_mask=None):
        x = self.token_embedding(input_ids)
        x = x + self.pos_embedding[:, :x.size(1), :]
        x = x.permute(1, 0, 2)

        if attention_mask is not None:
            key_padding_mask = ~attention_mask.bool()
        else:
            key_padding_mask = None

        encoded = self.transformer_encoder(x, src_key_padding_mask=key_padding_mask) 
        encoded = encoded.permute(1, 0, 2)

        pooled = encoded.mean(dim=1)

        embedding = self.fc(pooled)
        embedding = F.normalize(embedding, dim=-1)
        return embedding

In [7]:
class MiniCLIP(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, max_len=32):
        super().__init__()
        self.image_encoder = ImageEncoder(embed_dim)
        self.text_encoder = TextEncoder(vocab_size, embed_dim, max_len)

    def forward(self, images, input_ids, attention_mask):
        image_embeds = self.image_encoder(images)
        text_embeds = self.text_encoder(input_ids, attention_mask)
        return image_embeds, text_embeds

# Training

In [8]:
def contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    """InfoNCE Loss"""
    logits = torch.matmul(image_embeds, text_embeds.T)  # (batch, batch)
    logits = logits / temperature

    labels = torch.arange(len(image_embeds)).to(image_embeds.device)
    loss_i2t = F.cross_entropy(logits, labels)      # Image → Text
    loss_t2i = F.cross_entropy(logits.T, labels)    # Text → Image

    loss = (loss_i2t + loss_t2i) / 2
    return loss

In [10]:
batch_size = 32
dataloader = DataLoader(tokenized_coco, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [14]:
from tqdm import tqdm


# Model params
embed_dim = 256
vocab_size = tokenizer.vocab_size
max_len = 32

model = MiniCLIP(vocab_size=vocab_size, embed_dim=embed_dim, max_len=max_len).to(device)

# Training params
temperature = 0.07
lr = 1e-4
epochs = 5

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
model.train()
for epoch in range(epochs):
    total_loss = 0.0
    loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
    for images, input_ids, attention_masks in loop:
        images = images.to(device)
        input_ids = input_ids.to(device)
        attention_masks = attention_masks.to(device)

        image_embeds, text_embeds = model(images, input_ids, attention_masks)
        loss = contrastive_loss(image_embeds, text_embeds, temperature)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")

Epoch 1/5: 100%|██████████| 157/157 [03:53<00:00,  1.49s/it]


Epoch 1/5 - Loss: 2.7697


Epoch 2/5: 100%|██████████| 157/157 [04:01<00:00,  1.54s/it]


Epoch 2/5 - Loss: 1.3046


Epoch 3/5: 100%|██████████| 157/157 [03:52<00:00,  1.48s/it]


Epoch 3/5 - Loss: 0.4696


Epoch 4/5: 100%|██████████| 157/157 [03:58<00:00,  1.52s/it]


Epoch 4/5 - Loss: 0.1962


Epoch 5/5: 100%|██████████| 157/157 [03:59<00:00,  1.52s/it]

Epoch 5/5 - Loss: 0.1140





In [18]:
import os

save_dir = 'models/mini-clip'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

model.eval()
all_img_embeds = []
all_txt_embeds = []

with torch.no_grad():
    for images, input_ids, attention_masks in tqdm(dataloader):
        images = images.to(device)
        input_ids = input_ids.to(device)
        attention_masks = attention_masks.to(device)

        img_emb, txt_emb = model(images, input_ids, attention_masks)
        all_img_embeds.append(img_emb.cpu())
        all_txt_embeds.append(txt_emb.cpu())

image_embeddings = torch.cat(all_img_embeds, dim=0)
text_embeddings = torch.cat(all_txt_embeds, dim=0)

torch.save(image_embeddings, f"{save_dir}/image_embeds.pt")
torch.save(text_embeddings, f"{save_dir}/text_embeds.pt")

100%|██████████| 157/157 [03:41<00:00,  1.41s/it]


In [19]:
torch.save(model.state_dict(), f"{save_dir}/mini_clip_model.pt")