In [7]:
import os
import requests
from pycocotools.coco import COCO
from zipfile import ZipFile
from tqdm import tqdm

# Download COCO dataset annotations
annotations_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
annotations_path = "annotations_trainval2017.zip"
annotations_dir = "annotations"

if not os.path.exists(annotations_path):
    print(f"Downloading {annotations_url}...")
    response = requests.get(annotations_url)
    with open(annotations_path, 'wb') as f:
        f.write(response.content)

    # Extract the zip file
    with ZipFile(annotations_path, 'r') as zip_ref:
        zip_ref.extractall(annotations_dir)

# Load annotations
coco = COCO(os.path.join(annotations_dir, "annotations/instances_train2017.json"))

# Select 1,000 images
image_ids = coco.getImgIds()[:1000]

# Function to download images from COCO
def download_coco_images(coco, img_ids, save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    for img_id in tqdm(img_ids):
        img_info = coco.loadImgs(img_id)[0]
        img_url = img_info["coco_url"]
        img_path = os.path.join(save_dir, img_info["file_name"])
        if not os.path.exists(img_path):
            response = requests.get(img_url)
            with open(img_path, 'wb') as f:
                f.write(response.content)

# Download images
download_coco_images(coco, image_ids, "coco_images")

Downloading http://images.cocodataset.org/annotations/annotations_trainval2017.zip...
loading annotations into memory...
Done (t=11.21s)
creating index...
index created!


100%|██████████| 1000/1000 [25:27<00:00,  1.53s/it]


In [3]:
!pip install pycocotools

Defaulting to user installation because normal site-packages is not writeable
Collecting pycocotools
  Downloading pycocotools-2.0.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.1 kB)
Downloading pycocotools-2.0.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (443 kB)
Installing collected packages: pycocotools
Successfully installed pycocotools-2.0.8


In [4]:
!pip install datasets

Defaulting to user installation because normal site-packages is not writeable


In [2]:
import os
import pandas as pd
from pycocotools.coco import COCO
from tqdm import tqdm

# Load COCO captions annotations
annotations_dir = "annotations"
captions_file = os.path.join(annotations_dir, "annotations/captions_train2017.json")
coco = COCO(captions_file)

# Select 1,000 image IDs (same as before)
image_ids = coco.getImgIds()[:1000]

# Create a DataFrame with image paths and captions
data = []
for img_id in tqdm(image_ids, desc="Processing images"):
    img_info = coco.loadImgs(img_id)[0]
    ann_ids = coco.getAnnIds(imgIds=img_id)  # Get annotation IDs for this image
    anns = coco.loadAnns(ann_ids)  # Load annotations (captions)
    caption = anns[0]["caption"] if anns else "No caption available"  # Use first caption
    data.append({"image_path": os.path.join("coco_images", img_info["file_name"]), "caption": caption})

# Save to DataFrame
df = pd.DataFrame(data)
df.to_csv("coco_dataset.csv", index=False)

print("Dataset prepared and saved to 'coco_dataset.csv'")
print(df.head())  # Preview the first few rows

loading annotations into memory...
Done (t=1.04s)
creating index...
index created!


Processing images: 100%|██████████| 1000/1000 [00:00<00:00, 72460.51it/s]

Dataset prepared and saved to 'coco_dataset.csv'
                     image_path  \
0  coco_images/000000391895.jpg   
1  coco_images/000000522418.jpg   
2  coco_images/000000184613.jpg   
3  coco_images/000000318219.jpg   
4  coco_images/000000554625.jpg   

                                             caption  
0  A man with a red helmet on a small moped on a ...  
1  A woman wearing a net on her head cutting a ca...  
2  A child holding a flowered umbrella and pettin...  
3  A young boy standing in front of a computer ke...  
4  a boy wearing headphones using one computer in...  





In [3]:
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import BertTokenizer

class CocoDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data["image_path"][idx]
        caption = self.data["caption"][idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        text = self.tokenizer(caption, return_tensors="pt", padding="max_length", max_length=32, truncation=True)
        return image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

# Create DataLoader
dataset = CocoDataset("coco_dataset.csv")
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)

In [5]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from transformers import ViTModel, DistilBertModel, DistilBertTokenizer  # Updated to DistilBertModel
from torchvision import transforms
from PIL import Image, UnidentifiedImageError
import pandas as pd
from tqdm import tqdm

# Function to select best GPU
def get_best_gpu():
    if not torch.cuda.is_available():
        print("No CUDA GPUs available. Falling back to CPU.")
        return torch.device("cpu")
    
    num_gpus = torch.cuda.device_count()
    if num_gpus == 0:
        print("No GPUs detected. Falling back to CPU.")
        return torch.device("cpu")
    
    print(f"Detected {num_gpus} CUDA GPUs.")
    free_memory = []
    for i in range(num_gpus):
        torch.cuda.set_device(i)
        total_memory = torch.cuda.get_device_properties(i).total_memory
        allocated_memory = torch.cuda.memory_allocated(i)
        reserved_memory = torch.cuda.memory_reserved(i)
        free_mem = total_memory - (allocated_memory + reserved_memory)
        free_memory.append((i, free_mem))
        print(f"GPU {i}: Total Memory = {total_memory / 1024**3:.2f} GiB, "
              f"Free Memory = {free_mem / 1024**3:.2f} GiB")
    
    best_gpu_idx, best_free_mem = max(free_memory, key=lambda x: x[1])
    print(f"Selecting GPU {best_gpu_idx} with {best_free_mem / 1024**3:.2f} GiB free memory.")
    return torch.device(f"cuda:{best_gpu_idx}")

# Dataset with error handling
class CocoDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
        self.bad_images = []  # Track problematic images

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data["image_path"][idx]
        caption = self.data["caption"][idx]
        
        # Handle image loading with error checking
        try:
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Image not found: {img_path}")
            image = Image.open(img_path).convert("RGB")
        except (UnidentifiedImageError, FileNotFoundError, Exception) as e:
            print(f"Warning: Skipping {img_path} due to error: {e}")
            self.bad_images.append(img_path)
            # Return a dummy image and text to keep the batch size consistent
            dummy_image = torch.zeros(3, 224, 224)  # Dummy tensor for invalid images
            text = self.tokenizer("Invalid image", return_tensors="pt", padding="max_length", max_length=32, truncation=True)
            return dummy_image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

        if self.transform:
            image = self.transform(image)
        text = self.tokenizer(caption, return_tensors="pt", padding="max_length", max_length=32, truncation=True)
        return image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

dataset = CocoDataset("coco_dataset.csv")
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2)  # Reduced batch size

# CLIP Model
class CLIP(nn.Module):
    def __init__(self):
        super(CLIP, self).__init__()
        self.vision_encoder = ViTModel.from_pretrained("facebook/deit-small-patch16-224")
        self.text_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased")  # Updated to DistilBertModel
        self.projection_dim = 512
        self.vision_proj = nn.Linear(384, self.projection_dim)  # DeiT-Small outputs 384-dim
        self.text_proj = nn.Linear(768, self.projection_dim)    # DistilBERT outputs 768-dim

    def forward(self, images, input_ids, attention_mask):
        vision_outputs = self.vision_encoder(images).last_hidden_state[:, 0, :]
        text_outputs = self.text_encoder(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        vision_embeds = self.vision_proj(vision_outputs)
        text_embeds = self.text_proj(text_outputs)
        return vision_embeds, text_embeds

def contrastive_loss(vision_embeds, text_embeds, temperature=0.07):
    logits = torch.matmul(vision_embeds, text_embeds.T) / temperature
    labels = torch.arange(len(vision_embeds)).to(vision_embeds.device)
    loss_i = nn.CrossEntropyLoss()(logits, labels)
    loss_t = nn.CrossEntropyLoss()(logits.T, labels)
    return (loss_i + loss_t) / 2

# Training
device = get_best_gpu()
clip_model = CLIP().to(device)
optimizer = torch.optim.Adam(clip_model.parameters(), lr=1e-4)
scaler = GradScaler()

# Gradient accumulation settings
accumulation_steps = 4

for epoch in range(5):
    clip_model.train()
    total_loss = 0
    for i, (images, input_ids, attention_mask) in enumerate(tqdm(dataloader, desc=f"CLIP Epoch {epoch+1}")):
        images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)
        optimizer.zero_grad()
        with autocast():
            vision_embeds, text_embeds = clip_model(images, input_ids, attention_mask)
            loss = contrastive_loss(vision_embeds, text_embeds)
            loss = loss / accumulation_steps  # Normalize loss for gradient accumulation
        scaler.scale(loss).backward()
        
        if (i + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        total_loss += loss.item() * accumulation_steps  # Accumulate total loss
    
    print(f"CLIP Epoch {epoch+1}, Average Loss: {total_loss / len(dataloader):.4f}")

    # Log bad images after each epoch
    if dataset.bad_images:
        print(f"Bad images skipped in Epoch {epoch+1}: {dataset.bad_images}")

torch.save(clip_model.state_dict(), "clip_model.pth")

Detected 4 CUDA GPUs.
GPU 0: Total Memory = 10.75 GiB, Free Memory = 9.72 GiB
GPU 1: Total Memory = 10.75 GiB, Free Memory = 10.75 GiB
GPU 2: Total Memory = 10.75 GiB, Free Memory = 10.75 GiB
GPU 3: Total Memory = 10.75 GiB, Free Memory = 10.75 GiB
Selecting GPU 1 with 10.75 GiB free memory.


Some weights of ViTModel were not initialized from the model checkpoint at facebook/deit-small-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
CLIP Epoch 1:  80%|███████▉  | 399/500 [00:33<00:08, 11.40it/s]



CLIP Epoch 1: 100%|██████████| 500/500 [00:42<00:00, 11.76it/s]


CLIP Epoch 1, Average Loss: 10.8808


CLIP Epoch 2:  67%|██████▋   | 337/500 [00:27<00:13, 11.97it/s]



CLIP Epoch 2: 100%|██████████| 500/500 [00:41<00:00, 11.94it/s]


CLIP Epoch 2, Average Loss: 1.1382


CLIP Epoch 3:  88%|████████▊ | 441/500 [00:37<00:04, 14.69it/s]



CLIP Epoch 3: 100%|██████████| 500/500 [00:42<00:00, 11.87it/s]


CLIP Epoch 3, Average Loss: 0.8262


CLIP Epoch 4:  26%|██▌       | 130/500 [00:11<00:30, 12.31it/s]



CLIP Epoch 4: 100%|██████████| 500/500 [00:42<00:00, 11.77it/s]


CLIP Epoch 4, Average Loss: 0.7509


CLIP Epoch 5:  93%|█████████▎| 466/500 [00:38<00:02, 11.87it/s]



CLIP Epoch 5: 100%|██████████| 500/500 [00:41<00:00, 11.98it/s]


CLIP Epoch 5, Average Loss: 0.7471


In [6]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from transformers import BlipForImageTextRetrieval, BlipProcessor
from torchvision import transforms
from PIL import Image, UnidentifiedImageError
import pandas as pd
from tqdm import tqdm

# Dataset with error handling
class CocoDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        self.bad_images = []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data["image_path"][idx]
        caption = self.data["caption"][idx]
        
        try:
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Image not found: {img_path}")
            image = Image.open(img_path).convert("RGB")
        except (UnidentifiedImageError, FileNotFoundError, Exception) as e:
            print(f"Warning: Skipping {img_path} due to error: {e}")
            self.bad_images.append(img_path)
            dummy_image = torch.zeros(3, 224, 224)
            text = self.processor.tokenizer("Invalid image", return_tensors="pt", padding="max_length", max_length=32, truncation=True)
            return dummy_image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

        if self.transform:
            image = self.transform(image)
        text = self.processor.tokenizer(caption, return_tensors="pt", padding="max_length", max_length=32, truncation=True)
        return image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

# BLIP Model with Contrastive Loss
class BLIP(nn.Module):
    def __init__(self):
        super(BLIP, self).__init__()
        self.blip = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-image-captioning-base")
        self.projection_dim = 512
        self.vision_proj = nn.Linear(768, self.projection_dim)  # ViT output dim
        self.text_proj = nn.Linear(768, self.projection_dim)    # Text encoder output dim

    def forward(self, images, input_ids, attention_mask):
        # Vision encoding
        vision_outputs = self.blip.vision_model(pixel_values=images).last_hidden_state[:, 0, :]
        vision_embeds = self.vision_proj(vision_outputs)
        
        # Text encoding
        text_outputs = self.blip.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        text_embeds = self.text_proj(text_outputs)
        
        return vision_embeds, text_embeds

def contrastive_loss(vision_embeds, text_embeds, temperature=0.07):
    logits = torch.matmul(vision_embeds, text_embeds.T) / temperature
    labels = torch.arange(len(vision_embeds)).to(vision_embeds.device)
    loss_i = nn.CrossEntropyLoss()(logits, labels)
    loss_t = nn.CrossEntropyLoss()(logits.T, labels)
    return (loss_i + loss_t) / 2

# Setup
dataset = CocoDataset("coco_dataset.csv")
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2)  # Reduced batch size

