<a href="https://colab.research.google.com/github/1hamzaiqbal/MFCLIP_acv/blob/hamza%2Fdiscrim/mf_clip_finetune_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# 1) Mount Google Drive (Run this first!)
from google.colab import drive
drive.mount('/content/drive', force_remount=True)


In [2]:
# 2) Setup Repo & Dependencies
!nvidia-smi
%cd /content
import os
if not os.path.exists("MFCLIP_acv"):
    !git clone -b hamza/discrim https://github.com/1hamzaiqbal/MFCLIP_acv
%cd MFCLIP_acv
!git pull origin hamza/discrim  # Ensure latest code

!pip install torch torchvision timm einops yacs tqdm opencv-python scikit-learn scipy pyyaml ruamel.yaml pytorch-ignite foolbox pandas matplotlib seaborn wilds ftfy


In [3]:
# 3) Setup Data & Checkpoint
import shutil
import os
from torchvision.datasets import OxfordIIITPet
from torchvision import transforms
from pathlib import Path

# Download Dataset
root = Path("/content/data/oxford_pets")
root.mkdir(parents=True, exist_ok=True)
_ = OxfordIIITPet(root=str(root), download=True, transform=transforms.ToTensor())

# Fetch Annotations
%cd /content
!mkdir -p /content/data/oxford_pets
!wget -q https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
!wget -q https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
!tar -xf images.tar.gz -C /content/data/oxford_pets
!tar -xf annotations.tar.gz -C /content/data/oxford_pets

# Copy Checkpoint from Drive
src_ckpt = "/content/drive/MyDrive/grad/comp_vision/hanson_loss/oxford_pets/RN50_ArcFace_oxford_pets.pth"
dst_ckpt = "/content/data/oxford_pets/RN50_ArcFace.pth"

if os.path.exists(src_ckpt):
    shutil.copy(src_ckpt, dst_ckpt)
    print(f"Successfully copied checkpoint to {dst_ckpt}")
else:
    raise FileNotFoundError(f"Checkpoint not found at {src_ckpt}. Please check Drive path.")


In [4]:
# 4) Train ViT Generator
%cd /content/MFCLIP_acv
!python main.py \
  --flag train_unet \
  --generator vit \
  --dataset oxford_pets \
  --root /content/data \
  --config-file configs/trainers/CoOp/rn50.yaml \
  --dataset-config-file configs/datasets/oxford_pets.yaml \
  --trainer ZeroshotCLIP \
  --surrogate RN50 \
  --head ArcFace \
  --num_epoch 300 \
  --bs 64 \
  --lr 0.01 \
  --optimizer SGD \
  --ratio 0.2 \
  --device cuda:0


In [5]:
# 5) Save Artifacts and Plot History
import json
import matplotlib.pyplot as plt
import shutil
import os

# Paths
src_model = "/content/data/oxford_pets/vit_generator.pt"
src_history = "/content/data/oxford_pets/vit_generator_history.json"
dst_dir = "/content/drive/MyDrive/grad/comp_vision/hanson_loss/oxford_pets"

# Copy to Drive
os.makedirs(dst_dir, exist_ok=True)
if os.path.exists(src_model):
    shutil.copy(src_model, os.path.join(dst_dir, "vit_generator.pt"))
    print(f"Saved model to {dst_dir}/vit_generator.pt")
else:
    print("Model file not found!")

if os.path.exists(src_history):
    shutil.copy(src_history, os.path.join(dst_dir, "vit_generator_history.json"))
    print(f"Saved history to {dst_dir}/vit_generator_history.json")
    
    # Plot
    with open(src_history, 'r') as f:
        history = json.load(f)
    
    fig, ax1 = plt.subplots(figsize=(10, 6))
    
    color = 'tab:red'
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(history['epoch'], history['loss'], color=color, label='Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    
    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel('Accuracy', color=color)
    ax2.plot(history['epoch'], history['acc'], color=color, label='Accuracy')
    ax2.tick_params(axis='y', labelcolor=color)
    
    plt.title("ViT Generator Training Progress")
    plt.grid(True, alpha=0.3)
    plt.show()
else:
    print("History file not found! Did you run the training cell?")
