
# پروژه سوم بینایی کامپیوتر — ViT-Base-16 در برابر ResNet-50 (PyTorch, CIFAR-100)

این نوت‌بوک طبق **بخش‌های تمرین** پیاده‌سازی شده است:
1) **Setup & Data Preparation** (دانلود/پیش‌پردازش/افزودن اغتشاش/تقسیم‌بندی)
2) **Model Preparation** (ViT-Base-16 از `timm` و ResNet-50 از `torchvision` + تعویض هد)
3) **Training & Fine-tuning** (مرحله‌ی فریز هد + فاین‌تیون کامل، AdamW/SGD، EarlyStopping، ReduceLROnPlateau)
4) **Evaluation Metrics** (Top-1 روی train/val/test + زمان‌ها + پارامترها + FLOPs)
5) **Robustness Testing (Optional)** (نویز گاوسی و محو‌سازی مربعی)
6) **Interpretability (ViT)** (Attention Rollout)
7) **Plots & Report Artifacts** (نمودارها و فایل‌های متنی نتایج در `./outputs`)

> **Dataset:** CIFAR-100  
> **Image Size:** 224×224 (برای سازگاری با ViT و ResNet)


In [None]:

# اگر پکیج‌ها را ندارید، این را اجرا کنید (در Colab ممکن است لازم باشد).
# !pip install timm thop matplotlib


In [None]:

import os, time, math, random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import timm
from timm.models.vision_transformer import VisionTransformer
import matplotlib.pyplot as plt

try:
    from thop import profile
    THOP_AVAILABLE = True
except Exception:
    THOP_AVAILABLE = False

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class AverageMeter:
    def __init__(self): self.reset()
    def reset(self): self.sum, self.cnt = 0.0, 0
    def update(self, v, n=1): self.sum += v*n; self.cnt += n
    @property
    def avg(self): return self.sum / max(1, self.cnt)

def accuracy(outputs, targets):
    _, preds = torch.max(outputs, 1)
    return (preds == targets).float().mean().item()

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable


## 1) Setup & Data Preparation

In [None]:

CIFAR_MEAN = (0.5071, 0.4865, 0.4409)
CIFAR_STD  = (0.2673, 0.2564, 0.2762)

def get_transforms(img_size=224, train=True):
    if train:
        return transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(img_size, padding=4),
            transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
            transforms.ToTensor(),
            transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
        ])
    else:
        return transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
        ])

def get_dataloaders(data_dir="./data", img_size=224, batch_size=128, num_workers=4, val_split=0.1, seed=42):
    train_set = datasets.CIFAR100(root=data_dir, train=True, download=True, transform=get_transforms(img_size, True))
    test_set  = datasets.CIFAR100(root=data_dir, train=False, download=True, transform=get_transforms(img_size, False))
    n_train = len(train_set)
    n_val = int(n_train * val_split)
    n_train = n_train - n_val
    gen = torch.Generator().manual_seed(seed)
    train_subset, val_subset = torch.utils.data.random_split(train_set, [n_train, n_val], generator=gen)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader   = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, val_loader, test_loader


## 2) Model Preparation (ViT-Base-16 و ResNet-50)

In [None]:

def build_model(model_name: str, num_classes=100, pretrained=True):
    if model_name == "vit":
        model = timm.create_model("vit_base_patch16_224", pretrained=pretrained, num_classes=num_classes)
        assert isinstance(model, VisionTransformer)
        return model
    elif model_name == "resnet":
        m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)
        in_features = m.fc.in_features
        m.fc = nn.Linear(in_features, num_classes)
        return m
    else:
        raise ValueError("model_name must be 'vit' or 'resnet'")

def freeze_backbone(model, model_name):
    if model_name == "vit":
        for p in model.parameters(): p.requires_grad = False
        for p in model.head.parameters(): p.requires_grad = True
    else:
        for name, p in model.named_parameters(): p.requires_grad = False
        for p in model.fc.parameters(): p.requires_grad = True

def unfreeze_all(model):
    for p in model.parameters(): p.requires_grad = True


## 3) Training & Fine-tuning (+ EarlyStopping, Scheduler)

In [None]:

class EarlyStopping:
    def __init__(self, patience=5, verbose=True):
        self.patience = patience; self.counter = 0; self.best = -float('inf'); self.stop=False; self.verbose=verbose
    def step(self, metric):
        if metric > self.best + 1e-8:
            self.best = metric; self.counter = 0
        else:
            self.counter += 1
            if self.verbose: print(f"EarlyStopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience: self.stop=True

def epoch_pass(model, loader, optimizer=None, scaler=None):
    train = optimizer is not None
    model.train(train)
    lm, am = AverageMeter(), AverageMeter()
    t0 = time.time()
    for x, y in loader:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        if train:
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=scaler is not None):
                out = model(x); loss = F.cross_entropy(out, y)
            if scaler is not None:
                scaler.scale(loss).backward()
                scaler.step(optimizer); scaler.update()
            else:
                loss.backward(); optimizer.step()
        else:
            with torch.no_grad():
                out = model(x); loss = F.cross_entropy(out, y)
        am.update(accuracy(out, y), x.size(0)); lm.update(loss.item(), x.size(0))
    return lm.avg, am.avg, time.time()-t0

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    lm, am = AverageMeter(), AverageMeter()
    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        out = model(x); loss = F.cross_entropy(out, y)
        am.update(accuracy(out, y), x.size(0)); lm.update(loss.item(), x.size(0))
    return lm.avg, am.avg


## 4) Plots & Report Files

In [None]:

def plot_curves(hist, out_dir="./outputs"):
    os.makedirs(out_dir, exist_ok=True)
    plt.figure()
    plt.plot(hist['train_acc'], label='train_acc'); plt.plot(hist['val_acc'], label='val_acc')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Accuracy'); plt.legend(); plt.savefig(os.path.join(out_dir,'acc_curve.png'), dpi=150); plt.close()

    plt.figure()
    plt.plot(hist['train_loss'], label='train_loss'); plt.plot(hist['val_loss'], label='val_loss')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Loss'); plt.legend(); plt.savefig(os.path.join(out_dir,'loss_curve.png'), dpi=150); plt.close()


## 5) Robustness Testing (اختیاری)

In [None]:

@torch.no_grad()
def eval_with_noise(model, loader, noise_std=0.1):
    model.eval(); am = AverageMeter()
    for x, y in loader:
        x = x + noise_std * torch.randn_like(x)
        x = torch.clamp(x, -5, 5)
        x, y = x.to(DEVICE), y.to(DEVICE)
        out = model(x); am.update(accuracy(out, y), x.size(0))
    return am.avg

@torch.no_grad()
def eval_with_occlusion(model, loader, erase_size=32):
    model.eval(); am = AverageMeter()
    for x, y in loader:
        b,c,h,w = x.shape
        top = torch.randint(0, h-erase_size+1, (b,))
        left= torch.randint(0, w-erase_size+1, (b,))
        for i in range(b):
            x[i, :, top[i]:top[i]+erase_size, left[i]:left[i]+erase_size] = 0.0
        x, y = x.to(DEVICE), y.to(DEVICE)
        out = model(x); am.update(accuracy(out, y), x.size(0))
    return am.avg


## 6) Interpretability (ViT Attention Rollout)

In [None]:

class VitAttentionHook:
    def __init__(self, model: VisionTransformer):
        self.handles = []; self.attn_maps=[]
        for blk in model.blocks:
            h = blk.attn.attn_drop.register_forward_hook(self._hook); self.handles.append(h)
    def _hook(self, module, inp, out):
        self.attn_maps.append(inp[0].detach().cpu())
    def remove(self):
        for h in self.handles: h.remove()
        self.handles=[]

def attention_rollout(attn_list, discard_ratio=0.0):
    result=None
    for attn in attn_list:
        attn = attn.mean(dim=1)  # [B,N,N]
        I = torch.eye(attn.size(-1)).unsqueeze(0).expand_as(attn)
        attn = attn + I
        attn = attn / attn.sum(dim=-1, keepdim=True)
        result = attn if result is None else torch.bmm(result, attn)
    return result

@torch.no_grad()
def visualize_vit_attention(model, img_tensor, out_path="vit_attention.png", discard_ratio=0.0):
    model.eval()
    if not isinstance(model, VisionTransformer):
        print("Visualization only for ViT."); return
    hook = VitAttentionHook(model)
    _ = model(img_tensor.to(DEVICE))
    attn = attention_rollout(hook.attn_maps, discard_ratio)
    hook.remove()
    cls = attn[:,0,1:]
    B = cls.size(0); num = cls.size(1); g = int(math.sqrt(num))
    maps = cls.view(B,1,g,g)
    maps = torch.nn.functional.interpolate(maps, size=(img_tensor.size(-2), img_tensor.size(-1)), mode='bilinear', align_corners=False)
    m = maps[0,0].cpu().numpy()
    plt.imshow(m, cmap='jet', alpha=0.7); plt.axis('off'); plt.title('ViT Attention Rollout')
    os.makedirs("./outputs", exist_ok=True)
    plt.savefig("./outputs/vit_attention.png", dpi=150, bbox_inches='tight'); plt.close()