device = torch.device("cuda:3")  # Hardcoded to GPU 3 (adjust if needed)
torch.cuda.empty_cache()  # Clear memory

print(f"Using device: {device}")
blip_model = BLIP().to(device)
optimizer = torch.optim.Adam(blip_model.parameters(), lr=1e-4)
scaler = GradScaler()

# Training Loop
for epoch in range(5):
    blip_model.train()
    total_loss = 0
    for images, input_ids, attention_mask in tqdm(dataloader, desc=f"BLIP Epoch {epoch+1}"):
        images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)
        optimizer.zero_grad()
        with autocast():
            vision_embeds, text_embeds = blip_model(images, input_ids, attention_mask)
            loss = contrastive_loss(vision_embeds, text_embeds)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f"BLIP Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
    if dataset.bad_images:
        print(f"Bad images skipped in Epoch {epoch+1}: {list(set(dataset.bad_images))}")

torch.save(blip_model.state_dict(), "blip_model.pth")
print("BLIP model saved to blip_model.pth")

preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Using device: cuda:3


config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

Some weights of BlipForImageTextRetrieval were not initialized from the model checkpoint at Salesforce/blip-image-captioning-base and are newly initialized: ['itm_head.bias', 'itm_head.weight', 'text_encoder.embeddings.LayerNorm.bias', 'text_encoder.embeddings.LayerNorm.weight', 'text_encoder.embeddings.position_embeddings.weight', 'text_encoder.embeddings.word_embeddings.weight', 'text_encoder.encoder.layer.0.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.0.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.0.attention.output.dense.bias', 'text_encoder.encoder.layer.0.attention.output.dense.weight', 'text_encoder.encoder.layer.0.attention.self.key.bias', 'text_encoder.encoder.layer.0.attention.self.key.weight', 'text_encoder.encoder.layer.0.attention.self.query.bias', 'text_encoder.encoder.layer.0.attention.self.query.weight', 'text_encoder.encoder.layer.0.attention.self.value.bias', 'text_encoder.encoder.layer.0.attention.self.value.weight', 'text_encoder.

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

BLIP Epoch 1:  39%|███▊      | 193/500 [00:30<00:43,  7.07it/s]



BLIP Epoch 1: 100%|██████████| 500/500 [01:13<00:00,  6.78it/s]


BLIP Epoch 1, Average Loss: 1.3196


BLIP Epoch 2:  17%|█▋        | 84/500 [00:11<00:55,  7.52it/s]



BLIP Epoch 2: 100%|██████████| 500/500 [01:09<00:00,  7.22it/s]


BLIP Epoch 2, Average Loss: 0.7007


BLIP Epoch 3:  87%|████████▋ | 434/500 [00:59<00:07,  8.31it/s]



BLIP Epoch 3: 100%|██████████| 500/500 [01:09<00:00,  7.25it/s]


BLIP Epoch 3, Average Loss: 0.6958


BLIP Epoch 4:  71%|███████   | 356/500 [00:48<00:20,  7.20it/s]



BLIP Epoch 4: 100%|██████████| 500/500 [01:09<00:00,  7.20it/s]


BLIP Epoch 4, Average Loss: 0.6948


BLIP Epoch 5:  39%|███▉      | 194/500 [00:27<00:45,  6.67it/s]



BLIP Epoch 5: 100%|██████████| 500/500 [01:08<00:00,  7.27it/s]


BLIP Epoch 5, Average Loss: 0.6904
BLIP model saved to blip_model.pth


In [7]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from transformers import ViltProcessor, ViltModel
from torchvision import transforms
from PIL import Image, UnidentifiedImageError
import pandas as pd
from tqdm import tqdm

# Dataset with error handling
class CocoDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
        self.bad_images = []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data["image_path"][idx]
        caption = self.data["caption"][idx]
        
        try:
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Image not found: {img_path}")
            image = Image.open(img_path).convert("RGB")
        except (UnidentifiedImageError, FileNotFoundError, Exception) as e:
            print(f"Warning: Skipping {img_path} due to error: {e}")
            self.bad_images.append(img_path)
            dummy_image = torch.zeros(3, 224, 224)
            text = self.processor.tokenizer("Invalid image", return_tensors="pt", padding="max_length", max_length=32, truncation=True)
            return dummy_image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

        if self.transform:
            image = self.transform(image)
        text = self.processor.tokenizer(caption, return_tensors="pt", padding="max_length", max_length=32, truncation=True)
        return image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

# ViLT Model with Contrastive Loss
class VILT(nn.Module):
    def __init__(self):
        super(VILT, self).__init__()
        self.vilt = ViltModel.from_pretrained("dandelin/vilt-b32-mlm")
        self.projection_dim = 512
        self.proj = nn.Linear(768, self.projection_dim)  # ViLT output dim is 768

    def forward(self, images, input_ids, attention_mask):
        # ViLT processes images and text together
        outputs = self.vilt(pixel_values=images, input_ids=input_ids, attention_mask=attention_mask)
        embeds = self.proj(outputs.last_hidden_state[:, 0, :])  # CLS token
        # For contrastive loss, use same embeddings for vision and text (simplified)
        return embeds, embeds

def contrastive_loss(vision_embeds, text_embeds, temperature=0.07):
    logits = torch.matmul(vision_embeds, text_embeds.T) / temperature
    labels = torch.arange(len(vision_embeds)).to(vision_embeds.device)
    loss_i = nn.CrossEntropyLoss()(logits, labels)
    loss_t = nn.CrossEntropyLoss()(logits.T, labels)
    return (loss_i + loss_t) / 2

# Setup
dataset = CocoDataset("coco_dataset.csv")
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2)  # Reduced batch size

device = torch.device("cuda:3")  # Hardcoded to GPU 3 (adjust if needed)
torch.cuda.empty_cache()  # Clear memory

print(f"Using device: {device}")
vilt_model = VILT().to(device)
optimizer = torch.optim.Adam(vilt_model.parameters(), lr=1e-4)
scaler = GradScaler()

# Training Loop
for epoch in range(5):
    vilt_model.train()
    total_loss = 0
    for images, input_ids, attention_mask in tqdm(dataloader, desc=f"ViLT Epoch {epoch+1}"):
        images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)
        optimizer.zero_grad()
        with autocast():
            vision_embeds, text_embeds = vilt_model(images, input_ids, attention_mask)
            loss = contrastive_loss(vision_embeds, text_embeds)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f"ViLT Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
    if dataset.bad_images:
        print(f"Bad images skipped in Epoch {epoch+1}: {list(set(dataset.bad_images))}")

torch.save(vilt_model.state_dict(), "vilt_model.pth")
print("ViLT model saved to vilt_model.pth")

preprocessor_config.json:   0%|          | 0.00/251 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/320 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Using device: cuda:3


config.json:   0%|          | 0.00/653 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/543M [00:00<?, ?B/s]

ViLT Epoch 1:   2%|▏         | 12/500 [00:01<00:48, 10.15it/s]

model.safetensors:   0%|          | 0.00/543M [00:00<?, ?B/s]

ViLT Epoch 1:  83%|████████▎ | 415/500 [00:40<00:07, 10.89it/s]



ViLT Epoch 1: 100%|██████████| 500/500 [00:48<00:00, 10.35it/s]


ViLT Epoch 1, Average Loss: nan


ViLT Epoch 2:  61%|██████    | 305/500 [00:28<00:17, 11.09it/s]



ViLT Epoch 2: 100%|██████████| 500/500 [00:46<00:00, 10.83it/s]


ViLT Epoch 2, Average Loss: 0.1915


ViLT Epoch 3:  55%|█████▌    | 277/500 [00:25<00:20, 11.13it/s]



ViLT Epoch 3: 100%|██████████| 500/500 [00:45<00:00, 10.90it/s]


ViLT Epoch 3, Average Loss: 0.1519


ViLT Epoch 4:  29%|██▊       | 143/500 [00:13<00:31, 11.24it/s]



ViLT Epoch 4: 100%|██████████| 500/500 [00:46<00:00, 10.82it/s]


ViLT Epoch 4, Average Loss: 0.0693


ViLT Epoch 5:  80%|███████▉  | 398/500 [00:36<00:08, 11.52it/s]



ViLT Epoch 5: 100%|██████████| 500/500 [00:45<00:00, 11.05it/s]


ViLT Epoch 5, Average Loss: 0.0620
ViLT model saved to vilt_model.pth


In [8]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torchvision import transforms
from PIL import Image, UnidentifiedImageError
import pandas as pd
from tqdm import tqdm

# Multi-Head Attention (simplified)
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.qkv = nn.Linear(d_model, d_model * 3)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size = x.size(0)
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = [t.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) for t in qkv]
        attn = (q @ k.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.out(out)

# Vision Transformer Encoder
class ViTEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, d_model=256, num_heads=8, num_layers=6):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model))
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(d_model),
                MultiHeadAttention(d_model, num_heads),
                nn.LayerNorm(d_model),
                nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
            ]) for _ in range(num_layers)
        ])

    def forward(self, x):
        x = self.patch_embed(x).flatten(2).transpose(1, 2)  # [B, num_patches, d_model]
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat([cls_tokens, x], dim=1) + self.pos_embed
        for norm1, attn, norm2, ff in self.layers:
            x = x + attn(norm1(x))
            x = x + ff(norm2(x))
        return x[:, 0]  # CLS token

# Text Transformer Encoder
class TextEncoder(nn.Module):
    def __init__(self, vocab_size=30522, d_model=256, num_heads=8, num_layers=6, max_len=32):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, d_model))
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(d_model),
                MultiHeadAttention(d_model, num_heads),
                nn.LayerNorm(d_model),
                nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
            ]) for _ in range(num_layers)
        ])

    def forward(self, input_ids, attention_mask):
        x = self.embedding(input_ids) + self.pos_embed
        for norm1, attn, norm2, ff in self.layers:
            x = x + attn(norm1(x)) * attention_mask.unsqueeze(-1)
            x = x + ff(norm2(x)) * attention_mask.unsqueeze(-1)
        return x[:, 0]  # CLS token

