In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import CLIPProcessor, CLIPModel, get_cosine_schedule_with_warmup
from datasets import load_dataset
from datetime import datetime
import wandb
from torchvision import transforms
from typing import Literal, Optional, List
from tqdm.auto import tqdm
wandb.login()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

In [None]:
class Config:
  model_name = "openai/clip-vit-base-patch16"
  dataset_name = "eltorio/ROCO-radiology"
  max_length = 77
  image_size = 224

  strategy: Literal["vision_only", "text_only", "last_30"] = "last_30"
  num_epochs = 5
  batch_size = 256
  actual_batch_size = 32
  gradient_accumulation_steps = 8
  learning_rate = 1e-5
  warmup_steps = 500
  weight_decay = 0.01
  max_grad_norm = 1.0

  device = "cuda" if torch.cuda.is_available() else "cpu"
  mixed_precision = True
  num_workers = 2

  use_wandb = True
  checkpoint_dir = "./checkpoints"
  save_every_n_epochs = 1
  log_every_n_steps = 50

  eval_every_n_epochs = 1

config = Config()


In [None]:
class ROCODataset(Dataset):
    """ROCO Radiology Dataset with CLIP preprocessing"""
    def __init__(self, split="train", processor=None, max_samples=None):
        self.processor = processor
        self.split = split # Added this line

        print(f"Loading ROCO dataset ({split} split)...")
        dataset = load_dataset(config.dataset_name, split=split)

        print("Filtering invalid samples...")
        original_len = len(dataset)

        dataset = dataset.filter(
            lambda x: x["image"] is not None
            and x["caption"] is not None
            and len(x["caption"].strip()) > 0
        )

        print(f"Filtered {original_len - len(dataset)} invalid samples.")

        if max_samples:
            print(f"Selecting first {max_samples} samples for debug...")
            dataset = dataset.select(range(min(max_samples, len(dataset))))

        self.data = dataset
        print(f"Final dataset size: {len(self.data)} samples")
        self.augment = transforms.Compose([
            transforms.RandomAffine(
                degrees=10,
                translate=(0.05, 0.05),
                scale=(0.95, 1.05)
            ),
            transforms.ColorJitter(
                brightness=0.2,
                contrast=0.2
            )
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        image = item['image']
        if image.mode != 'RGB':
            image = image.convert('RGB')
        if self.split == "train":
            image = self.augment(image)

        caption = item['caption']

        inputs = self.processor(
            text=caption,
            images=image,
            return_tensors="pt",
            padding="max_length",
            max_length=config.max_length,
            truncation=True
        )

        return {
            'pixel_values': inputs['pixel_values'].squeeze(0),
            'input_ids': inputs['input_ids'].squeeze(0),
            'attention_mask': inputs['attention_mask'].squeeze(0)
        }

In [None]:
class CLIPFineTuner:

  def __init__(self, model, strategy: str):
    self.model = model
    self.strategy = strategy
    self.apply_strategy()

  def freeze_all(self):
    """Freeze all parameters"""
    for param in self.model.parameters():
        param.requires_grad = False

  def get_layer_info(self):
    """Get information about model layers"""
    vision_layers = []
    text_layers = []

    if hasattr(self.model.vision_model, 'encoder'):
        vision_layers = list(self.model.vision_model.encoder.layers)
    if hasattr(self.model.text_model, 'encoder'):
        text_layers = list(self.model.text_model.encoder.layers)

    return vision_layers, text_layers

  def apply_strategy(self):
    print(f"\n{'='*60}")
    print(f"Applying Fine-tuning Strategy: {self.strategy.upper()}")
    print(f"{'='*60}\n")

    self.freeze_all()

    if self.strategy == "vision_only":
        self._apply_vision_only()
    elif self.strategy == "text_only":
        self._apply_text_only()
    elif self.strategy == "last_30":
        self._apply_last_30()
    else:
        raise ValueError(f"Unknown strategy: {self.strategy}")

    self._print_trainable_params()

  def _apply_vision_only(self):
    print("Unfreezing: Vision Encoder")
    for param in self.model.vision_model.parameters():
        param.requires_grad = True

    if hasattr(self.model, 'visual_projection'):
        for param in self.model.visual_projection.parameters():
            param.requires_grad = True

  def _apply_text_only(self):
    print("Unfreezing: Text Encoder")
    for param in self.model.text_model.parameters():
        param.requires_grad = True

    if hasattr(self.model, 'text_projection'):
        for param in self.model.text_projection.parameters():
            param.requires_grad = True

  def _apply_last_30(self):
    """Unfreeze last 30% of layers in both encoders"""
    vision_layers, text_layers = self.get_layer_info()

    vision_threshold = int(len(vision_layers) * 0.7)
    text_threshold = int(len(text_layers) * 0.7)

    print(f"Vision Encoder: {len(vision_layers)} layers total")
    print(f"  - Freezing first {vision_threshold} layers")
    print(f"  - Unfreezing last {len(vision_layers) - vision_threshold} layers")

    print(f"\nText Encoder: {len(text_layers)} layers total")
    print(f"  - Freezing first {text_threshold} layers")
    print(f"  - Unfreezing last {len(text_layers) - text_threshold} layers")

    for i in range(vision_threshold, len(vision_layers)):
        for param in vision_layers[i].parameters():
            param.requires_grad = True

    if hasattr(self.model.vision_model, 'post_layernorm'):
        print("  - Unfreezing Vision Post-LayerNorm")
        for param in self.model.vision_model.post_layernorm.parameters():
            param.requires_grad = True

    for i in range(text_threshold, len(text_layers)):
        for param in text_layers[i].parameters():
            param.requires_grad = True

    if hasattr(self.model.text_model, 'final_layer_norm'):
          print("  - Unfreezing Text Final-LayerNorm")
          for param in self.model.text_model.final_layer_norm.parameters():
            param.requires_grad = True

    if hasattr(self.model, 'visual_projection'):
        for param in self.model.visual_projection.parameters():
            param.requires_grad = True

    if hasattr(self.model, 'text_projection'):
        for param in self.model.text_projection.parameters():
            param.requires_grad = True

  def _print_trainable_params(self):
    """Print summary of trainable parameters"""
    trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in self.model.parameters())
    percentage = 100 * trainable / total

    print(f"\n{'='*60}")
    print(f"Trainable Parameters: {trainable:,} / {total:,} ({percentage:.2f}%)")
    print(f"{'='*60}\n")


In [None]:
class CLIPTrainer:
  """Trainer for CLIP fine-tuning"""

  def __init__(self, model, train_loader, val_loader, config):
    self.model = model.to(config.device)
    self.train_loader = train_loader
    self.val_loader = val_loader
    self.config = config
    trainable_params = [p for p in self.model.parameters() if p.requires_grad]
    self.optimizer = torch.optim.AdamW(
        trainable_params,
        lr=config.learning_rate,
        weight_decay=config.weight_decay
    )
    num_update_steps_per_epoch = len(train_loader) // config.gradient_accumulation_steps
    max_train_steps = config.num_epochs * num_update_steps_per_epoch

    self.scheduler = get_cosine_schedule_with_warmup(
        self.optimizer,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=max_train_steps
    )
    self.scaler = torch.cuda.amp.GradScaler() if config.mixed_precision else None

    self.global_step = 0
    self.best_val_loss = float('inf')

    os.makedirs(config.checkpoint_dir, exist_ok=True)

  def train_epoch(self,epoch):
    self.model.train()
    total_loss = 0
    total_grad_norm = 0

    pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.config.num_epochs}")

    for step, batch in enumerate(pbar):
      pixel_values = batch['pixel_values'].to(self.config.device)
      input_ids = batch['input_ids'].to(self.config.device)
      attention_mask = batch['attention_mask'].to(self.config.device)

      if self.scaler:
          with torch.cuda.amp.autocast():
              outputs = self.model(
                  pixel_values=pixel_values,
                  input_ids=input_ids,
                  attention_mask=attention_mask,
                  return_loss=True
              )

              loss = outputs.loss / self.config.gradient_accumulation_steps

          self.scaler.scale(loss).backward()

      else:
          outputs = self.model(
              pixel_values=pixel_values,
              input_ids=input_ids,
              attention_mask=attention_mask,
              return_loss=True
          )
          loss = outputs.loss / self.config.gradient_accumulation_steps
          loss.backward()
      current_loss = loss.item()*self.config.gradient_accumulation_steps
      total_loss += current_loss

      if (step + 1) % self.config.gradient_accumulation_steps == 0 or (step+1)==len(self.train_loader):
          if self.scaler:
              self.scaler.unscale_(self.optimizer)
          grad_norm = torch.nn.utils.clip_grad_norm_(
                  self.model.parameters(),
                  self.config.max_grad_norm
              )
          if self.scaler:
              self.scaler.step(self.optimizer)
              self.scaler.update()
          else:
              self.optimizer.step()

          self.optimizer.zero_grad()
          self.scheduler.step()
          self.global_step += 1

          if self.config.use_wandb and self.global_step % self.config.log_every_n_steps == 0:
                wandb.log({
                  'train/loss': current_loss,
                  'train/grad_norm': grad_norm.item(),
                  'train/learning_rate': self.scheduler.get_last_lr()[0],
                  'global_step': self.global_step
              })
      pbar.set_postfix({'loss':f"{current_loss:.4f}"})

    avg_loss = total_loss / len(self.train_loader)
    return avg_loss

  @torch.no_grad()
  def validate(self):
    """Validate with Recall@1 and Recall@5"""
    self.model.eval()
    total_loss = 0
    total_correct_r1 = 0
    total_correct_r5 = 0
    total_samples = 0

    pbar = tqdm(self.val_loader, desc="Validation")

    for batch in pbar:
        pixel_values = batch['pixel_values'].to(self.config.device)
        input_ids = batch['input_ids'].to(self.config.device)
        attention_mask = batch['attention_mask'].to(self.config.device)

        with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
            outputs = self.model(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_loss=True
            )
        logits = outputs.logits_per_image
        batch_size = logits.shape[0]
        labels = torch.arange(batch_size, device=self.config.device)

        pred = logits.argmax(dim=1)
        total_correct_r1 += (pred == labels).sum().item()

        if batch_size >= 5:
            _, top5_indices = logits.topk(5, dim=1)
            total_correct_r5 += (top5_indices == labels.view(-1, 1)).any(dim=1).sum().item()
        else:
            total_correct_r5 += (pred==labels).sum().item()

        total_loss += outputs.loss.item()
        total_samples += batch_size
        avg_loss = total_loss / len(self.val_loader)
        avg_r1 = total_correct_r1 / total_samples
        avg_r5 = total_correct_r5 / total_samples

        pbar.set_postfix({
            'loss': f"{avg_loss:.4f}",
            'r1' : f"{avg_r1:.4f}",
            'r5' : f"{avg_r5:.4f}"
        })

    return avg_loss, avg_r1, avg_r5

  def save_checkpoint(self, epoch, val_loss, is_best=False):
      """Save model checkpoint"""

      checkpoint = {
          'epoch': epoch,
          'model_state_dict': self.model.state_dict(),
          'optimizer_state_dict': self.optimizer.state_dict(),
          'scheduler_state_dict': self.scheduler.state_dict(),
          'val_loss': val_loss,
          'global_step': self.global_step,
          'config': vars(self.config)
      }
      path = os.path.join(
          self.config.checkpoint_dir,
          f"checkpoint_epoch_{epoch+1}.pt"
      )
      torch.save(checkpoint, path)
      print(f"Saved checkpoint: {path}")

      if is_best:
          best_pt_path = os.path.join(self.config.checkpoint_dir, "best_model.pt")
          torch.save(checkpoint, best_pt_path)
          hf_save_dir = os.path.join(self.config.checkpoint_dir, "best_model_hf")
          self.model.save_pretrained(hf_save_dir)
          print(f"Saved best model (HF Ready): {hf_save_dir}")

  def train(self):
      """Main training loop"""
      print("\n" + "="*60)
      print("Starting Training")
      print("="*60 + "\n")

      for epoch in range(self.config.num_epochs):

          train_loss= self.train_epoch(epoch)

          print(f"\nEpoch {epoch+1}/{self.config.num_epochs}")
          print(f"Train Loss: {train_loss:.4f}")

          val_loss = None
          is_best = False

          if (epoch + 1) % self.config.eval_every_n_epochs == 0:

              val_loss, val_r1, val_r5 = self.validate()

              print(f"Val Loss: {val_loss:.4f} | R@1: {val_r1:.4f} | R@5: {val_r5:.4f}")

              if self.config.use_wandb:
                  wandb.log({
                      'val/loss': val_loss,
                      'val/recall_1': val_r1,
                      'val/recall_5': val_r5,
                      'epoch': epoch + 1
                  })

              if val_loss < self.best_val_loss:
                  self.best_val_loss = val_loss
                  is_best = True
                  print("New best model found!")

          should_save = ((epoch + 1) % self.config.save_every_n_epochs == 0) or is_best

          if should_save:
              save_loss = val_loss if val_loss is not None else self.best_val_loss
              self.save_checkpoint(epoch, save_loss, is_best)

          print()

  def load_checkpoint(self, checkpoint_path):
        """Load a checkpoint to resume training"""
        print(f"Loading checkpoint from {checkpoint_path}...")
        checkpoint = torch.load(checkpoint_path)

        self.model.load_state_dict(checkpoint['model_state_dict'])

        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        start_epoch = checkpoint['epoch'] + 1
        self.global_step = checkpoint['global_step']
        self.best_val_loss = checkpoint['val_loss']

        print(f"Resuming from Epoch {start_epoch}")
        return start_epoch

