<a href="https://colab.research.google.com/github/1hamzaiqbal/MFCLIP_acv/blob/hamza%2Fdiscrim/vit_generator_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# 1) Mount Google Drive (Run this first!)
from google.colab import drive
drive.mount('/content/drive', force_remount=True)


In [2]:
# 2) Setup Repo & Dependencies
!nvidia-smi
%cd /content
import os
if not os.path.exists("MFCLIP_acv"):
    !git clone -b hamza/discrim https://github.com/1hamzaiqbal/MFCLIP_acv
%cd MFCLIP_acv
!git pull origin hamza/discrim  # Ensure latest code

!pip install torch torchvision timm einops yacs tqdm opencv-python scikit-learn scipy pyyaml ruamel.yaml pytorch-ignite foolbox pandas matplotlib seaborn wilds ftfy


In [3]:
# 3) Setup Data & Checkpoint
import shutil
import os
from torchvision.datasets import OxfordIIITPet
from torchvision import transforms
from pathlib import Path

# Download Dataset
root = Path("/content/data/oxford_pets")
root.mkdir(parents=True, exist_ok=True)
_ = OxfordIIITPet(root=str(root), download=True, transform=transforms.ToTensor())

# Fetch Annotations
%cd /content
!mkdir -p /content/data/oxford_pets
!wget -q https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
!wget -q https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
!tar -xf images.tar.gz -C /content/data/oxford_pets
!tar -xf annotations.tar.gz -C /content/data/oxford_pets

# Copy Checkpoint from Drive (Optional if only visualizing)
src_ckpt = "/content/drive/MyDrive/grad/comp_vision/hanson_loss/oxford_pets/RN50_ArcFace_oxford_pets.pth"
dst_ckpt = "/content/data/oxford_pets/RN50_ArcFace.pth"

if os.path.exists(src_ckpt):
    shutil.copy(src_ckpt, dst_ckpt)
    print(f"Successfully copied checkpoint to {dst_ckpt}")
else:
    print(f"WARNING: Checkpoint not found at {src_ckpt}. Training will fail, but Visualization will work if you have a trained generator.")


In [4]:
# 4) Train ViT Generator (Skip this if you already have a trained model)
%cd /content/MFCLIP_acv
!python main.py \
  --flag train_unet \
  --generator vit \
  --dataset oxford_pets \
  --root /content/data \
  --config-file configs/trainers/CoOp/rn50.yaml \
  --dataset-config-file configs/datasets/oxford_pets.yaml \
  --trainer ZeroshotCLIP \
  --surrogate RN50 \
  --head ArcFace \
  --num_epoch 300 \
  --bs 64 \
  --lr 0.01 \
  --optimizer SGD \
  --ratio 0.2 \
  --device cuda:0


In [5]:
# 5) Save Artifacts and Plot History
import json
import matplotlib.pyplot as plt
import shutil
import os

# Paths
src_model = "/content/data/oxford_pets/vit_generator.pt"
src_history = "/content/data/oxford_pets/vit_generator_history.json"
dst_dir = "/content/drive/MyDrive/grad/comp_vision/hanson_loss/oxford_pets"

# Copy to Drive
os.makedirs(dst_dir, exist_ok=True)
if os.path.exists(src_model):
    shutil.copy(src_model, os.path.join(dst_dir, "vit_generator.pt"))
    print(f"Saved model to {dst_dir}/vit_generator.pt")
else:
    print("Model file not found! (Did you skip training?)")

if os.path.exists(src_history):
    shutil.copy(src_history, os.path.join(dst_dir, "vit_generator_history.json"))
    print(f"Saved history to {dst_dir}/vit_generator_history.json")
    
    # Plot
    with open(src_history, 'r') as f:
        history = json.load(f)
    
    fig, ax1 = plt.subplots(figsize=(10, 6))
    
    color = 'tab:red'
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(history['epoch'], history['loss'], color=color, label='Loss')
    ax1.tick_params(axis='y', labelcolor=color)
    
    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel('Accuracy', color=color)
    ax2.plot(history['epoch'], history['acc'], color=color, label='Accuracy')
    ax2.tick_params(axis='y', labelcolor=color)
    
    plt.title("ViT Generator Training Progress")
    plt.grid(True, alpha=0.3)
    plt.show()
else:
    print("History file not found! (Did you skip training?)")


In [6]:
# 6) Visualize Adversarial Example
import torch
import matplotlib.pyplot as plt
import numpy as np
from model import ViTGenerator
from torchvision import transforms
from PIL import Image
import os
import shutil

# --- LOAD CHECKPOINT FROM DRIVE IF NEEDED ---
local_ckpt = "/content/data/oxford_pets/vit_generator.pt"
drive_ckpt = "/content/drive/MyDrive/grad/comp_vision/hanson_loss/oxford_pets/vit_generator.pt"

if not os.path.exists(local_ckpt):
    print(f"Local checkpoint not found at {local_ckpt}")
    if os.path.exists(drive_ckpt):
        print(f"Found checkpoint in Drive at {drive_ckpt}. Copying...")
        shutil.copy(drive_ckpt, local_ckpt)
        print("Copy complete.")
    else:
        print(f"WARNING: Checkpoint not found in Drive either ({drive_ckpt}). Visualization will use random weights!")

# Load Generator
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = ViTGenerator().to(device)

if os.path.exists(local_ckpt):
    generator.load_state_dict(torch.load(local_ckpt, map_location=device))
    print("Generator loaded successfully.")
else:
    print("Using random weights (Generator not loaded).")
generator.eval()

# Load a sample image
img_path = "/content/data/oxford_pets/images/Abyssinian_1.jpg" # Example image
if not os.path.exists(img_path):
    # Fallback if specific image doesn't exist, pick first one
    import glob
    images = glob.glob("/content/data/oxford_pets/images/*.jpg")
    if images:
        img_path = images[0]
    else:
        print("No images found to visualize. Did you run the 'Setup Data' cell?")
        img_path = None

if img_path:
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    
    img = Image.open(img_path).convert('RGB')
    img_t = transform(img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        noise = generator(img_t)
        # Clamp noise for visualization (eps=10/255 approx 0.04)
        eps = 10/255
        noise = torch.clamp(noise, -eps, eps)
        adv_img = torch.clamp(img_t + noise, 0, 1)
    
    # Helper to plot
    def show_tensor(t, ax, title):
        im = t.squeeze().cpu().permute(1, 2, 0).numpy()
        # Normalize noise for better visibility if needed, but here we show raw
        if title == "Noise (Amplified)":
            im = (im - im.min()) / (im.max() - im.min())
        ax.imshow(im)
        ax.set_title(title)
        ax.axis('off')

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    show_tensor(img_t, axs[0], "Original")
    show_tensor(noise, axs[1], "Noise (Amplified)")
    show_tensor(adv_img, axs[2], "Adversarial")
    plt.show()