# CLIP from Scratch
class CLIP(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision_encoder = ViTEncoder()
        self.text_encoder = TextEncoder()
        self.projection_dim = 512
        self.vision_proj = nn.Linear(256, self.projection_dim)
        self.text_proj = nn.Linear(256, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        vision_embeds = self.vision_proj(self.vision_encoder(images))
        text_embeds = self.text_proj(self.text_encoder(input_ids, attention_mask))
        return vision_embeds, text_embeds

# Dataset
class CocoDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        from transformers import DistilBertTokenizer
        self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
        self.bad_images = []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data["image_path"][idx]
        caption = self.data["caption"][idx]
        try:
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Image not found: {img_path}")
            image = Image.open(img_path).convert("RGB")
        except (UnidentifiedImageError, FileNotFoundError, Exception) as e:
            print(f"Warning: Skipping {img_path} due to error: {e}")
            self.bad_images.append(img_path)
            dummy_image = torch.zeros(3, 224, 224)
            text = self.tokenizer("Invalid image", return_tensors="pt", padding="max_length", max_length=32, truncation=True)
            return dummy_image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()
        if self.transform:
            image = self.transform(image)
        text = self.tokenizer(caption, return_tensors="pt", padding="max_length", max_length=32, truncation=True)
        return image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

def contrastive_loss(vision_embeds, text_embeds, temperature=0.07):
    logits = torch.matmul(vision_embeds, text_embeds.T) / temperature
    labels = torch.arange(len(vision_embeds)).to(vision_embeds.device)
    loss_i = nn.CrossEntropyLoss()(logits, labels)
    loss_t = nn.CrossEntropyLoss()(logits.T, labels)
    return (loss_i + loss_t) / 2

# Setup and Training
dataset = CocoDataset("coco_dataset.csv")
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2)
device = torch.device("cuda:3")
torch.cuda.empty_cache()

print(f"Using device: {device}")
clip_model = CLIP().to(device)
optimizer = torch.optim.Adam(clip_model.parameters(), lr=1e-4)
scaler = GradScaler()

for epoch in range(5):  # More epochs may be needed
    clip_model.train()
    total_loss = 0
    for images, input_ids, attention_mask in tqdm(dataloader, desc=f"CLIP Epoch {epoch+1}"):
        images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)
        optimizer.zero_grad()
        with autocast():
            vision_embeds, text_embeds = clip_model(images, input_ids, attention_mask)
            loss = contrastive_loss(vision_embeds, text_embeds)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f"CLIP Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
    if dataset.bad_images:
        print(f"Bad images skipped in Epoch {epoch+1}: {list(set(dataset.bad_images))}")

torch.save(clip_model.state_dict(), "clip_from_scratch.pth")
print("CLIP model saved to clip_from_scratch.pth")

Using device: cuda:3


CLIP Epoch 1:  93%|█████████▎| 465/500 [00:32<00:02, 14.97it/s]



CLIP Epoch 1: 100%|██████████| 500/500 [00:34<00:00, 14.30it/s]


CLIP Epoch 1, Average Loss: 7.6916


CLIP Epoch 2:  45%|████▍     | 223/500 [00:15<00:18, 14.87it/s]



CLIP Epoch 2: 100%|██████████| 500/500 [00:33<00:00, 14.77it/s]


CLIP Epoch 2, Average Loss: 2.9676


CLIP Epoch 3:  83%|████████▎ | 413/500 [00:27<00:05, 14.95it/s]



CLIP Epoch 3: 100%|██████████| 500/500 [00:33<00:00, 14.73it/s]


CLIP Epoch 3, Average Loss: 1.3232


CLIP Epoch 4:  87%|████████▋ | 433/500 [00:29<00:04, 15.26it/s]



CLIP Epoch 4: 100%|██████████| 500/500 [00:33<00:00, 14.74it/s]


CLIP Epoch 4, Average Loss: 1.1029


CLIP Epoch 5:  37%|███▋      | 185/500 [00:12<00:21, 14.78it/s]



CLIP Epoch 5: 100%|██████████| 500/500 [00:33<00:00, 14.71it/s]


CLIP Epoch 5, Average Loss: 1.4876
CLIP model saved to clip_from_scratch.pth


In [9]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torchvision import transforms
from PIL import Image, UnidentifiedImageError
import pandas as pd
from tqdm import tqdm

# Reuse MultiHeadAttention, ViTEncoder, TextEncoder from CLIP
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.qkv = nn.Linear(d_model, d_model * 3)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size = x.size(0)
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = [t.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) for t in qkv]
        attn = (q @ k.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.out(out)

class ViTEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, d_model=256, num_heads=8, num_layers=6):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model))
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(d_model),
                MultiHeadAttention(d_model, num_heads),
                nn.LayerNorm(d_model),
                nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
            ]) for _ in range(num_layers)
        ])

    def forward(self, x):
        x = self.patch_embed(x).flatten(2).transpose(1, 2)
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat([cls_tokens, x], dim=1) + self.pos_embed
        for norm1, attn, norm2, ff in self.layers:
            x = x + attn(norm1(x))
            x = x + ff(norm2(x))
        return x[:, 0]

class TextEncoder(nn.Module):
    def __init__(self, vocab_size=30522, d_model=256, num_heads=8, num_layers=6, max_len=32):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, d_model))
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(d_model),
                MultiHeadAttention(d_model, num_heads),
                nn.LayerNorm(d_model),
                nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
            ]) for _ in range(num_layers)
        ])

    def forward(self, input_ids, attention_mask):
        x = self.embedding(input_ids) + self.pos_embed
        for norm1, attn, norm2, ff in self.layers:
            x = x + attn(norm1(x)) * attention_mask.unsqueeze(-1)
            x = x + ff(norm2(x)) * attention_mask.unsqueeze(-1)
        return x[:, 0]

# BLIP from Scratch
class BLIP(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision_encoder = ViTEncoder()
        self.text_encoder = TextEncoder()
        self.projection_dim = 512
        self.vision_proj = nn.Linear(256, self.projection_dim)
        self.text_proj = nn.Linear(256, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        vision_embeds = self.vision_proj(self.vision_encoder(images))
        text_embeds = self.text_proj(self.text_encoder(input_ids, attention_mask))
        return vision_embeds, text_embeds

# Dataset
class CocoDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        from transformers import DistilBertTokenizer
        self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
        self.bad_images = []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data["image_path"][idx]
        caption = self.data["caption"][idx]
        try:
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Image not found: {img_path}")
            image = Image.open(img_path).convert("RGB")
        except (UnidentifiedImageError, FileNotFoundError, Exception) as e:
            print(f"Warning: Skipping {img_path} due to error: {e}")
            self.bad_images.append(img_path)
            dummy_image = torch.zeros(3, 224, 224)
            text = self.tokenizer("Invalid image", return_tensors="pt", padding="max_length", max_length=32, truncation=True)
            return dummy_image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()
        if self.transform:
            image = self.transform(image)
        text = self.tokenizer(caption, return_tensors="pt", padding="max_length", max_length=32, truncation=True)
        return image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

def contrastive_loss(vision_embeds, text_embeds, temperature=0.07):
    logits = torch.matmul(vision_embeds, text_embeds.T) / temperature
    labels = torch.arange(len(vision_embeds)).to(vision_embeds.device)
    loss_i = nn.CrossEntropyLoss()(logits, labels)
    loss_t = nn.CrossEntropyLoss()(logits.T, labels)
    return (loss_i + loss_t) / 2

# Setup and Training
dataset = CocoDataset("coco_dataset.csv")
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2)
device = torch.device("cuda:3")
torch.cuda.empty_cache()

print(f"Using device: {device}")
blip_model = BLIP().to(device)
optimizer = torch.optim.Adam(blip_model.parameters(), lr=1e-4)
scaler = GradScaler()

for epoch in range(5):  # More epochs may be needed
    blip_model.train()
    total_loss = 0
    for images, input_ids, attention_mask in tqdm(dataloader, desc=f"BLIP Epoch {epoch+1}"):
        images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)
        optimizer.zero_grad()
        with autocast():
            vision_embeds, text_embeds = blip_model(images, input_ids, attention_mask)
            loss = contrastive_loss(vision_embeds, text_embeds)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f"BLIP Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
    if dataset.bad_images:
        print(f"Bad images skipped in Epoch {epoch+1}: {list(set(dataset.bad_images))}")

torch.save(blip_model.state_dict(), "blip_from_scratch.pth")
print("BLIP model saved to blip_from_scratch.pth")

Using device: cuda:3


BLIP Epoch 1:  85%|████████▌ | 425/500 [00:29<00:05, 14.92it/s]



BLIP Epoch 1: 100%|██████████| 500/500 [00:34<00:00, 14.68it/s]


BLIP Epoch 1, Average Loss: 7.0565


BLIP Epoch 2:  61%|██████    | 303/500 [00:20<00:13, 14.79it/s]



BLIP Epoch 2: 100%|██████████| 500/500 [00:33<00:00, 14.85it/s]


BLIP Epoch 2, Average Loss: 1.8474


BLIP Epoch 3:  32%|███▏      | 161/500 [00:11<00:20, 16.15it/s]



BLIP Epoch 3: 100%|██████████| 500/500 [00:34<00:00, 14.58it/s]


BLIP Epoch 3, Average Loss: 1.5648


BLIP Epoch 4:  43%|████▎     | 217/500 [00:14<00:18, 14.90it/s]



BLIP Epoch 4: 100%|██████████| 500/500 [00:33<00:00, 14.83it/s]


BLIP Epoch 4, Average Loss: 1.3408


BLIP Epoch 5:  19%|█▉        | 95/500 [00:06<00:28, 14.08it/s]



BLIP Epoch 5: 100%|██████████| 500/500 [00:33<00:00, 14.88it/s]


BLIP Epoch 5, Average Loss: 1.5527
BLIP model saved to blip_from_scratch.pth


In [10]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torchvision import transforms
from PIL import Image, UnidentifiedImageError
import pandas as pd
from tqdm import tqdm

# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.qkv = nn.Linear(d_model, d_model * 3)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size = x.size(0)
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = [t.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) for t in qkv]
        attn = (q @ k.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.out(out)

# ViLT Encoder
class ViltEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, vocab_size=30522, d_model=256, num_heads=8, num_layers=6, max_len=32):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size)
        self.text_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + max_len + 1, d_model))
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(d_model),
                MultiHeadAttention(d_model, num_heads),
                nn.LayerNorm(d_model),
                nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
            ]) for _ in range(num_layers)
        ])

    def forward(self, images, input_ids, attention_mask):
        # Image patches
        img_embeds = self.patch_embed(images).flatten(2).transpose(1, 2)
        # Text tokens
        text_embeds = self.text_embed(input_ids)
        # Concatenate
        x = torch.cat([img_embeds, text_embeds], dim=1)
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat([cls_tokens, x], dim=1) + self.pos_embed
        # Attention mask for images + text
        img_mask = torch.ones(x.size(0), img_embeds.size(1) + 1, device=x.device)
        full_mask = torch.cat([img_mask, attention_mask], dim=1)
        for norm1, attn, norm2, ff in self.layers:
            x = x + attn(norm1(x)) * full_mask.unsqueeze(-1)
            x = x + ff(norm2(x)) * full_mask.unsqueeze(-1)
        return x[:, 0]

# ViLT from Scratch
class VILT(nn.Module):
    def __init__(self):
        super().__init__()
        self.vilt = ViltEncoder()
        self.projection_dim = 512
        self.proj = nn.Linear(256, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        embeds = self.proj(self.vilt(images, input_ids, attention_mask))
        return embeds, embeds  # Same embeddings for contrastive loss

# Dataset
class CocoDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        from transformers import DistilBertTokenizer
        self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
        self.bad_images = []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data["image_path"][idx]
        caption = self.data["caption"][idx]
        try:
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Image not found: {img_path}")
            image = Image.open(img_path).convert("RGB")
        except (UnidentifiedImageError, FileNotFoundError, Exception) as e:
            print(f"Warning: Skipping {img_path} due to error: {e}")
            self.bad_images.append(img_path)
            dummy_image = torch.zeros(3, 224, 224)
            text = self.tokenizer("Invalid image", return_tensors="pt", padding="max_length", max_length=32, truncation=True)
            return dummy_image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()
        if self.transform:
            image = self.transform(image)
        text = self.tokenizer(caption, return_tensors="pt", padding="max_length", max_length=32, truncation=True)
        return image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

def contrastive_loss(vision_embeds, text_embeds, temperature=0.07):
    logits = torch.matmul(vision_embeds, text_embeds.T) / temperature
    labels = torch.arange(len(vision_embeds)).to(vision_embeds.device)
    loss_i = nn.CrossEntropyLoss()(logits, labels)
    loss_t = nn.CrossEntropyLoss()(logits.T, labels)
    return (loss_i + loss_t) / 2

# Setup and Training
dataset = CocoDataset("coco_dataset.csv")
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2)
device = torch.device("cuda:3")
torch.cuda.empty_cache()