In [None]:
from google.colab import drive
import glob

def main():
    DEBUG = False
    RESUME_FROM = None

    if not os.path.exists('/content/drive'):
        print("Mounting Google Drive...")
        drive.mount('/content/drive')

    drive_checkpoint_dir = "/content/drive/MyDrive/ROCO_FINETUNING"
    config.checkpoint_dir = drive_checkpoint_dir
    os.makedirs(config.checkpoint_dir, exist_ok=True)

    print("="*60)
    print(f"CLIP Fine-tuning on ROCO (Production Mode)")
    print(f"Checkpoints will save to: {config.checkpoint_dir}")
    print("="*60)

    if config.use_wandb:
        run_name = f"clip_full_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        if DEBUG: run_name += "_DEBUG"

        wandb.init(
            project="ROCOCLIP",
            config=vars(config),
            name=run_name,
            resume="allow"
        )

    print("Loading CLIP...")
    processor = CLIPProcessor.from_pretrained(config.model_name)
    model = CLIPModel.from_pretrained(config.model_name)

    CLIPFineTuner(model, config.strategy)

    print("Loading Datasets...")
    train_limit = 100 if DEBUG else None
    val_limit = 50 if DEBUG else None

    train_dataset = ROCODataset(split="train", processor=processor, max_samples=train_limit)
    val_dataset = ROCODataset(split="validation", processor=processor, max_samples=val_limit)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.actual_batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        drop_last = True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.actual_batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    trainer = CLIPTrainer(model, train_loader, val_loader, config)

    start_epoch = 0

    if RESUME_FROM and os.path.exists(RESUME_FROM):
        print(f"Resuming from: {RESUME_FROM}")
        start_epoch = trainer.load_checkpoint(RESUME_FROM)

    print(f"Starting Training from Epoch {start_epoch+1}...")

    try:
        for epoch in range(start_epoch, config.num_epochs):

            train_loss= trainer.train_epoch(epoch)
            print(f"\nEpoch {epoch+1}/{config.num_epochs} | Loss: {train_loss:.4f}")

            if (epoch + 1) % config.eval_every_n_epochs == 0:
                val_loss, val_r1, val_r5 = trainer.validate()
                print(f"Val Loss: {val_loss:.4f} | R@1: {val_r1:.4f} | R@5: {val_r5:.4f}")

                if config.use_wandb:
                    wandb.log({
                        'val/loss': val_loss,
                        'val/recall_1': val_r1,
                        'val/recall_5': val_r5,
                        'epoch': epoch+1
                    })

                is_best = val_loss < trainer.best_val_loss
                if is_best:
                    trainer.best_val_loss = val_loss
                    print("New best model found!")

                if ((epoch + 1) % config.save_every_n_epochs == 0) or is_best:
                    trainer.save_checkpoint(epoch, val_loss, is_best)

            print()

        print("\n Full Training Completed Successfully!")

    except KeyboardInterrupt:
        print("\n Training interrupted by user.")
    except Exception as e:
        print(f"\n Error during training: {e}")
        raise
    finally:
        if config.use_wandb:
            wandb.finish()


In [None]:

if __name__ == "__main__":
    main()

This notebook results are

train/loss - 0.31862


val/loss - 0.80609


val/recall_1 - 0.71976


val/recall_5 - 0.97505


In this notebook i changed a few things:
1. Changed model from patch-32 to patch-16.
2. batch-size from 128 to 256.
3. actual batch size is still 32.
4. grad_accumulation_steps from 4 to 8.
5. Added data augmentation.