In [2]:
import math
import numpy as np
import pandas as pd
import torch
import os
import random
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import clip
from PIL import Image
from pathlib import Path
from tqdm.auto import tqdm
import re
from clip.simple_tokenizer import SimpleTokenizer
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
n_epochs = 20
testsize = 0.1
temp = torch.tensor(0.07)  # Temperature scaling parameter
batch_size = 256
model_save_dr = "/home/hice1/asubramanian91/scratch/iMET/models/"
prob = 0  # Contamination level: set to 0 if no corruption

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

# Get embedding size from the text projection (this is your final embedding dimension, e.g. 512)
embed_dim = model.text_projection.shape[1]
print("Embedding dimension:", embed_dim)

Embedding dimension: 512


In [5]:
df_data = pd.read_csv('/home/hice1/asubramanian91/scratch/iMET/Small_train_set/cleaned_train.csv', index_col='id')
df_labels = pd.read_csv('/home/hice1/asubramanian91/scratch/iMET/Small_train_set/sampled_labels.csv', index_col="attribute_id")
image_dir = "/home/hice1/asubramanian91/scratch/iMET/Small_train_set/small_train"
image_names = sorted(os.listdir(image_dir))
train_names, test_names, _, _ = train_test_split(image_names, image_names, test_size=testsize, random_state=42)

In [6]:
_tokenizer = SimpleTokenizer()
def tokenize(texts, context_length: int = 77) -> torch.LongTensor:
    if isinstance(texts, str):
        texts = [texts]
    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
    for i, tokens in enumerate(all_tokens):
        n = min(len(tokens), context_length)
        result[i, :n] = torch.tensor(tokens)[:n]
        if len(tokens) > context_length:
            result[i, -1] = tokens[-1]
    return result

In [7]:
class MyDataset(Dataset):
    def __init__(self, df_data, train_names, df_labels):
        super().__init__()
        self.df_data = df_data
        self.train_names = train_names
        self.df_labels = df_labels
        
    def __len__(self):
        return len(self.train_names)
    
    def __getitem__(self, idx):
        image_name = self.train_names[idx]
        image_path = os.path.join(image_dir, image_name)
        # Get the label(s) corresponding to the image (as string)
        text_ids_string = self.df_data.loc[image_name[:-4]].iloc[0]
        # With probability `prob`, corrupt the text by replacing it with another random image's labels
        if torch.rand(1).item() < prob:
            random_name = random.choice(self.train_names)
            text_ids_string = self.df_data.loc[random_name[:-4]].iloc[0]
        text_ids_string_list = text_ids_string.split(" ")
        # Extract text; assume format 'id::text'
        text_list = [self.df_labels.loc[int(ii)].iloc[0].split("::")[1] for ii in text_ids_string_list]
        image = preprocess(Image.open(image_path).convert("RGB"))
        text_all = ', '.join(text_list)
        text = tokenize([text_all])[0]
        return image, text


In [8]:
dstrain = MyDataset(df_data, train_names, df_labels)
dltrain = DataLoader(dstrain, batch_size=batch_size, num_workers=1)
dstest = MyDataset(df_data, test_names, df_labels)
dltest = DataLoader(dstest, batch_size=batch_size, num_workers=1)

In [9]:
class ProjectionWithDropout(nn.Module):
    def __init__(self, in_features, out_features, dropout_p=0.2, bias=True):
        super(ProjectionWithDropout, self).__init__()
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features))
        else:
            self.register_parameter('bias', None)
        self.dropout = nn.Dropout(p=dropout_p)
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if bias:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)
    def forward(self, x):
        x = x.to(self.weight.dtype)
        out = F.linear(x, self.weight, self.bias)
        out = self.dropout(out)
        return out
    def __rmatmul__(self, x):
        return self.forward(x)

In [10]:

dropout_prob = 0.2  # Set your desired dropout probability

# ------------------
# MODIFY PROJECTION HEADS (for Visual and Text) TO USE THE CUSTOM MODULE
# ------------------

# --- Visual Projection Head ---
if hasattr(model.visual, 'proj'):
    if not isinstance(model.visual.proj, nn.Module):
        orig_weight = model.visual.proj  # Originally a tensor.
        # Check shape of original weight.
        # For example, if its shape is (768, 512), then we need to transpose it
        # so that the desired weight shape is (512, 768) (i.e., in_features=768, out_features=512).
        if orig_weight.shape[0] == 768:
            new_weight = orig_weight.t()  # new_weight shape: (512,768)
        else:
            new_weight = orig_weight  # Otherwise, assume it's already (512,768)
        in_features = new_weight.size(1)  # expected: 768
        out_features = new_weight.size(0)  # expected: 512
        new_image_proj = ProjectionWithDropout(in_features, out_features, dropout_p=dropout_prob, bias=True)
        new_image_proj.weight = torch.nn.Parameter(new_weight.clone().to(model.dtype))
        new_image_proj.bias = torch.nn.Parameter(torch.zeros(out_features, device=new_weight.device, dtype=model.dtype))
        if "proj" in model.visual._parameters:
            del model.visual._parameters["proj"]
        model.visual.proj = new_image_proj.to(model.dtype)

# --- Text Projection Head ---
if hasattr(model, 'text_projection'):
    if not isinstance(model.text_projection, nn.Module):
        orig_weight = model.text_projection  # Originally a tensor.
        # Check its shape: if shape[0] equals 768 (and desired output is 512), transpose it.
        if orig_weight.shape[0] == 768:
            new_weight = orig_weight.t()  # new_weight shape: (512,768)
        else:
            new_weight = orig_weight
        in_features = new_weight.size(1)  # expected: 768
        out_features = new_weight.size(0)  # expected: 512
        new_text_proj = ProjectionWithDropout(in_features, out_features, dropout_p=dropout_prob, bias=True)
        new_text_proj.weight = torch.nn.Parameter(new_weight.clone().to(model.dtype))
        new_text_proj.bias = torch.nn.Parameter(torch.zeros(out_features, device=new_weight.device, dtype=model.dtype))
        if "text_projection" in model._parameters:
            del model._parameters["text_projection"]
        model.text_projection = new_text_proj.to(model.dtype)