print(f"Using device: {device}")
vilt_model = VILT().to(device)
optimizer = torch.optim.Adam(vilt_model.parameters(), lr=1e-4)
scaler = GradScaler()

for epoch in range(5):  # More epochs may be needed
    vilt_model.train()
    total_loss = 0
    for images, input_ids, attention_mask in tqdm(dataloader, desc=f"ViLT Epoch {epoch+1}"):
        images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)
        optimizer.zero_grad()
        with autocast():
            vision_embeds, text_embeds = vilt_model(images, input_ids, attention_mask)
            loss = contrastive_loss(vision_embeds, text_embeds)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f"ViLT Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
    if dataset.bad_images:
        print(f"Bad images skipped in Epoch {epoch+1}: {list(set(dataset.bad_images))}")

torch.save(vilt_model.state_dict(), "vilt_from_scratch.pth")
print("ViLT model saved to vilt_from_scratch.pth")

Using device: cuda:3


ViLT Epoch 1:  89%|████████▊ | 443/500 [00:19<00:02, 23.29it/s]



ViLT Epoch 1: 100%|██████████| 500/500 [00:22<00:00, 22.35it/s]


ViLT Epoch 1, Average Loss: 26.7046


ViLT Epoch 2:  81%|████████  | 406/500 [00:18<00:04, 22.70it/s]



ViLT Epoch 2: 100%|██████████| 500/500 [00:22<00:00, 21.89it/s]


ViLT Epoch 2, Average Loss: 6.8063


ViLT Epoch 3:  18%|█▊        | 91/500 [00:04<00:18, 22.50it/s]



ViLT Epoch 3: 100%|██████████| 500/500 [00:22<00:00, 21.84it/s]


ViLT Epoch 3, Average Loss: 19.1642


ViLT Epoch 4:  74%|███████▍  | 371/500 [00:16<00:05, 22.59it/s]



ViLT Epoch 4: 100%|██████████| 500/500 [00:22<00:00, 22.37it/s]


ViLT Epoch 4, Average Loss: 3.5421


ViLT Epoch 5:  33%|███▎      | 166/500 [00:07<00:15, 21.97it/s]



ViLT Epoch 5: 100%|██████████| 500/500 [00:22<00:00, 21.87it/s]


ViLT Epoch 5, Average Loss: 0.7040
ViLT model saved to vilt_from_scratch.pth


In [2]:
import os
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, UnidentifiedImageError
from transformers import ViTModel, DistilBertModel, DistilBertTokenizer, BlipForImageTextRetrieval, BlipProcessor, ViltProcessor, ViltModel
from tqdm import tqdm
import numpy as np

# Environment settings
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Dataset
class CocoDataset(Dataset):
    def __init__(self, csv_file, transform=None, model_type="clip"):
        self.data = pd.read_csv(csv_file)
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.model_type = model_type
        if model_type == "blip":
            self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        elif model_type == "vilt":
            self.processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
        else:  # clip or from-scratch
            self.processor = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
        self.bad_images = []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data["image_path"][idx]
        caption = self.data["caption"][idx]
        try:
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Image not found: {img_path}")
            image = Image.open(img_path).convert("RGB")
        except (UnidentifiedImageError, FileNotFoundError, Exception) as e:
            print(f"Warning: Skipping {img_path} due to error: {e}")
            self.bad_images.append(img_path)
            dummy_image = torch.zeros(3, 224, 224)
            if self.model_type in ["blip", "vilt"]:
                text = self.processor.tokenizer("Invalid image", padding="max_length", max_length=32, truncation=True, return_tensors="pt")
            else:
                text = self.processor("Invalid image", padding="max_length", max_length=32, truncation=True, return_tensors="pt")
            return dummy_image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

        if self.transform:
            image = self.transform(image)
        if self.model_type == "blip":
            text = self.processor.tokenizer(caption, padding="max_length", max_length=32, truncation=True, return_tensors="pt")
        elif self.model_type == "vilt":
            text = self.processor.tokenizer(caption, padding="max_length", max_length=32, truncation=True, return_tensors="pt")
        else:
            text = self.processor(caption, padding="max_length", max_length=32, truncation=True, return_tensors="pt")
        return image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

# Pre-trained CLIP Model
class CLIP(nn.Module):
    def __init__(self):
        super(CLIP, self).__init__()
        self.vision_encoder = ViTModel.from_pretrained("facebook/deit-small-patch16-224")
        self.text_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.projection_dim = 512
        self.vision_proj = nn.Linear(384, self.projection_dim)
        self.text_proj = nn.Linear(768, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        vision_outputs = self.vision_encoder(images).last_hidden_state[:, 0, :]
        text_outputs = self.text_encoder(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        vision_embeds = self.vision_proj(vision_outputs)
        text_embeds = self.text_proj(text_outputs)
        return vision_embeds, text_embeds

# Pre-trained BLIP Model
class BLIP(nn.Module):
    def __init__(self):
        super(BLIP, self).__init__()
        self.blip = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-image-captioning-base")
        self.projection_dim = 512
        self.vision_proj = nn.Linear(768, self.projection_dim)
        self.text_proj = nn.Linear(768, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        vision_outputs = self.blip.vision_model(pixel_values=images).last_hidden_state[:, 0, :]
        text_outputs = self.blip.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        vision_embeds = self.vision_proj(vision_outputs)
        text_embeds = self.text_proj(text_outputs)
        return vision_embeds, text_embeds

# Pre-trained ViLT Model
class VILT(nn.Module):
    def __init__(self):
        super(VILT, self).__init__()
        self.vilt = ViltModel.from_pretrained("dandelin/vilt-b32-mlm")
        self.projection_dim = 512
        self.proj = nn.Linear(768, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        outputs = self.vilt(pixel_values=images, input_ids=input_ids, attention_mask=attention_mask)
        embeds = self.proj(outputs.last_hidden_state[:, 0, :])
        return embeds, embeds

# From-Scratch Models
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.qkv = nn.Linear(d_model, d_model * 3)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size = x.size(0)
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = [t.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) for t in qkv]
        attn = (q @ k.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.out(out)

class ViTEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, d_model=256, num_heads=8, num_layers=6):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model))
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(d_model),
                MultiHeadAttention(d_model, num_heads),
                nn.LayerNorm(d_model),
                nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
            ]) for _ in range(num_layers)
        ])

    def forward(self, x):
        x = self.patch_embed(x).flatten(2).transpose(1, 2)
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat([cls_tokens, x], dim=1) + self.pos_embed
        for norm1, attn, norm2, ff in self.layers:
            x = x + attn(norm1(x))
            x = x + ff(norm2(x))
        return x[:, 0]

class TextEncoder(nn.Module):
    def __init__(self, vocab_size=30522, d_model=256, num_heads=8, num_layers=6, max_len=32):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, d_model))
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(d_model),
                MultiHeadAttention(d_model, num_heads),
                nn.LayerNorm(d_model),
                nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
            ]) for _ in range(num_layers)
        ])

    def forward(self, input_ids, attention_mask):
        x = self.embedding(input_ids) + self.pos_embed
        for norm1, attn, norm2, ff in self.layers:
            x = x + attn(norm1(x)) * attention_mask.unsqueeze(-1)
            x = x + ff(norm2(x)) * attention_mask.unsqueeze(-1)
        return x[:, 0]

class CLIPFromScratch(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision_encoder = ViTEncoder()
        self.text_encoder = TextEncoder()
        self.projection_dim = 512
        self.vision_proj = nn.Linear(256, self.projection_dim)
        self.text_proj = nn.Linear(256, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        vision_embeds = self.vision_proj(self.vision_encoder(images))
        text_embeds = self.text_proj(self.text_encoder(input_ids, attention_mask))
        return vision_embeds, text_embeds

class BLIPFromScratch(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision_encoder = ViTEncoder()
        self.text_encoder = TextEncoder()
        self.projection_dim = 512
        self.vision_proj = nn.Linear(256, self.projection_dim)
        self.text_proj = nn.Linear(256, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        vision_embeds = self.vision_proj(self.vision_encoder(images))
        text_embeds = self.text_proj(self.text_encoder(input_ids, attention_mask))
        return vision_embeds, text_embeds

class ViltEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, vocab_size=30522, d_model=256, num_heads=8, num_layers=6, max_len=32):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size)
        self.text_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + max_len + 1, d_model))
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(d_model),
                MultiHeadAttention(d_model, num_heads),
                nn.LayerNorm(d_model),
                nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
            ]) for _ in range(num_layers)
        ])

    def forward(self, images, input_ids, attention_mask):
        img_embeds = self.patch_embed(images).flatten(2).transpose(1, 2)
        text_embeds = self.text_embed(input_ids)
        x = torch.cat([img_embeds, text_embeds], dim=1)
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat([cls_tokens, x], dim=1) + self.pos_embed
        img_mask = torch.ones(x.size(0), img_embeds.size(1) + 1, device=x.device)
        full_mask = torch.cat([img_mask, attention_mask], dim=1)
        for norm1, attn, norm2, ff in self.layers:
            x = x + attn(norm1(x)) * full_mask.unsqueeze(-1)
            x = x + ff(norm2(x)) * full_mask.unsqueeze(-1)
        return x[:, 0]

class VILTFromScratch(nn.Module):
    def __init__(self):
        super().__init__()
        self.vilt = ViltEncoder()
        self.projection_dim = 512
        self.proj = nn.Linear(256, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        embeds = self.proj(self.vilt(images, input_ids, attention_mask))
        return embeds, embeds

# Evaluation function
def evaluate_model(model, dataloader, device, model_name, model_type):
    model.eval()
    correct = 0
    total = 0
    similarities = []
    with torch.no_grad():
        for images, input_ids, attention_mask in tqdm(dataloader, desc=f"Evaluating {model_name}"):
            images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)
            vision_embeds, text_embeds = model(images, input_ids, attention_mask)
            
            # Normalize embeddings
            vision_embeds = vision_embeds / vision_embeds.norm(dim=-1, keepdim=True)
            text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
            
            # Compute cosine similarities
            logits = torch.matmul(vision_embeds, text_embeds.T)
            labels = torch.arange(len(images)).to(device)
            
            # Image-to-text retrieval (top-1 accuracy)
            pred_i2t = logits.argmax(dim=1)
            correct += (pred_i2t == labels).sum().item()
            
            # Text-to-image retrieval (top-1 accuracy)
            pred_t2i = logits.T.argmax(dim=1)
            correct += (pred_t2i == labels).sum().item()
            
            total += len(images) * 2  # Count both i2t and t2i
            
            # Average cosine similarity for matching pairs
            sim = logits.diag().cpu().numpy()
            similarities.extend(sim)
    
    accuracy = correct / total
    avg_similarity = np.mean(similarities)
    return {"model": model_name, "accuracy": accuracy, "avg_similarity": avg_similarity}