## 7) Evaluation Metrics (FLOPs/Params/Timing)

In [None]:

def compute_flops(model, img_size=224):
    if not THOP_AVAILABLE: return None
    dummy = torch.randn(1,3,img_size,img_size).to(DEVICE)
    macs, params = profile(model, (dummy,), verbose=False)
    return int(macs*2)

@torch.no_grad()
def measure_inference_time(model, img_size=224, repeats=50):
    model.eval()
    dummy = torch.randn(1,3,img_size,img_size).to(DEVICE)
    for _ in range(10): _ = model(dummy)
    if DEVICE.type=='cuda': torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(repeats): _ = model(dummy)
    if DEVICE.type=='cuda': torch.cuda.synchronize()
    return (time.time()-t0)/repeats


## 8) اجرای کامل (پیکربندی، آموزش، فاین‌تیون، ارزیابی، نمودارها)

In [None]:

# ===== پیکربندی =====
cfg = {
    "model": "vit",        # "vit" یا "resnet"
    "data_dir": "./data",
    "out_dir": "./outputs",
    "img_size": 224,
    "batch_size": 128,
    "epochs": 30,
    "ft_epochs": 15,
    "freeze_stage": True,  # مرحله 1: آموزش فقط هد
    "finetune_stage": True,# مرحله 2: فاین‌تیون کامل
    "lr": None,            # اگر None باشد، مقادیر مناسب پیش‌فرض انتخاب می‌شود
    "ft_lr": None,
    "weight_decay": 0.05,
    "num_workers": 4,
    "val_split": 0.1,
    "seed": 42,
    "es_patience": 5,
    "amp": True,
    "resnet_adamw": False
}

set_seed(cfg["seed"])
os.makedirs(cfg["out_dir"], exist_ok=True)

# ===== داده =====
train_loader, val_loader, test_loader = get_dataloaders(
    data_dir=cfg["data_dir"],
    img_size=cfg["img_size"],
    batch_size=cfg["batch_size"],
    num_workers=cfg["num_workers"],
    val_split=cfg["val_split"],
    seed=cfg["seed"]
)

# ===== مدل =====
model = build_model(cfg["model"], num_classes=100, pretrained=True).to(DEVICE)
if cfg["freeze_stage"]:
    freeze_backbone(model, cfg["model"])

total_params, trainable_params = count_params(model)
print(f"Total params: {total_params/1e6:.2f}M | Trainable: {trainable_params/1e6:.2f}M")

# ===== بهینه‌ساز و شِدولر =====
if cfg["model"] == "vit":
    lr = cfg["lr"] if cfg["lr"] is not None else (1e-3 if cfg["freeze_stage"] else 5e-5)
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=cfg["weight_decay"])
else:
    lr = cfg["lr"] if cfg["lr"] is not None else 1e-3
    if cfg["resnet_adamw"]:
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=cfg["weight_decay"])
    else:
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=0.9, weight_decay=cfg["weight_decay"])

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
early = EarlyStopping(patience=cfg["es_patience"], verbose=True)
scaler = torch.cuda.amp.GradScaler(enabled=cfg["amp"])

history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_val, best_path = -1.0, os.path.join(cfg["out_dir"], f"best_{cfg['model']}.pt")

print("=== Stage 1: Head-only training ===" if cfg["freeze_stage"] else "=== Single-stage training ===")
for epoch in range(1, cfg["epochs"]+1):
    tr_loss, tr_acc, tr_time = epoch_pass(model, train_loader, optimizer, scaler)
    va_loss, va_acc = evaluate(model, val_loader)
    scheduler.step(va_acc)

    history['train_loss'].append(tr_loss); history['train_acc'].append(tr_acc)
    history['val_loss'].append(va_loss);   history['val_acc'].append(va_acc)

    print(f"Epoch {epoch:03d} | train_loss={tr_loss:.4f} acc={tr_acc:.4f} | val_loss={va_loss:.4f} acc={va_acc:.4f} | {tr_time:.1f}s")

    if va_acc > best_val:
        best_val = va_acc
        torch.save(model.state_dict(), best_path)
        print(f"  >> Saved best: {best_path} (val_acc={best_val:.4f})")

    early.step(va_acc)
    if early.stop:
        print("Early stopping triggered."); break