model.to(device)

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [11]:
optimzr = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.2)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimzr, 1e-2, total_steps=n_epochs * (2 * len(dltrain) - 1),
    base_momentum=0.0, max_momentum=0.5, pct_start=0.1, div_factor=1e2, final_div_factor=1e4
)
criterion = torch.nn.CrossEntropyLoss()

In [12]:
# Helper: Rolling Mean for tracking loss
class RollingMean():
    def __init__(self):
        self.n = 0
        self.mean = 0
    def update(self, value):
        self.mean = (self.mean * self.n + value) / (self.n + 1)
        self.n += 1
    def result(self):
        return self.mean

In [13]:
def mc_dropout_forward(model, images, texts, num_samples=10):
    model.eval()
    # Force dropout layers to remain active by setting them to train mode.
    for module in model.modules():
        if isinstance(module, nn.Dropout):
            module.train()
    image_embeddings = []
    text_embeddings = []
    with torch.no_grad():
        for _ in range(num_samples):
            I_e = model.encode_image(images.to(device).to(model.dtype))
            T_e = model.encode_text(texts.to(device))  # Do not convert texts to model.dtype
            I_e = I_e / I_e.norm(2, dim=1, keepdim=True)
            T_e = T_e / T_e.norm(2, dim=1, keepdim=True)
            image_embeddings.append(I_e)
            text_embeddings.append(T_e)
    image_embeddings = torch.stack(image_embeddings, dim=0)  # (num_samples, B, D)
    text_embeddings = torch.stack(text_embeddings, dim=0)      # (num_samples, B, D)
    image_mean = image_embeddings.mean(dim=0)
    text_mean = text_embeddings.mean(dim=0)
    image_var = image_embeddings.var(dim=0)
    text_var = text_embeddings.var(dim=0)
    return image_mean, image_var, text_mean, text_var

In [14]:
loss_test_list = []

for epoch in range(n_epochs):
    if epoch % 5 == 0 and epoch > 0:
        torch.save(model.state_dict(), os.path.join(model_save_dr, f"{100 * prob}_{epoch}.pth"))
    
    loss_train_epoch = []
    model.train()  # Enable dropout during training
    with tqdm(total=2 * len(dltrain) - 1, desc=f"Epoch {epoch} Training") as bar:
        loss_mean_tracker = RollingMean()
        for images, texts in dltrain:
            images = images.to(device).to(model.dtype)
            texts = texts.to(device)  # tokenized text (integers) remain unchanged
            
            # Forward pass through CLIP encoders
            I_e = model.encode_image(images.to(device).to(model.dtype))
            T_e = model.encode_text(texts.to(device))  # Keep texts as LongTensor
            I_e = I_e / I_e.norm(2, dim=1, keepdim=True)
            T_e = T_e / T_e.norm(2, dim=1, keepdim=True)
            
            optimzr.zero_grad()
            scale = torch.exp(temp.to(device))
            logits_i = I_e @ T_e.T * scale
            logits_t = T_e @ I_e.T * scale
            labels = torch.arange(images.size(0), device=device)
            loss_i = criterion(logits_i, labels)
            loss_t = criterion(logits_t, labels)
            loss = (loss_i + loss_t) / 2
            
            loss.backward()
            optimzr.step()
            scheduler.step()
            
            loss_train_epoch.append(loss.item())
            loss_mean_tracker.update(loss.item())
            bar.update(1)
            bar.set_description(f"Train Loss: {loss_mean_tracker.result():.4f}")
    
    # Validation using MC dropout inference
    loss_val_epoch = []
    with tqdm(total=2 * len(dltest) - 1, desc=f"Epoch {epoch} Validation") as bar:
        for images, texts in dltest:
            images = images.to(device).to(model.dtype)
            texts = texts.to(device)
            
            I_e_mean, I_e_var, T_e_mean, T_e_var = mc_dropout_forward(model, images, texts, num_samples=10)
            logits_i = I_e_mean @ T_e_mean.T * scale
            logits_t = T_e_mean @ I_e_mean.T * scale
            labels = torch.arange(images.size(0), device=device)
            loss_i = criterion(logits_i, labels)
            loss_t = criterion(logits_t, labels)
            loss_val = (loss_i + loss_t) / 2
            loss_val_epoch.append(loss_val.item())
            bar.update(1)
    
    avg_train_loss = np.mean(loss_train_epoch)
    avg_val_loss = np.mean(loss_val_epoch)
    loss_test_list.append(avg_val_loss)
    print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

Train Loss: 5.5337:  51%|█████     | 45/89 [18:58<18:33, 25.31s/it]
Epoch 0 Validation:  56%|█████▌    | 5/9 [06:42<05:21, 80.47s/it]


Epoch 0 | Train Loss: 5.5337 | Val Loss: 5.5283


Train Loss: 5.4954:  51%|█████     | 45/89 [18:30<18:05, 24.67s/it]
Epoch 1 Validation:  56%|█████▌    | 5/9 [06:38<05:19, 79.77s/it]


Epoch 1 | Train Loss: 5.4954 | Val Loss: 5.4199


Train Loss: 5.4123:  45%|████▍     | 40/89 [16:42<20:28, 25.07s/it]