# Main evaluation
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.cuda.empty_cache()
    print(f"Using device: {device}")

    # Models and their configurations
    models = [
        ("CLIP", CLIP, "clip_model.pth", "clip"),
        ("BLIP", BLIP, "blip_model.pth", "blip"),
        ("VILT", VILT, "vilt_model.pth", "vilt"),
        ("CLIPFromScratch", CLIPFromScratch, "clip_from_scratch.pth", "clip"),
        ("BLIPFromScratch", BLIPFromScratch, "blip_from_scratch.pth", "clip"),
        ("VILTFromScratch", VILTFromScratch, "vilt_from_scratch.pth", "vilt")
    ]

    results = []
    for model_name, model_class, model_path, model_type in models:
        print(f"\nLoading {model_name}...")
        # Load dataset
        dataset = CocoDataset("coco_dataset.csv", model_type=model_type)
        dataloader = DataLoader(dataset, batch_size=8, shuffle=False, num_workers=2)  # Reduced batch size
        
        # Initialize and load model
        model = model_class().to(device)
        if os.path.exists(model_path):
            model.load_state_dict(torch.load(model_path, map_location=device))
        else:
            print(f"Warning: {model_path} not found. Using untrained model.")
        
        # Evaluate
        result = evaluate_model(model, dataloader, device, model_name, model_type)
        results.append(result)
        
        # Log bad images
        if dataset.bad_images:
            print(f"Bad images skipped in {model_name}: {list(set(dataset.bad_images))}")

    # Print results
    print("\nPerformance Comparison:")
    print("-" * 50)
    print(f"{'Model':<20} {'Accuracy':<12} {'Avg Similarity':<12}")
    print("-" * 50)
    for result in results:
        print(f"{result['model']:<20} {result['accuracy']:<12.4f} {result['avg_similarity']:<12.4f}")
    
    # Save results to CSV
    results_df = pd.DataFrame(results)
    results_df.to_csv("model_performance.csv", index=False)
    print("\nResults saved to 'model_performance.csv'")

if __name__ == "__main__":
    main()

Using device: cuda

Loading CLIP...


Some weights of ViTModel were not initialized from the model checkpoint at facebook/deit-small-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Evaluating CLIP:  74%|███████▍  | 93/125 [00:03<00:00, 32.65it/s]



Evaluating CLIP: 100%|██████████| 125/125 [00:04<00:00, 29.76it/s]



Loading BLIP...


Some weights of BlipForImageTextRetrieval were not initialized from the model checkpoint at Salesforce/blip-image-captioning-base and are newly initialized: ['itm_head.bias', 'itm_head.weight', 'text_encoder.embeddings.LayerNorm.bias', 'text_encoder.embeddings.LayerNorm.weight', 'text_encoder.embeddings.position_embeddings.weight', 'text_encoder.embeddings.word_embeddings.weight', 'text_encoder.encoder.layer.0.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.0.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.0.attention.output.dense.bias', 'text_encoder.encoder.layer.0.attention.output.dense.weight', 'text_encoder.encoder.layer.0.attention.self.key.bias', 'text_encoder.encoder.layer.0.attention.self.key.weight', 'text_encoder.encoder.layer.0.attention.self.query.bias', 'text_encoder.encoder.layer.0.attention.self.query.weight', 'text_encoder.encoder.layer.0.attention.self.value.bias', 'text_encoder.encoder.layer.0.attention.self.value.weight', 'text_encoder.



Evaluating BLIP: 100%|██████████| 125/125 [00:08<00:00, 15.24it/s]



Loading VILT...


Evaluating VILT:  75%|███████▌  | 94/125 [00:04<00:01, 25.00it/s]



Evaluating VILT: 100%|██████████| 125/125 [00:05<00:00, 22.71it/s]



Loading CLIPFromScratch...


Evaluating CLIPFromScratch:  77%|███████▋  | 96/125 [00:03<00:00, 31.97it/s]



Evaluating CLIPFromScratch: 100%|██████████| 125/125 [00:04<00:00, 29.87it/s]



Loading BLIPFromScratch...


Evaluating BLIPFromScratch:  74%|███████▍  | 93/125 [00:03<00:01, 31.41it/s]



Evaluating BLIPFromScratch: 100%|██████████| 125/125 [00:04<00:00, 28.37it/s]



Loading VILTFromScratch...


Evaluating VILTFromScratch:  74%|███████▍  | 93/125 [00:03<00:01, 31.96it/s]



Evaluating VILTFromScratch: 100%|██████████| 125/125 [00:04<00:00, 29.31it/s]


Performance Comparison:
--------------------------------------------------
Model                Accuracy     Avg Similarity
--------------------------------------------------
CLIP                 0.1955       -0.0341     
BLIP                 0.2125       0.1609      
VILT                 1.0000       1.0000      
CLIPFromScratch      0.1250       -0.1224     
BLIPFromScratch      0.1255       0.0022      
VILTFromScratch      1.0000       1.0000      

Results saved to 'model_performance.csv'





In [3]:
import os
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, UnidentifiedImageError
from transformers import ViTModel, DistilBertModel, DistilBertTokenizer, BlipForImageTextRetrieval, BlipProcessor, ViltProcessor, ViltModel
from tqdm import tqdm
import numpy as np

# Environment settings
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Dataset
class CocoDataset(Dataset):
    def __init__(self, csv_file, transform=None, model_type="clip"):
        self.data = pd.read_csv(csv_file)
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.model_type = model_type
        if model_type == "blip":
            self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        elif model_type == "vilt":
            self.processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
        else:  # clip or from-scratch
            self.processor = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
        self.bad_images = []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data["image_path"][idx]
        caption = self.data["caption"][idx]
        try:
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Image not found: {img_path}")
            image = Image.open(img_path).convert("RGB")
        except (UnidentifiedImageError, FileNotFoundError, Exception) as e:
            print(f"Warning: Skipping {img_path} due to error: {e}")
            self.bad_images.append(img_path)
            dummy_image = torch.zeros(3, 224, 224)
            if self.model_type in ["blip", "vilt"]:
                text = self.processor.tokenizer("Invalid image", padding="max_length", max_length=32, truncation=True, return_tensors="pt")
            else:
                text = self.processor("Invalid image", padding="max_length", max_length=32, truncation=True, return_tensors="pt")
            return dummy_image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

        if self.transform:
            image = self.transform(image)
        if self.model_type == "blip":
            text = self.processor.tokenizer(caption, padding="max_length", max_length=32, truncation=True, return_tensors="pt")
        elif self.model_type == "vilt":
            text = self.processor.tokenizer(caption, padding="max_length", max_length=32, truncation=True, return_tensors="pt")
        else:
            text = self.processor(caption, padding="max_length", max_length=32, truncation=True, return_tensors="pt")
        return image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

# Pre-trained CLIP Model
class CLIP(nn.Module):
    def __init__(self):
        super(CLIP, self).__init__()
        self.vision_encoder = ViTModel.from_pretrained("facebook/deit-small-patch16-224")
        self.text_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.projection_dim = 512
        self.vision_proj = nn.Linear(384, self.projection_dim)
        self.text_proj = nn.Linear(768, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        vision_outputs = self.vision_encoder(images).last_hidden_state[:, 0, :]
        text_outputs = self.text_encoder(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        vision_embeds = self.vision_proj(vision_outputs)
        text_embeds = self.text_proj(text_outputs)
        return vision_embeds, text_embeds

# Pre-trained BLIP Model
class BLIP(nn.Module):
    def __init__(self):
        super(BLIP, self).__init__()
        self.blip = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-image-captioning-base")
        self.projection_dim = 512
        self.vision_proj = nn.Linear(768, self.projection_dim)
        self.text_proj = nn.Linear(768, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        vision_outputs = self.blip.vision_model(pixel_values=images).last_hidden_state[:, 0, :]
        text_outputs = self.blip.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        vision_embeds = self.vision_proj(vision_outputs)
        text_embeds = self.text_proj(text_outputs)
        return vision_embeds, text_embeds

# Pre-trained ViLT Model
class VILT(nn.Module):
    def __init__(self):
        super(VILT, self).__init__()
        self.vilt = ViltModel.from_pretrained("dandelin/vilt-b32-mlm")
        self.projection_dim = 512
        self.proj = nn.Linear(768, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        outputs = self.vilt(pixel_values=images, input_ids=input_ids, attention_mask=attention_mask)
        embeds = self.proj(outputs.last_hidden_state[:, 0, :])
        return embeds, embeds

# From-Scratch Models
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.qkv = nn.Linear(d_model, d_model * 3)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size = x.size(0)
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = [t.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) for t in qkv]
        attn = (q @ k.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.out(out)

class ViTEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, d_model=256, num_heads=8, num_layers=6):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model))
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(d_model),
                MultiHeadAttention(d_model, num_heads),
                nn.LayerNorm(d_model),
                nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
            ]) for _ in range(num_layers)
        ])

    def forward(self, x):
        x = self.patch_embed(x).flatten(2).transpose(1, 2)
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat([cls_tokens, x], dim=1) + self.pos_embed
        for norm1, attn, norm2, ff in self.layers:
            x = x + attn(norm1(x))
            x = x + ff(norm2(x))
        return x[:, 0]

class TextEncoder(nn.Module):
    def __init__(self, vocab_size=30522, d_model=256, num_heads=8, num_layers=6, max_len=32):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, d_model))
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(d_model),
                MultiHeadAttention(d_model, num_heads),
                nn.LayerNorm(d_model),
                nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
            ]) for _ in range(num_layers)
        ])

    def forward(self, input_ids, attention_mask):
        x = self.embedding(input_ids) + self.pos_embed
        for norm1, attn, norm2, ff in self.layers:
            x = x + attn(norm1(x)) * attention_mask.unsqueeze(-1)
            x = x + ff(norm2(x)) * attention_mask.unsqueeze(-1)
        return x[:, 0]

class CLIPFromScratch(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision_encoder = ViTEncoder()
        self.text_encoder = TextEncoder()
        self.projection_dim = 512
        self.vision_proj = nn.Linear(256, self.projection_dim)
        self.text_proj = nn.Linear(256, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        vision_embeds = self.vision_proj(self.vision_encoder(images))
        text_embeds = self.text_proj(self.text_encoder(input_ids, attention_mask))
        return vision_embeds, text_embeds

class BLIPFromScratch(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision_encoder = ViTEncoder()
        self.text_encoder = TextEncoder()
        self.projection_dim = 512
        self.vision_proj = nn.Linear(256, self.projection_dim)
        self.text_proj = nn.Linear(256, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        vision_embeds = self.vision_proj(self.vision_encoder(images))
        text_embeds = self.text_proj(self.text_encoder(input_ids, attention_mask))
        return vision_embeds, text_embeds

class ViltEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, vocab_size=30522, d_model=256, num_heads=8, num_layers=6, max_len=32):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size)
        self.text_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + max_len + 1, d_model))
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(d_model),
                MultiHeadAttention(d_model, num_heads),
                nn.LayerNorm(d_model),
                nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
            ]) for _ in range(num_layers)
        ])

    def forward(self, images, input_ids, attention_mask):
        img_embeds = self.patch_embed(images).flatten(2).transpose(1, 2)
        text_embeds = self.text_embed(input_ids)
        x = torch.cat([img_embeds, text_embeds], dim=1)
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat([cls_tokens, x], dim=1) + self.pos_embed
        img_mask = torch.ones(x.size(0), img_embeds.size(1) + 1, device=x.device)
        full_mask = torch.cat([img_mask, attention_mask], dim=1)
        for norm1, attn, norm2, ff in self.layers:
            x = x + attn(norm1(x)) * full_mask.unsqueeze(-1)
            x = x + ff(norm2(x)) * full_mask.unsqueeze(-1)
        return x[:, 0]

