In [None]:
# ============================================================================
# COMPLETE STANDALONE ATTACK EVALUATION - SINGLE CELL
# ============================================================================
print("Installing dependencies...")
import subprocess
import os

# Install
subprocess.run(["pip", "install", "medmnist", "--quiet"], check=True)
subprocess.run(["pip", "install", "git+https://github.com/openai/CLIP.git", "--quiet"], check=True)

# Now import everything
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from medmnist.dataset import PathMNIST, TissueMNIST, OrganAMNIST, OCTMNIST
import clip
from tqdm import tqdm
import json
import matplotlib.pyplot as plt
import numpy as np

print("✅ All dependencies loaded!")

# ========================================
# CONFIGURATION - CHANGE THESE
# ========================================
MODEL_PATH = "/kaggle/input/example/pytorch/default/1/Example/checkpoints_fedavg_organamnist/final_global_model.pth"
DATASET_NAME = "organamnist"
BATCH_SIZE = 32
SAVE_DIR = "/kaggle/working/attack_results"
# ========================================

os.makedirs('./data/medmnist', exist_ok=True)
os.makedirs(SAVE_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset configs
DATASET_CONFIGS = {
    'pathmnist': {
        'num_classes': 9,
        'class': PathMNIST,
        'class_names': ["adipose", "background", "debris", "lymphocytes",
                       "mucus", "smooth muscle", "normal colon mucosa",
                       "cancer-associated stroma", "colorectal adenocarcinoma epithelium"]
    },
    'tissuemnist': {
        'num_classes': 8,
        'class': TissueMNIST,
        'class_names': ["collecting duct", "thick ascending limb",
                       "distal convoluted tubule", "proximal tubule",
                       "glomerular tuft", "blood vessel", "macula densa",
                       "interstitial fibrosis"]
    },
    'organamnist': {
        'num_classes': 11,
        'class': OrganAMNIST,
        'class_names': ["bladder", "femur-left", "femur-right", "heart",
                       "kidneys", "liver", "lungs", "pancreas",
                       "pelvis", "spleen", "kidney cyst"]
    },
    'octmnist': {
        'num_classes': 4,
        'class': OCTMNIST,
        'class_names': ["choroidal neovascularization", "diabetic macular edema",
                       "drusen", "normal"]
    }
}

# Model class
class CLIPMedMNISTClassifier(nn.Module):
    def __init__(self, num_classes, device, class_names=None):
        super(CLIPMedMNISTClassifier, self).__init__()
        self.clip_model, _ = clip.load("ViT-B/32", device=device)
        self.num_classes = num_classes
        
        for param in self.clip_model.transformer.parameters():
            param.requires_grad = False
        for param in self.clip_model.token_embedding.parameters():
            param.requires_grad = False
        for param in self.clip_model.ln_final.parameters():
            param.requires_grad = False
        self.clip_model.positional_embedding.requires_grad = False
        self.clip_model.text_projection.requires_grad = False
        
        if class_names:
            with torch.no_grad():
                text_tokens = clip.tokenize([f"a microscopic image of {c}" for c in class_names]).to(device)
                text_features = self.clip_model.encode_text(text_tokens)
                text_features /= text_features.norm(dim=-1, keepdim=True)
            self.register_buffer('text_features', text_features)
    
    def forward(self, images):
        image_features = self.clip_model.encode_image(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        return 100.0 * image_features @ self.text_features.T

# Attack functions
def fgsm(model, x, y, device, eps=0.03):
    x = x.clone().detach().requires_grad_(True)
    out = model(x)
    out = out.logits if hasattr(out, 'logits') else out
    loss = F.cross_entropy(out, y)
    loss.backward()
    return torch.clamp(x + eps * x.grad.sign(), -1, 1).detach()

def pgd(model, x, y, device, eps=0.03, alpha=0.01, steps=10):
    x_adv = x + torch.empty_like(x).uniform_(-eps, eps)
    x_adv = torch.clamp(x_adv, -1, 1)
    for _ in range(steps):
        x_adv.requires_grad_(True)
        out = model(x_adv)
        out = out.logits if hasattr(out, 'logits') else out
        loss = F.cross_entropy(out, y)
        loss.backward()
        x_adv = x_adv.detach() + alpha * x_adv.grad.sign()
        x_adv = torch.clamp(x_adv, x - eps, x + eps)
        x_adv = torch.clamp(x_adv, -1, 1)
    return x_adv.detach()

def bim(model, x, y, device, eps=0.03, alpha=0.01, steps=10):
    x_adv = x.clone()
    for _ in range(steps):
        x_adv.requires_grad_(True)
        out = model(x_adv)
        out = out.logits if hasattr(out, 'logits') else out
        loss = F.cross_entropy(out, y)
        loss.backward()
        x_adv = x_adv.detach() + alpha * x_adv.grad.sign()
        x_adv = torch.clamp(x_adv, x - eps, x + eps)
        x_adv = torch.clamp(x_adv, -1, 1)
    return x_adv.detach()

def mifgsm(model, x, y, device, eps=0.03, alpha=0.01, steps=10, decay=1.0):
    momentum = torch.zeros_like(x).to(device)
    x_adv = x.clone()
    for _ in range(steps):
        x_adv.requires_grad_(True)
        out = model(x_adv)
        out = out.logits if hasattr(out, 'logits') else out
        loss = F.cross_entropy(out, y)
        loss.backward()
        grad = x_adv.grad.data
        grad = grad / (torch.mean(torch.abs(grad), dim=(1,2,3), keepdim=True) + 1e-8)
        momentum = decay * momentum + grad
        x_adv = x_adv.detach() + alpha * momentum.sign()
        x_adv = torch.clamp(x_adv, x - eps, x + eps)
        x_adv = torch.clamp(x_adv, -1, 1)
    return x_adv.detach()

def deepfool(model, x, y, device, steps=30):
    perturbed = x.clone()
    for idx in range(x.size(0)):
        img = x[idx:idx+1].clone()
        pert = img.clone()
        for _ in range(steps):
            pert.requires_grad_(True)
            out = model(pert)
            out = out.logits if hasattr(out, 'logits') else out
            pred = out.max(1)[1].item()
            if pred != y[idx].item():
                break
            model.zero_grad()
            out[0, pred].backward(retain_graph=True)
            grad_pred = pert.grad.data.clone()
            min_dist, min_grad = 1e10, None
            for k in range(out.size(1)):
                if k == pred:
                    continue
                model.zero_grad()
                pert.grad = None
                try:
                    out[0, k].backward(retain_graph=True)
                    grad_k = pert.grad.data.clone()
                    w_k = grad_k - grad_pred
                    f_k = out[0, k] - out[0, pred]
                    dist = abs(f_k.item()) / (torch.norm(w_k.flatten()).item() + 1e-8)
                    if dist < min_dist:
                        min_dist, min_grad = dist, w_k
                except:
                    continue
            if min_grad is not None:
                r = (min_dist + 1e-4) * min_grad / (torch.norm(min_grad.flatten()) + 1e-8)
                pert = torch.clamp(pert.detach() + 1.02 * r, -1, 1)
            else:
                break
        perturbed[idx:idx+1] = pert.detach()
    return perturbed

# Load data
print(f"\n{'='*70}")
print(" LOADING DATA")
print('='*70)
config = DATASET_CONFIGS[DATASET_NAME.lower()]
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
test_dataset = config['class'](root='./data/medmnist', split='test', download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
print(f"✅ Loaded {len(test_dataset)} test samples | Classes: {config['num_classes']}")

# Load model
print(f"\n{'='*70}")
print(" LOADING MODEL")
print('='*70)
model = CLIPMedMNISTClassifier(config['num_classes'], device, config['class_names'])
checkpoint = torch.load(MODEL_PATH, map_location=device)
if isinstance(checkpoint, dict):
    if 'server_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['server_state_dict'])
    elif 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
else:
    model.load_state_dict(checkpoint)
model.eval()
print("✅ Model loaded successfully")

# Run attacks
print(f"\n{'='*70}")
print(" RUNNING 5 ATTACKS")
print('='*70)

attacks = {
    'FGSM': lambda x, y: fgsm(model, x, y, device),
    'PGD': lambda x, y: pgd(model, x, y, device),
    'BIM': lambda x, y: bim(model, x, y, device),
    'MI-FGSM': lambda x, y: mifgsm(model, x, y, device),
    'DeepFool': lambda x, y: deepfool(model, x, y, device)
}

results = []
for name, attack_fn in attacks.items():
    print(f"\n[{name}]")
    clean_c, adv_c, total = 0, 0, 0
    for imgs, lbls in tqdm(test_loader, desc=f"  {name:12s}", ncols=100, leave=False):
        imgs, lbls = imgs.to(device), lbls.to(device).squeeze()
        with torch.no_grad():
            out = model(imgs)
            clean_c += (out.max(1)[1] == lbls).sum().item()
        try:
            adv = attack_fn(imgs, lbls)
            with torch.no_grad():
                out = model(adv)
                adv_c += (out.max(1)[1] == lbls).sum().item()
        except:
            adv_c += imgs.size(0)
        total += lbls.size(0)
    
    clean_acc = 100 * clean_c / total
    adv_acc = 100 * adv_c / total
    asr = 100 * (clean_c - adv_c) / clean_c if clean_c > 0 else 0
    print(f"  ✅ Clean: {clean_acc:.2f}% | Adv: {adv_acc:.2f}% | ASR: {asr:.2f}%")
    results.append({'attack': name, 'clean_accuracy': round(clean_acc, 2),
                    'adversarial_accuracy': round(adv_acc, 2), 'attack_success_rate': round(asr, 2)})

# Save & plot
with open(f"{SAVE_DIR}/attack_results.json", 'w') as f:
    json.dump(results, f, indent=2)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
names = [r['attack'] for r in results]
clean = [r['clean_accuracy'] for r in results]
adv = [r['adversarial_accuracy'] for r in results]
asr = [r['attack_success_rate'] for r in results]
x = np.arange(len(names))
ax1.bar(x - 0.175, clean, 0.35, label='Clean', alpha=0.8, color='#2ecc71')
ax1.bar(x + 0.175, adv, 0.35, label='Adversarial', alpha=0.8, color='#e74c3c')
ax1.set_xlabel('Attack'); ax1.set_ylabel('Accuracy (%)')
ax1.set_title('Robustness', fontweight='bold'); ax1.set_xticks(x)
ax1.set_xticklabels(names, rotation=15); ax1.legend(); ax1.grid(alpha=0.3, axis='y')
ax2.bar(names, asr, color=['#e74c3c', '#e67e22', '#f39c12', '#d35400', '#c0392b'], alpha=0.8)
ax2.set_xlabel('Attack'); ax2.set_ylabel('ASR (%)')
ax2.set_title('Attack Effectiveness', fontweight='bold'); ax2.set_xticklabels(names, rotation=15)
ax2.grid(alpha=0.3, axis='y')
plt.tight_layout()
plt.savefig(f"{SAVE_DIR}/attack_results.png", dpi=300, bbox_inches='tight')
plt.close()

print(f"\n{'='*70}")
print(" COMPLETE!")
print('='*70)
print(f"✅ JSON: {SAVE_DIR}/attack_results.json")
print(f"✅ Plot: {SAVE_DIR}/attack_results.png")
for r in results:
    print(f"  {r['attack']:10s} | Clean: {r['clean_accuracy']:5.2f}% | Adv: {r['adversarial_accuracy']:5.2f}% | ASR: {r['attack_success_rate']:5.2f}%")
print('='*70)

Installing dependencies...
✅ All dependencies loaded!

 LOADING DATA


100%|██████████| 38.2M/38.2M [00:41<00:00, 927kB/s] 


✅ Loaded 17778 test samples | Classes: 11

 LOADING MODEL


100%|████████████████████████████████████████| 338M/338M [00:01<00:00, 262MiB/s]


✅ Model loaded successfully

 RUNNING 5 ATTACKS

[FGSM]


                                                                                                    

  ✅ Clean: 90.12% | Adv: 56.25% | ASR: 37.59%

[PGD]


                                                                                                    

  ✅ Clean: 90.12% | Adv: 45.18% | ASR: 49.87%

[BIM]


                                                                                                    

  ✅ Clean: 90.12% | Adv: 44.76% | ASR: 50.34%

[MI-FGSM]


                                                                                                    

  ✅ Clean: 90.12% | Adv: 44.80% | ASR: 50.29%

[DeepFool]


  DeepFool    :   1%|▋                                            | 8/556 [02:19<2:28:55, 16.31s/it]