# ===== مرحلهٔ 2: فاین‌تیون کامل =====
if cfg["finetune_stage"]:
    print("=== Stage 2: Full fine-tuning ===")
    if os.path.exists(best_path):
        model.load_state_dict(torch.load(best_path, map_location=DEVICE))
    unfreeze_all(model)

    if cfg["model"] == "vit":
        ft_lr = cfg["ft_lr"] if cfg["ft_lr"] is not None else 5e-5
        optimizer = torch.optim.AdamW(model.parameters(), lr=ft_lr, weight_decay=cfg["weight_decay"])
    else:
        if cfg["resnet_adamw"]:
            ft_lr = cfg["ft_lr"] if cfg["ft_lr"] is not None else 1e-4
            optimizer = torch.optim.AdamW(model.parameters(), lr=ft_lr, weight_decay=cfg["weight_decay"])
        else:
            ft_lr = cfg["ft_lr"] if cfg["ft_lr"] is not None else 1e-3
            optimizer = torch.optim.SGD(model.parameters(), lr=ft_lr, momentum=0.9, weight_decay=cfg["weight_decay"])

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
    early = EarlyStopping(patience=cfg["es_patience"], verbose=True)

    for epoch in range(1, cfg["ft_epochs"]+1):
        tr_loss, tr_acc, tr_time = epoch_pass(model, train_loader, optimizer, scaler)
        va_loss, va_acc = evaluate(model, val_loader)
        scheduler.step(va_acc)

        history['train_loss'].append(tr_loss); history['train_acc'].append(tr_acc)
        history['val_loss'].append(va_loss);   history['val_acc'].append(va_acc)

        print(f"[FT] Epoch {epoch:03d} | train_loss={tr_loss:.4f} acc={tr_acc:.4f} | val_loss={va_loss:.4f} acc={va_acc:.4f} | {tr_time:.1f}s")

        if va_acc > best_val:
            best_val = va_acc
            torch.save(model.state_dict(), best_path)
            print(f"  >> Saved best: {best_path} (val_acc={best_val:.4f})")

        early.step(va_acc)
        if early.stop:
            print("Early stopping (fine-tune) triggered."); break

# ===== ارزیابی تست =====
model.load_state_dict(torch.load(best_path, map_location=DEVICE))
test_loss, test_acc = evaluate(model, test_loader)
print(f"TEST: loss={test_loss:.4f} acc={test_acc:.4f}")

# ===== نمودارها =====
plot_curves(history, cfg["out_dir"])

# ===== FLOPs و زمان =====
flops = compute_flops(model, cfg["img_size"])
tot_p, tr_p = count_params(model)
inf_t = measure_inference_time(model, cfg["img_size"], repeats=50)

with open(os.path.join(cfg["out_dir"], f"{cfg['model']}_metrics.txt"), "w") as f:
    f.write(f"Test loss: {test_loss:.4f}\nTest acc: {test_acc:.4f}\n")
    f.write(f"Total params: {tot_p}\nTrainable params: {tr_p}\n")
    f.write(f"Inference time per sample (s): {inf_t}\n")
    f.write(f"FLOPs (approx): {flops if flops is not None else 'thop not installed'}\n")

# ===== Robustness =====
noise_acc = eval_with_noise(model, test_loader, noise_std=0.1)
occ_acc   = eval_with_occlusion(model, test_loader, erase_size=32)
with open(os.path.join(cfg["out_dir"], f"{cfg['model']}_robustness.txt"), "w") as f:
    f.write(f"Gaussian noise std=0.1 acc: {noise_acc}\n")
    f.write(f"Occlusion 32x32 acc: {occ_acc}\n")

# ===== تفسیرپذیری ViT =====
if cfg["model"] == "vit":
    imgs, _ = next(iter(test_loader))
    img0 = imgs[0:1].to(DEVICE)
    visualize_vit_attention(model, img0, out_path=os.path.join(cfg["out_dir"], "vit_attention.png"))

print("Done. Outputs in:", cfg["out_dir"])