class VILTFromScratch(nn.Module):
    def __init__(self):
        super().__init__()
        self.vilt = ViltEncoder()
        self.projection_dim = 512
        self.proj = nn.Linear(256, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        embeds = self.proj(self.vilt(images, input_ids, attention_mask))
        return embeds, embeds

# Evaluation function
def evaluate_model(model, dataloader, device, model_name, model_type):
    model.eval()
    topk_correct = {k: 0 for k in range(1, 6)}  # Track top-1 to top-5
    total = 0
    similarities = []
    with torch.no_grad():
        for images, input_ids, attention_mask in tqdm(dataloader, desc=f"Evaluating {model_name}"):
            images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)
            vision_embeds, text_embeds = model(images, input_ids, attention_mask)
            
            # Normalize embeddings
            vision_embeds = vision_embeds / vision_embeds.norm(dim=-1, keepdim=True)
            text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
            
            # Compute cosine similarities
            logits = torch.matmul(vision_embeds, text_embeds.T)
            labels = torch.arange(len(images)).to(device)
            
            # Top-k accuracy for i2t and t2i
            for k in range(1, 6):
                # Image-to-text
                _, pred_i2t = logits.topk(k, dim=1)
                correct_i2t = pred_i2t.eq(labels.view(-1, 1).expand_as(pred_i2t)).sum().item()
                topk_correct[k] += correct_i2t
                
                # Text-to-image
                _, pred_t2i = logits.T.topk(k, dim=1)
                correct_t2i = pred_t2i.eq(labels.view(-1, 1).expand_as(pred_t2i)).sum().item()
                topk_correct[k] += correct_t2i
            
            total += len(images) * 2  # Count both i2t and t2i
            
            # Average cosine similarity for matching pairs
            sim = logits.diag().cpu().numpy()
            similarities.extend(sim)
    
    results = {
        "model": model_name,
        "top1_accuracy": topk_correct[1] / total,
        "top2_accuracy": topk_correct[2] / total,
        "top3_accuracy": topk_correct[3] / total,
        "top4_accuracy": topk_correct[4] / total,
        "top5_accuracy": topk_correct[5] / total,
        "top1_error": 1 - (topk_correct[1] / total),
        "top2_error": 1 - (topk_correct[2] / total),
        "top3_error": 1 - (topk_correct[3] / total),
        "top4_error": 1 - (topk_correct[4] / total),
        "top5_error": 1 - (topk_correct[5] / total),
        "avg_similarity": np.mean(similarities)
    }
    return results

# Main evaluation
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.cuda.empty_cache()
    print(f"Using device: {device}")

    # Models and their configurations
    models = [
        ("CLIP", CLIP, "clip_model.pth", "clip"),
        ("BLIP", BLIP, "blip_model.pth", "blip"),
        ("VILT", VILT, "vilt_model.pth", "vilt"),
        ("CLIPFromScratch", CLIPFromScratch, "clip_from_scratch.pth", "clip"),
        ("BLIPFromScratch", BLIPFromScratch, "blip_from_scratch.pth", "clip"),
        ("VILTFromScratch", VILTFromScratch, "vilt_from_scratch.pth", "vilt")
    ]

    results = []
    for model_name, model_class, model_path, model_type in models:
        print(f"\nLoading {model_name}...")
        # Load dataset
        dataset = CocoDataset("coco_dataset.csv", model_type=model_type)
        dataloader = DataLoader(dataset, batch_size=8, shuffle=False, num_workers=2)
        
        # Initialize and load model
        model = model_class().to(device)
        if os.path.exists(model_path):
            model.load_state_dict(torch.load(model_path, map_location=device))
        else:
            print(f"Warning: {model_path} not found. Using untrained model.")
        
        # Evaluate
        result = evaluate_model(model, dataloader, device, model_name, model_type)
        results.append(result)
        
        # Log bad images
        if dataset.bad_images:
            print(f"Bad images skipped in {model_name}: {list(set(dataset.bad_images))}")

    # Print results
    print("\nPerformance Comparison:")
    print("-" * 80)
    print(f"{'Model':<20} {'Top-1 Err':<12} {'Top-2 Err':<12} {'Top-3 Err':<12} {'Top-4 Err':<12} {'Top-5 Err':<12} {'Avg Sim':<12}")
    print("-" * 80)
    for result in results:
        print(f"{result['model']:<20} {result['top1_error']:<12.4f} {result['top2_error']:<12.4f} {result['top3_error']:<12.4f} {result['top4_error']:<12.4f} {result['top5_error']:<12.4f} {result['avg_similarity']:<12.4f}")
    
    # Save results to CSV
    results_df = pd.DataFrame(results)
    results_df.to_csv("model_performance_topk.csv", index=False)
    print("\nResults saved to 'model_performance_topk.csv'")

if __name__ == "__main__":
    main()

Using device: cuda

Loading CLIP...


Some weights of ViTModel were not initialized from the model checkpoint at facebook/deit-small-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Evaluating CLIP:  74%|███████▎  | 92/125 [00:03<00:01, 30.99it/s]



Evaluating CLIP: 100%|██████████| 125/125 [00:04<00:00, 27.54it/s]



Loading BLIP...


Some weights of BlipForImageTextRetrieval were not initialized from the model checkpoint at Salesforce/blip-image-captioning-base and are newly initialized: ['itm_head.bias', 'itm_head.weight', 'text_encoder.embeddings.LayerNorm.bias', 'text_encoder.embeddings.LayerNorm.weight', 'text_encoder.embeddings.position_embeddings.weight', 'text_encoder.embeddings.word_embeddings.weight', 'text_encoder.encoder.layer.0.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.0.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.0.attention.output.dense.bias', 'text_encoder.encoder.layer.0.attention.output.dense.weight', 'text_encoder.encoder.layer.0.attention.self.key.bias', 'text_encoder.encoder.layer.0.attention.self.key.weight', 'text_encoder.encoder.layer.0.attention.self.query.bias', 'text_encoder.encoder.layer.0.attention.self.query.weight', 'text_encoder.encoder.layer.0.attention.self.value.bias', 'text_encoder.encoder.layer.0.attention.self.value.weight', 'text_encoder.



Evaluating BLIP: 100%|██████████| 125/125 [00:08<00:00, 14.87it/s]



Loading VILT...


Evaluating VILT:  76%|███████▌  | 95/125 [00:04<00:01, 23.65it/s]



Evaluating VILT: 100%|██████████| 125/125 [00:05<00:00, 21.90it/s]



Loading CLIPFromScratch...


Evaluating CLIPFromScratch:  74%|███████▍  | 93/125 [00:03<00:00, 33.39it/s]



Evaluating CLIPFromScratch: 100%|██████████| 125/125 [00:04<00:00, 31.10it/s]



Loading BLIPFromScratch...


Evaluating BLIPFromScratch:  78%|███████▊  | 97/125 [00:03<00:00, 33.23it/s]



Evaluating BLIPFromScratch: 100%|██████████| 125/125 [00:04<00:00, 30.11it/s]



Loading VILTFromScratch...


Evaluating VILTFromScratch:  78%|███████▊  | 97/125 [00:03<00:00, 33.48it/s]



Evaluating VILTFromScratch: 100%|██████████| 125/125 [00:04<00:00, 30.50it/s]


Performance Comparison:
--------------------------------------------------------------------------------
Model                Top-1 Err    Top-2 Err    Top-3 Err    Top-4 Err    Top-5 Err    Avg Sim     
--------------------------------------------------------------------------------
CLIP                 0.8045       0.6435       0.5015       0.3745       0.2600       -0.0341     
BLIP                 0.7875       0.5995       0.4595       0.3210       0.2145       0.1609      
VILT                 0.0000       0.0000       0.0000       0.0000       0.0000       1.0000      
CLIPFromScratch      0.8750       0.7505       0.6245       0.5000       0.3715       -0.1224     
BLIPFromScratch      0.8745       0.7480       0.6235       0.4975       0.3770       0.0022      
VILTFromScratch      0.0000       0.0000       0.0000       0.0000       0.0000       1.0000      

Results saved to 'model_performance_topk.csv'





In [5]:
import os
import pandas as pd
from pycocotools.coco import COCO
from tqdm import tqdm

# Load annotations
annotations_dir = "annotations"
instances_file = os.path.join(annotations_dir, "annotations/instances_train2017.json")
coco = COCO(instances_file)

# Load coco_dataset.csv to get the 1,000 images
dataset_df = pd.read_csv("coco_dataset.csv")
image_paths = dataset_df["image_path"].tolist()

# Map image filenames to COCO image IDs
filename_to_id = {img["file_name"]: img["id"] for img in coco.loadImgs(coco.getImgIds())}

# Get categories for each image
data = []
for img_path in tqdm(image_paths, desc="Processing metadata"):
    filename = os.path.basename(img_path)
    img_id = filename_to_id.get(filename)
    if img_id is None:
        print(f"Warning: Image {filename} not found in annotations")
        category = "unknown"
    else:
        ann_ids = coco.getAnnIds(imgIds=img_id)
        anns = coco.loadAnns(ann_ids)
        # Get the first category (simplified; you could aggregate multiple categories)
        category_ids = [ann["category_id"] for ann in anns]
        if category_ids:
            category = coco.loadCats(category_ids[0])[0]["name"]
        else:
            category = "none"
    data.append({"image_path": img_path, "category": category})

# Save to CSV
metadata_df = pd.DataFrame(data)
metadata_df.to_csv("coco_metadata.csv", index=False)
print("Metadata saved to 'coco_metadata.csv'")
print(metadata_df.head())

loading annotations into memory...
Done (t=12.68s)
creating index...
index created!


Processing metadata: 100%|██████████| 1000/1000 [00:00<00:00, 22419.60it/s]

Metadata saved to 'coco_metadata.csv'
                     image_path    category
0  coco_images/000000391895.jpg  motorcycle
1  coco_images/000000522418.jpg      person
2  coco_images/000000184613.jpg         cow
3  coco_images/000000318219.jpg      person
4  coco_images/000000554625.jpg          tv





In [9]:
import os
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import entropy
import cv2
from transformers import ViTModel, DistilBertModel, DistilBertTokenizer, BlipForImageTextRetrieval, BlipProcessor, ViltProcessor, ViltModel

# Environment settings
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Dataset
class CocoDataset(Dataset):
    def __init__(self, csv_file, transform=None, model_type="clip"):
        self.data = pd.read_csv(csv_file)
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.model_type = model_type
        if model_type == "blip":
            self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        elif model_type == "vilt":
            self.processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
        else:  # clip or from-scratch
            self.processor = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
        self.bad_images = []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data["image_path"][idx]
        caption = self.data["caption"][idx]
        try:
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"Image not found: {img_path}")
            image = Image.open(img_path).convert("RGB")
        except (UnidentifiedImageError, FileNotFoundError, Exception) as e:
            print(f"Warning: Skipping {img_path} due to error: {e}")
            self.bad_images.append(img_path)
            dummy_image = torch.zeros(3, 224, 224)
            if self.model_type in ["blip", "vilt"]:
                text = self.processor.tokenizer("Invalid image", padding="max_length", max_length=32, truncation=True, return_tensors="pt")
            else:
                text = self.processor("Invalid image", padding="max_length", max_length=32, truncation=True, return_tensors="pt")
            return dummy_image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

        if self.transform:
            image = self.transform(image)
        if self.model_type == "blip":
            text = self.processor.tokenizer(caption, padding="max_length", max_length=32, truncation=True, return_tensors="pt")
        elif self.model_type == "vilt":
            text = self.processor.tokenizer(caption, padding="max_length", max_length=32, truncation=True, return_tensors="pt")
        else:
            text = self.processor(caption, padding="max_length", max_length=32, truncation=True, return_tensors="pt")
        return image, text["input_ids"].squeeze(), text["attention_mask"].squeeze()

