In [1]:
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
from torch import nn
from torch.optim import AdamW
import torch
import time
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from torch.nn.utils import clip_grad_norm_
from sklearn.metrics import accuracy_score

# Load pretrained CLIP
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").cuda()
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")



  from .autonotebook import tqdm as notebook_tqdm
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [2]:
# Load MIMIC-CXR dataset
train_dataset = load_dataset("itsanmolgupta/mimic-cxr-dataset", split="train[:90%]")  
val_datset = load_dataset("itsanmolgupta/mimic-cxr-dataset", split="train[:10%]")
train_dataset = train_dataset.filter(lambda x: x['impression'] is not None)
val_dataset = val_datset.filter(lambda x: x['impression'] is not None)

In [3]:
train_dataset

Dataset({
    features: ['image', 'findings', 'impression'],
    num_rows: 27561
})

In [4]:
val_dataset

Dataset({
    features: ['image', 'findings', 'impression'],
    num_rows: 3062
})

In [5]:
train_dataset[1]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
 'findings': 'Lung volumes remain low. There are innumerable bilateral scattered small pulmonary nodules which are better demonstrated on recent CT. Mild pulmonary vascular congestion is stable. The cardiomediastinal silhouette and hilar contours are unchanged. Small pleural effusion in the right middle fissure is new. There is no new focal opacity to suggest pneumonia. There is no pneumothorax. ',
 'impression': 'Low lung volumes and mild pulmonary vascular congestion is unchanged. New small right fissural pleural effusion. No new focal opacities to suggest pneumonia.'}

In [6]:
# Preprocessing function
def preprocess(examples):
    images = [img.convert("RGB") for img in examples["image"]]
    texts = [text if text else "" for text in examples["impression"]]
    inputs = processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
    return inputs

# Custom PyTorch dataset wrapper
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]
        image = example["image"].convert("RGB")
        text = example["impression"]
        return {"image": image, "text": text}

# Collate function using CLIPProcessor
def collate_fn(batch):
    images = [b["image"] for b in batch]
    texts = [b["text"] for b in batch]
    return processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)



In [11]:
# DataLoader
train_ds = CLIPDataset(train_dataset)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=collate_fn)
# Validation DataLoader
val_dataset = CLIPDataset(val_dataset)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)
# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-6)

In [12]:
import torch
import time
import os
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, ReduceLROnPlateau
# from transformers import AdamW

## CLAUDE SUGGESTION:

In [13]:
# CLAUDE SUGGESTION:



def do_train(model, train_loader, optimizer, epoch, device, scheduler=None, max_grad_norm=1.0):
    model.train()
    total_loss = 0.0
    start_time = time.time()
    
    for batch_idx, batch in enumerate(train_loader):
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Forward pass
        outputs = model(**batch, return_loss=True)
        loss = outputs.loss
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping for stability
        if max_grad_norm > 0:
            clip_grad_norm_(model.parameters(), max_grad_norm)
            
        optimizer.step()
        optimizer.zero_grad()
        
        total_loss += loss.item()
        
        # Report progress for long epochs
        if batch_idx % 50 == 0:
            print(f"  Batch {batch_idx}/{len(train_loader)} - Loss: {loss.item():.4f}")
    
    # Step the scheduler if provided
    if scheduler is not None:
        scheduler.step()
    
    train_time = time.time() - start_time
    avg_train_loss = total_loss / len(train_loader)
    
    return avg_train_loss, train_time

@torch.no_grad()
def do_eval(model, val_loader, device):
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_labels = []
    start_time = time.time()
    
    for batch in val_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Get loss
        outputs = model(**batch, return_loss=True)
        loss = outputs.loss
        total_loss += loss.item()
        
        # Get image and text features
        image_embeds = model.get_image_features(pixel_values=batch['pixel_values'])
        text_embeds = model.get_text_features(
            input_ids=batch['input_ids'], 
            attention_mask=batch['attention_mask']
        )
        
        # Normalize
        image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
        text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
        
        # Compute similarity
        logits = torch.matmul(image_embeds, text_embeds.T) * model.logit_scale.exp()
        
        # Prediction (diagonal elements should be highest)
        preds = torch.argmax(logits, dim=1)
        labels = torch.arange(len(preds)).to(device)
        
        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())
    
    val_time = time.time() - start_time
    avg_val_loss = total_loss / len(val_loader)
    val_acc = accuracy_score(all_labels, all_preds)
    
    return avg_val_loss, val_acc, val_time

In [16]:
import torch.nn.functional as F


In [17]:
train_losses, val_losses, val_accuracies = [], [], []

EPOCHS=10
for epoch in range(EPOCHS):
    print(f"\n🚀 Epoch {epoch+1}/{EPOCHS}")

    avg_train_loss, train_time = do_train(model, train_loader, optimizer, epoch, device="cuda")
    avg_val_loss,val_acc, val_time = do_eval(model, val_loader, device="cuda")
    print(f"\n🚀Epoch {epoch+1} Summary:")
    print(f"✅ Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f} - Time: {train_time:.2f}s")
    print(f"📊 Validation Loss: {avg_val_loss:.4f} | val_Acc: {val_acc*100:.4f}% | Time: {val_time:.2f}s")

    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    val_accuracies.append(val_acc)



🚀 Epoch 1/10
  Batch 0/1723 - Loss: 3.3927
  Batch 50/1723 - Loss: 2.8823
  Batch 100/1723 - Loss: 2.7791
  Batch 150/1723 - Loss: 2.2495
  Batch 200/1723 - Loss: 2.6175
  Batch 250/1723 - Loss: 2.1857
  Batch 300/1723 - Loss: 2.8325
  Batch 350/1723 - Loss: 2.3628
  Batch 400/1723 - Loss: 2.3473
  Batch 450/1723 - Loss: 2.2286
  Batch 500/1723 - Loss: 2.4369
  Batch 550/1723 - Loss: 2.2994
  Batch 600/1723 - Loss: 1.9284
  Batch 650/1723 - Loss: 2.3107
  Batch 700/1723 - Loss: 1.9258
  Batch 750/1723 - Loss: 1.7282
  Batch 800/1723 - Loss: 2.5525
  Batch 850/1723 - Loss: 1.9190
  Batch 900/1723 - Loss: 1.8173
  Batch 950/1723 - Loss: 2.1459
  Batch 1000/1723 - Loss: 2.1178
  Batch 1050/1723 - Loss: 2.1140
  Batch 1100/1723 - Loss: 2.3142
  Batch 1150/1723 - Loss: 2.3100
  Batch 1200/1723 - Loss: 1.8118
  Batch 1250/1723 - Loss: 2.1673
  Batch 1300/1723 - Loss: 2.0456
  Batch 1350/1723 - Loss: 1.8644
  Batch 1400/1723 - Loss: 2.1974
  Batch 1450/1723 - Loss: 2.0318
  Batch 1500/1723 -

In [1]:
import matplotlib.pyplot as plt

plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.title("CLIP Fine-Tuning Loss")
plt.legend()
plt.show()

plt.plot(val_accuracies, label="Val Accuracy")
plt.title("CLIP Val Accuracy")
plt.legend()
plt.show()


NameError: name 'train_losses' is not defined