# Pre-trained Models
class CLIP(nn.Module):
    def __init__(self):
        super(CLIP, self).__init__()
        self.vision_encoder = ViTModel.from_pretrained("facebook/deit-small-patch16-224", output_attentions=True)
        self.text_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.projection_dim = 512
        self.vision_proj = nn.Linear(384, self.projection_dim)
        self.text_proj = nn.Linear(768, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        vision_outputs = self.vision_encoder(images)
        vision_embeds = self.vision_proj(vision_outputs.last_hidden_state[:, 0, :])
        text_outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
        text_embeds = self.text_proj(text_outputs.last_hidden_state[:, 0, :])
        return vision_embeds, text_embeds, vision_outputs.attentions

class BLIP(nn.Module):
    def __init__(self):
        super(BLIP, self).__init__()
        self.blip = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-image-captioning-base")
        self.blip.vision_model.config.output_attentions = True
        self.projection_dim = 512
        self.vision_proj = nn.Linear(768, self.projection_dim)
        self.text_proj = nn.Linear(768, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        vision_outputs = self.blip.vision_model(pixel_values=images)
        vision_embeds = self.vision_proj(vision_outputs.last_hidden_state[:, 0, :])
        text_outputs = self.blip.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_embeds = self.text_proj(text_outputs.last_hidden_state[:, 0, :])
        return vision_embeds, text_embeds, vision_outputs.attentions

class VILT(nn.Module):
    def __init__(self):
        super(VILT, self).__init__()
        self.vilt = ViltModel.from_pretrained("dandelin/vilt-b32-mlm", output_attentions=True)
        self.projection_dim = 512
        self.proj = nn.Linear(768, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        outputs = self.vilt(pixel_values=images, input_ids=input_ids, attention_mask=attention_mask)
        embeds = self.proj(outputs.last_hidden_state[:, 0, :])
        return embeds, embeds, outputs.attentions

# From-Scratch Models
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.qkv = nn.Linear(d_model, d_model * 3)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size = x.size(0)
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = [t.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) for t in qkv]
        attn = (q @ k.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return out, attn

class ViTEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, d_model=256, num_heads=8, num_layers=6):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model))
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(d_model),
                MultiHeadAttention(d_model, num_heads),
                nn.LayerNorm(d_model),
                nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
            ]) for _ in range(num_layers)
        ])

    def forward(self, x):
        x = self.patch_embed(x).flatten(2).transpose(1, 2)
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat([cls_tokens, x], dim=1) + self.pos_embed
        attentions = []
        for norm1, attn, norm2, ff in self.layers:
            x_norm = norm1(x)
            attn_output, attn_weights = attn(x_norm)
            x = x + attn_output
            x = x + ff(norm2(x))
            attentions.append(attn_weights)
        return x[:, 0], attentions

class TextEncoder(nn.Module):
    def __init__(self, vocab_size=30522, d_model=256, num_heads=8, num_layers=6, max_len=32):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, max_len, d_model))
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(d_model),
                MultiHeadAttention(d_model, num_heads),
                nn.LayerNorm(d_model),
                nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
            ]) for _ in range(num_layers)
        ])

    def forward(self, input_ids, attention_mask):
        x = self.embedding(input_ids) + self.pos_embed
        attentions = []
        for norm1, attn, norm2, ff in self.layers:
            x_norm = norm1(x)
            attn_output, attn_weights = attn(x_norm)
            x = x + attn_output * attention_mask.unsqueeze(-1)
            x = x + ff(norm2(x)) * attention_mask.unsqueeze(-1)
            attentions.append(attn_weights)
        return x[:, 0], attentions

class CLIPFromScratch(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision_encoder = ViTEncoder()
        self.text_encoder = TextEncoder()
        self.projection_dim = 512
        self.vision_proj = nn.Linear(256, self.projection_dim)
        self.text_proj = nn.Linear(256, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        vision_embeds, vision_attentions = self.vision_encoder(images)
        text_embeds, text_attentions = self.text_encoder(input_ids, attention_mask)
        vision_embeds = self.vision_proj(vision_embeds)
        text_embeds = self.text_proj(text_embeds)
        return vision_embeds, text_embeds, vision_attentions

class BLIPFromScratch(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision_encoder = ViTEncoder()
        self.text_encoder = TextEncoder()
        self.projection_dim = 512
        self.vision_proj = nn.Linear(256, self.projection_dim)
        self.text_proj = nn.Linear(256, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        vision_embeds, vision_attentions = self.vision_encoder(images)
        text_embeds, text_attentions = self.text_encoder(input_ids, attention_mask)
        vision_embeds = self.vision_proj(vision_embeds)
        text_embeds = self.text_proj(text_embeds)
        return vision_embeds, text_embeds, vision_attentions

class ViltEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, vocab_size=30522, d_model=256, num_heads=8, num_layers=6, max_len=32):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_embed = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size)
        self.text_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + max_len + 1, d_model))
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(d_model),
                MultiHeadAttention(d_model, num_heads),
                nn.LayerNorm(d_model),
                nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
            ]) for _ in range(num_layers)
        ])

    def forward(self, images, input_ids, attention_mask):
        img_embeds = self.patch_embed(images).flatten(2).transpose(1, 2)
        text_embeds = self.text_embed(input_ids)
        x = torch.cat([img_embeds, text_embeds], dim=1)
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat([cls_tokens, x], dim=1) + self.pos_embed
        img_mask = torch.ones(x.size(0), img_embeds.size(1) + 1, device=x.device)
        full_mask = torch.cat([img_mask, attention_mask], dim=1)
        attentions = []
        for norm1, attn, norm2, ff in self.layers:
            x_norm = norm1(x)
            attn_output, attn_weights = attn(x_norm)
            x = x + attn_output * full_mask.unsqueeze(-1)
            x = x + ff(norm2(x)) * full_mask.unsqueeze(-1)
            attentions.append(attn_weights)
        return x[:, 0], attentions

class VILTFromScratch(nn.Module):
    def __init__(self):
        super().__init__()
        self.vilt = ViltEncoder()
        self.projection_dim = 512
        self.proj = nn.Linear(256, self.projection_dim)

    def forward(self, images, input_ids, attention_mask):
        embeds, attentions = self.vilt(images, input_ids, attention_mask)
        embeds = self.proj(embeds)
        return embeds, embeds, attentions

# Robustness Testing
def robustness_test(model, dataloader, device, model_name, model_type, noise_levels=[0.01, 0.05, 0.1], blur_levels=[1, 3, 5]):
    model.eval()
    results = {
        "model": model_name,
        "noise_results": {},
        "blur_results": {},
        "occlusion_results": {}
    }
    
    def add_gaussian_noise(image, std):
        noise = torch.normal(mean=0.0, std=std, size=image.shape, device=image.device)
        return torch.clamp(image + noise, 0, 1)
    
    def add_gaussian_blur(image, kernel_size):
        image_np = image.permute(1, 2, 0).cpu().numpy()
        blurred = cv2.GaussianBlur(image_np, (kernel_size, kernel_size), 0)
        return torch.tensor(blurred).permute(2, 0, 1).to(image.device)
    
    def add_occlusion(image, patch_size=32):
        _, h, w = image.shape
        x = np.random.randint(0, w - patch_size)
        y = np.random.randint(0, h - patch_size)
        image[:, y:y+patch_size, x:x+patch_size] = 0
        return image
    
    def evaluate_perturbed(dataloader, perturbation_fn, levels, perturbation_type):
        level_results = {}
        for level in levels:
            top1_correct = 0
            total = 0
            similarities = []
            for images, input_ids, attention_mask in tqdm(dataloader, desc=f"{perturbation_type} {level}"):
                images = images.to(device)
                images_perturbed = torch.stack([perturbation_fn(img, level) for img in images])
                input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
                
                with torch.no_grad():
                    vision_embeds, text_embeds, _ = model(images_perturbed, input_ids, attention_mask)
                    vision_embeds = vision_embeds / vision_embeds.norm(dim=-1, keepdim=True)
                    text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
                    logits = torch.matmul(vision_embeds, text_embeds.T)
                    labels = torch.arange(len(images)).to(device)
                    
                    _, pred_i2t = logits.topk(1, dim=1)
                    correct_i2t = pred_i2t.eq(labels.view(-1, 1)).sum().item()
                    _, pred_t2i = logits.T.topk(1, dim=1)
                    correct_t2i = pred_t2i.eq(labels.view(-1, 1)).sum().item()
                    
                    top1_correct += correct_i2t + correct_t2i
                    total += len(images) * 2
                    sim = logits.diag().cpu().numpy()
                    similarities.extend(sim)
            
            level_results[f"{perturbation_type}_{level}"] = {
                "top1_accuracy": top1_correct / total,
                "top1_error": 1 - (top1_correct / total),
                "avg_similarity": np.mean(similarities)
            }
        return level_results
    
    results["noise_results"] = evaluate_perturbed(dataloader, add_gaussian_noise, noise_levels, "noise")
    results["blur_results"] = evaluate_perturbed(dataloader, add_gaussian_blur, blur_levels, "blur")
    results["occlusion_results"] = evaluate_perturbed(dataloader, add_occlusion, [32, 64], "occlusion")
    
    return results

# Interpretability Analysis
def interpretability_analysis(model, dataloader, device, model_name, model_type, num_samples=5):
    model.eval()
    attention_maps = []
    entropies = []
    num_patches = 196  # Default for 224x224 images with 16x16 patches
    grid_size = int(np.sqrt(num_patches))  # 14 for 196 patches

    for i, (images, input_ids, attention_mask) in enumerate(dataloader):
        if i >= num_samples:
            break
        images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)
        
        with torch.no_grad():
            vision_embeds, text_embeds, attentions = model(images, input_ids, attention_mask)
            attn_weights = attentions[-1]  # Shape: (batch, num_heads, seq_len, seq_len)
            attn_weights = attn_weights.mean(dim=1)  # Average over heads: (batch, seq_len, seq_len)

            if model_type == "vilt":
                # ViLT: seq_len = 1 (CLS) + 196 (image patches) + text_tokens
                # Take attention to image patches only (indices 1 to 196)
                cls_attn = attn_weights[:, 0, 1:197]  # Shape: (batch, 196)
            else:
                # CLIP, BLIP, from-scratch: seq_len = 1 (CLS) + 196 (patches)
                cls_attn = attn_weights[:, 0, 1:]  # Shape: (batch, 196)

            # Debug: Print shape to confirm
            print(f"{model_name} cls_attn shape: {cls_attn.shape}")

            # Ensure cls_attn has the expected number of patches
            if cls_attn.shape[1] != num_patches:
                print(f"Warning: Expected {num_patches} patches, got {cls_attn.shape[1]}. Skipping visualization for {model_name}.")
                continue

            # Compute attention entropy
            attn_probs = cls_attn.cpu().numpy()
            sample_entropy = [entropy(probs) for probs in attn_probs if probs.sum() > 0]
            entropies.extend(sample_entropy)
            
            # Reshape attention for visualization
            attn_map = cls_attn.view(-1, grid_size, grid_size).cpu().numpy()
            attention_maps.append(attn_map)
    
    # Save attention visualization
    if attention_maps:
        plt.figure(figsize=(15, 5))
        for i, attn_map in enumerate(attention_maps[:num_samples]):
            plt.subplot(1, num_samples, i + 1)
            sns.heatmap(attn_map[0], cmap="viridis")
            plt.title(f"Sample {i+1}")
        plt.tight_layout()
        plt.savefig(f"{model_name}_attention_maps.png")
        plt.close()
    
    return {
        "model": model_name,
        "avg_attention_entropy": np.mean(entropies) if entropies else 0.0,
        "attention_maps_path": f"{model_name}_attention_maps.png" if attention_maps else "None"
    }

# Fairness Analysis
def fairness_analysis(model, dataloader, device, model_name, model_type, metadata_csv="coco_metadata.csv"):
    model.eval()
    metadata = pd.read_csv(metadata_csv)
    subgroups = metadata["category"].unique()
    fairness_results = {sg: {"top1_correct": 0, "total": 0, "similarities": []} for sg in subgroups}
    
    data_idx = 0
    for images, input_ids, attention_mask in tqdm(dataloader, desc=f"Fairness {model_name}"):
        images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)
        
        with torch.no_grad():
            vision_embeds, text_embeds, _ = model(images, input_ids, attention_mask)
            vision_embeds = vision_embeds / vision_embeds.norm(dim=-1, keepdim=True)
            text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
            logits = torch.matmul(vision_embeds, text_embeds.T)
            labels = torch.arange(len(images)).to(device)
            
            _, pred_i2t = logits.topk(1, dim=1)
            correct_i2t = pred_i2t.eq(labels.view(-1, 1)).sum().item()
            _, pred_t2i = logits.T.topk(1, dim=1)
            correct_t2i = pred_t2i.eq(labels.view(-1, 1)).sum().item()
            sim = logits.diag().cpu().numpy()
            
            for idx in range(len(images)):
                if data_idx >= len(metadata):
                    break
                sg = metadata.iloc[data_idx]["category"]
                fairness_results[sg]["top1_correct"] += (correct_i2t + correct_t2i) / 2
                fairness_results[sg]["total"] += 1
                fairness_results[sg]["similarities"].append(sim[idx])
                data_idx += 1
    
    results = {
        "model": model_name,
        "subgroup_metrics": {}
    }
    for sg in subgroups:
        total = fairness_results[sg]["total"]
        if total > 0:
            results["subgroup_metrics"][sg] = {
                "top1_accuracy": fairness_results[sg]["top1_correct"] / total,
                "top1_error": 1 - (fairness_results[sg]["top1_correct"] / total),
                "avg_similarity": np.mean(fairness_results[sg]["similarities"])
            }
    
    accuracies = [results["subgroup_metrics"][sg]["top1_accuracy"] for sg in subgroups if sg in results["subgroup_metrics"]]
    plt.figure(figsize=(10, 5))
    sns.barplot(x=subgroups, y=accuracies)
    plt.title(f"{model_name} Fairness Across Subgroups")
    plt.ylabel("Top-1 Accuracy")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(f"{model_name}_fairness_plot.png")
    plt.close()
    
    results["fairness_plot_path"] = f"{model_name}_fairness_plot.png"
    return results

# Main function
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.cuda.empty_cache()
    print(f"Using device: {device}")

    models = [
        ("CLIP", CLIP, "clip_model.pth", "clip"),
        ("BLIP", BLIP, "blip_model.pth", "blip"),
        ("VILT", VILT, "vilt_model.pth", "vilt"),
        ("CLIPFromScratch", CLIPFromScratch, "clip_from_scratch.pth", "clip"),
        ("BLIPFromScratch", BLIPFromScratch, "blip_from_scratch.pth", "clip"),
        ("VILTFromScratch", VILTFromScratch, "vilt_from_scratch.pth", "vilt")
    ]

    all_results = []
    for model_name, model_class, model_path, model_type in models:
        print(f"\nProcessing {model_name}...")
        dataset = CocoDataset("coco_dataset.csv", model_type=model_type)
        dataloader = DataLoader(dataset, batch_size=8, shuffle=False, num_workers=2)
        
        model = model_class().to(device)
        if os.path.exists(model_path):
            model.load_state_dict(torch.load(model_path, map_location=device))
        else:
            print(f"Warning: {model_path} not found. Using untrained model.")
        
        robustness_result = robustness_test(model, dataloader, device, model_name, model_type)
        all_results.append(robustness_result)
        
        interpretability_result = interpretability_analysis(model, dataloader, device, model_name, model_type)
        all_results.append(interpretability_result)
        
        fairness_result = fairness_analysis(model, dataloader, device, model_name, model_type)
        all_results.append(fairness_result)
    
    results_df = pd.DataFrame(all_results)
    results_df.to_csv("robustness_interpretability_fairness_results.csv", index=False)
    print("\nResults saved to 'robustness_interpretability_fairness_results.csv'")

if __name__ == "__main__":
    main()

Using device: cuda

Processing CLIP...


Some weights of ViTModel were not initialized from the model checkpoint at facebook/deit-small-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
noise 0.01: 100%|██████████| 125/125 [00:05<00:00, 23.55it/s]
noise 0.05: 100%|██████████| 125/125 [00:05<00:00, 21.41it/s]
noise 0.1: 100%|██████████| 125/125 [00:05<00:00, 23.93it/s]
blur 1: 100%|██████████| 125/125 [00:06<00:00, 18.60it/s]
blur 3: 100%|██████████| 125/125 [00:07<00:00, 17.55it/s]
blur 5: 100%|██████████| 125/125 [00:07<00:00, 17.05it/s]
occlusion 32: 100%|██████████| 125/125 [00:05<00:00, 23.45it/s]
occlusion 64: 100%|██████████| 125/125 [00:04<00:00, 25.25it/s]


CLIP cls_attn shape: torch.Size([8, 196])
CLIP cls_attn shape: torch.Size([8, 196])
CLIP cls_attn shape: torch.Size([8, 196])
CLIP cls_attn shape: torch.Size([8, 196])
CLIP cls_attn shape: torch.Size([8, 196])


Fairness CLIP: 100%|██████████| 125/125 [00:05<00:00, 22.01it/s]



Processing BLIP...


Some weights of BlipForImageTextRetrieval were not initialized from the model checkpoint at Salesforce/blip-image-captioning-base and are newly initialized: ['itm_head.bias', 'itm_head.weight', 'text_encoder.embeddings.LayerNorm.bias', 'text_encoder.embeddings.LayerNorm.weight', 'text_encoder.embeddings.position_embeddings.weight', 'text_encoder.embeddings.word_embeddings.weight', 'text_encoder.encoder.layer.0.attention.output.LayerNorm.bias', 'text_encoder.encoder.layer.0.attention.output.LayerNorm.weight', 'text_encoder.encoder.layer.0.attention.output.dense.bias', 'text_encoder.encoder.layer.0.attention.output.dense.weight', 'text_encoder.encoder.layer.0.attention.self.key.bias', 'text_encoder.encoder.layer.0.attention.self.key.weight', 'text_encoder.encoder.layer.0.attention.self.query.bias', 'text_encoder.encoder.layer.0.attention.self.query.weight', 'text_encoder.encoder.layer.0.attention.self.value.bias', 'text_encoder.encoder.layer.0.attention.self.value.weight', 'text_encoder.

BLIP cls_attn shape: torch.Size([8, 196])
BLIP cls_attn shape: torch.Size([8, 196])
BLIP cls_attn shape: torch.Size([8, 196])
BLIP cls_attn shape: torch.Size([8, 196])
BLIP cls_attn shape: torch.Size([8, 196])


Fairness BLIP: 100%|██████████| 125/125 [00:08<00:00, 13.94it/s]



Processing VILT...


noise 0.01: 100%|██████████| 125/125 [00:05<00:00, 21.32it/s]
noise 0.05: 100%|██████████| 125/125 [00:05<00:00, 21.62it/s]
noise 0.1: 100%|██████████| 125/125 [00:06<00:00, 20.74it/s]
blur 1: 100%|██████████| 125/125 [00:07<00:00, 15.83it/s]
blur 3: 100%|██████████| 125/125 [00:07<00:00, 16.95it/s]
blur 5: 100%|██████████| 125/125 [00:07<00:00, 16.72it/s]
occlusion 32: 100%|██████████| 125/125 [00:05<00:00, 21.50it/s]
occlusion 64: 100%|██████████| 125/125 [00:05<00:00, 21.92it/s]


VILT cls_attn shape: torch.Size([8, 81])
VILT cls_attn shape: torch.Size([8, 81])
VILT cls_attn shape: torch.Size([8, 81])
VILT cls_attn shape: torch.Size([8, 81])
VILT cls_attn shape: torch.Size([8, 81])


Fairness VILT: 100%|██████████| 125/125 [00:05<00:00, 21.55it/s]



Processing CLIPFromScratch...


noise 0.01: 100%|██████████| 125/125 [00:04<00:00, 27.92it/s]
noise 0.05: 100%|██████████| 125/125 [00:04<00:00, 26.18it/s]
noise 0.1: 100%|██████████| 125/125 [00:04<00:00, 26.95it/s]
blur 1: 100%|██████████| 125/125 [00:05<00:00, 23.24it/s]
blur 3: 100%|██████████| 125/125 [00:05<00:00, 23.78it/s]
blur 5: 100%|██████████| 125/125 [00:05<00:00, 22.43it/s]
occlusion 32: 100%|██████████| 125/125 [00:04<00:00, 28.30it/s]
occlusion 64: 100%|██████████| 125/125 [00:04<00:00, 28.43it/s]


CLIPFromScratch cls_attn shape: torch.Size([8, 196])
CLIPFromScratch cls_attn shape: torch.Size([8, 196])
CLIPFromScratch cls_attn shape: torch.Size([8, 196])
CLIPFromScratch cls_attn shape: torch.Size([8, 196])
CLIPFromScratch cls_attn shape: torch.Size([8, 196])


Fairness CLIPFromScratch: 100%|██████████| 125/125 [00:04<00:00, 28.36it/s]



Processing BLIPFromScratch...


noise 0.01: 100%|██████████| 125/125 [00:04<00:00, 27.51it/s]
noise 0.05: 100%|██████████| 125/125 [00:04<00:00, 26.90it/s]
noise 0.1: 100%|██████████| 125/125 [00:04<00:00, 26.56it/s]
blur 1: 100%|██████████| 125/125 [00:05<00:00, 23.28it/s]
blur 3: 100%|██████████| 125/125 [00:05<00:00, 22.36it/s]
blur 5: 100%|██████████| 125/125 [00:05<00:00, 21.69it/s]
occlusion 32: 100%|██████████| 125/125 [00:04<00:00, 26.72it/s]
occlusion 64: 100%|██████████| 125/125 [00:04<00:00, 27.95it/s]


BLIPFromScratch cls_attn shape: torch.Size([8, 196])
BLIPFromScratch cls_attn shape: torch.Size([8, 196])
BLIPFromScratch cls_attn shape: torch.Size([8, 196])
BLIPFromScratch cls_attn shape: torch.Size([8, 196])
BLIPFromScratch cls_attn shape: torch.Size([8, 196])


Fairness BLIPFromScratch: 100%|██████████| 125/125 [00:04<00:00, 27.18it/s]



Processing VILTFromScratch...


noise 0.01: 100%|██████████| 125/125 [00:04<00:00, 28.51it/s]
noise 0.05: 100%|██████████| 125/125 [00:04<00:00, 27.83it/s]
noise 0.1: 100%|██████████| 125/125 [00:04<00:00, 27.85it/s]
blur 1: 100%|██████████| 125/125 [00:05<00:00, 23.65it/s]
blur 3: 100%|██████████| 125/125 [00:05<00:00, 23.25it/s]
blur 5: 100%|██████████| 125/125 [00:05<00:00, 22.18it/s]
occlusion 32: 100%|██████████| 125/125 [00:04<00:00, 27.74it/s]
occlusion 64: 100%|██████████| 125/125 [00:04<00:00, 28.13it/s]


VILTFromScratch cls_attn shape: torch.Size([8, 196])
VILTFromScratch cls_attn shape: torch.Size([8, 196])
VILTFromScratch cls_attn shape: torch.Size([8, 196])
VILTFromScratch cls_attn shape: torch.Size([8, 196])
VILTFromScratch cls_attn shape: torch.Size([8, 196])


Fairness VILTFromScratch: 100%|██████████| 125/125 [00:04<00:00, 29.05it/s]



Results saved to 'robustness_interpretability_fairness_results.csv'


In [7]:
import pandas as pd

# Load dataset and metadata
dataset_df = pd.read_csv("coco_dataset.csv")
metadata_df = pd.read_csv("coco_metadata.csv")

# Filter out the corrupted image
corrupted_image = "coco_images/000000365426.jpg"
dataset_df = dataset_df[dataset_df["image_path"] != corrupted_image]
metadata_df = metadata_df[metadata_df["image_path"] != corrupted_image]

# Save updated files
dataset_df.to_csv("coco_dataset.csv", index=False)
metadata_df.to_csv("coco_metadata.csv", index=False)
print(f"Removed {corrupted_image} from dataset and metadata.")

Removed coco_images/000000365426.jpg from dataset and metadata.
