In [None]:
# ============================================================================
# COMPLETE STANDALONE ATTACK EVALUATION - SINGLE CELL (WITH EXTENDED METRICS)
# ============================================================================
print("Installing dependencies...")
import subprocess
import os

# Install
subprocess.run(["pip", "install", "medmnist", "--quiet"], check=True)
subprocess.run(["pip", "install", "git+https://github.com/openai/CLIP.git", "--quiet"], check=True)
subprocess.run(["pip", "install", "scikit-learn", "--quiet"], check=True)

# Now import everything
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from medmnist.dataset import PathMNIST, TissueMNIST, OrganAMNIST, OCTMNIST
import clip
from tqdm import tqdm
import json
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, mean_squared_error

print("✅ All dependencies loaded!")

# ========================================
# CONFIGURATION - CHANGE THESE
# ========================================
MODEL_PATH = "/kaggle/input/example/pytorch/default/1/Example/checkpoints_fedavg_organamnist/final_global_model.pth"
DATASET_NAME = "organamnist"
BATCH_SIZE = 32
SAVE_DIR = "/kaggle/working/attack_results"
# ========================================

os.makedirs('./data/medmnist', exist_ok=True)
os.makedirs(SAVE_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset configs
DATASET_CONFIGS = {
    'pathmnist': {
        'num_classes': 9,
        'class': PathMNIST,
        'class_names': ["adipose", "background", "debris", "lymphocytes",
                       "mucus", "smooth muscle", "normal colon mucosa",
                       "cancer-associated stroma", "colorectal adenocarcinoma epithelium"]
    },
    'tissuemnist': {
        'num_classes': 8,
        'class': TissueMNIST,
        'class_names': ["collecting duct", "thick ascending limb",
                       "distal convoluted tubule", "proximal tubule",
                       "glomerular tuft", "blood vessel", "macula densa",
                       "interstitial fibrosis"]
    },
    'organamnist': {
        'num_classes': 11,
        'class': OrganAMNIST,
        'class_names': ["bladder", "femur-left", "femur-right", "heart",
                       "kidneys", "liver", "lungs", "pancreas",
                       "pelvis", "spleen", "kidney cyst"]
    },
    'octmnist': {
        'num_classes': 4,
        'class': OCTMNIST,
        'class_names': ["choroidal neovascularization", "diabetic macular edema",
                       "drusen", "normal"]
    }
}

# Model class
class CLIPMedMNISTClassifier(nn.Module):
    def __init__(self, num_classes, device, class_names=None):
        super(CLIPMedMNISTClassifier, self).__init__()
        self.clip_model, _ = clip.load("ViT-B/32", device=device)
        self.num_classes = num_classes
        
        for param in self.clip_model.transformer.parameters():
            param.requires_grad = False
        for param in self.clip_model.token_embedding.parameters():
            param.requires_grad = False
        for param in self.clip_model.ln_final.parameters():
            param.requires_grad = False
        self.clip_model.positional_embedding.requires_grad = False
        self.clip_model.text_projection.requires_grad = False
        
        if class_names:
            with torch.no_grad():
                text_tokens = clip.tokenize([f"a microscopic image of {c}" for c in class_names]).to(device)
                text_features = self.clip_model.encode_text(text_tokens)
                text_features /= text_features.norm(dim=-1, keepdim=True)
            self.register_buffer('text_features', text_features)
    
    def forward(self, images):
        image_features = self.clip_model.encode_image(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        return 100.0 * image_features @ self.text_features.T

# Attack functions
def fgsm(model, x, y, device, eps=0.03):
    x = x.clone().detach().requires_grad_(True)
    out = model(x)
    out = out.logits if hasattr(out, 'logits') else out
    loss = F.cross_entropy(out, y)
    loss.backward()
    return torch.clamp(x + eps * x.grad.sign(), -1, 1).detach()

def pgd(model, x, y, device, eps=0.03, alpha=0.01, steps=10):
    x_adv = x + torch.empty_like(x).uniform_(-eps, eps)
    x_adv = torch.clamp(x_adv, -1, 1)
    for _ in range(steps):
        x_adv.requires_grad_(True)
        out = model(x_adv)
        out = out.logits if hasattr(out, 'logits') else out
        loss = F.cross_entropy(out, y)
        loss.backward()
        x_adv = x_adv.detach() + alpha * x_adv.grad.sign()
        x_adv = torch.clamp(x_adv, x - eps, x + eps)
        x_adv = torch.clamp(x_adv, -1, 1)
    return x_adv.detach()

def bim(model, x, y, device, eps=0.03, alpha=0.01, steps=10):
    x_adv = x.clone()
    for _ in range(steps):
        x_adv.requires_grad_(True)
        out = model(x_adv)
        out = out.logits if hasattr(out, 'logits') else out
        loss = F.cross_entropy(out, y)
        loss.backward()
        x_adv = x_adv.detach() + alpha * x_adv.grad.sign()
        x_adv = torch.clamp(x_adv, x - eps, x + eps)
        x_adv = torch.clamp(x_adv, -1, 1)
    return x_adv.detach()

def mifgsm(model, x, y, device, eps=0.03, alpha=0.01, steps=10, decay=1.0):
    momentum = torch.zeros_like(x).to(device)
    x_adv = x.clone()
    for _ in range(steps):
        x_adv.requires_grad_(True)
        out = model(x_adv)
        out = out.logits if hasattr(out, 'logits') else out
        loss = F.cross_entropy(out, y)
        loss.backward()
        grad = x_adv.grad.data
        grad = grad / (torch.mean(torch.abs(grad), dim=(1,2,3), keepdim=True) + 1e-8)
        momentum = decay * momentum + grad
        x_adv = x_adv.detach() + alpha * momentum.sign()
        x_adv = torch.clamp(x_adv, x - eps, x + eps)
        x_adv = torch.clamp(x_adv, -1, 1)
    return x_adv.detach()

def deepfool(model, x, y, device, steps=30):
    perturbed = x.clone()
    for idx in range(x.size(0)):
        img = x[idx:idx+1].clone()
        pert = img.clone()
        for _ in range(steps):
            pert.requires_grad_(True)
            out = model(pert)
            out = out.logits if hasattr(out, 'logits') else out
            pred = out.max(1)[1].item()
            if pred != y[idx].item():
                break
            model.zero_grad()
            out[0, pred].backward(retain_graph=True)
            grad_pred = pert.grad.data.clone()
            min_dist, min_grad = 1e10, None
            for k in range(out.size(1)):
                if k == pred:
                    continue
                model.zero_grad()
                pert.grad = None
                try:
                    out[0, k].backward(retain_graph=True)
                    grad_k = pert.grad.data.clone()
                    w_k = grad_k - grad_pred
                    f_k = out[0, k] - out[0, pred]
                    dist = abs(f_k.item()) / (torch.norm(w_k.flatten()).item() + 1e-8)
                    if dist < min_dist:
                        min_dist, min_grad = dist, w_k
                except:
                    continue
            if min_grad is not None:
                r = (min_dist + 1e-4) * min_grad / (torch.norm(min_grad.flatten()) + 1e-8)
                pert = torch.clamp(pert.detach() + 1.02 * r, -1, 1)
            else:
                break
        perturbed[idx:idx+1] = pert.detach()
    return perturbed

# Load data
print(f"\n{'='*70}")
print(" LOADING DATA")
print('='*70)
config = DATASET_CONFIGS[DATASET_NAME.lower()]
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
test_dataset = config['class'](root='./data/medmnist', split='test', download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
print(f"✅ Loaded {len(test_dataset)} test samples | Classes: {config['num_classes']}")

# Load model
print(f"\n{'='*70}")
print(" LOADING MODEL")
print('='*70)
model = CLIPMedMNISTClassifier(config['num_classes'], device, config['class_names'])
checkpoint = torch.load(MODEL_PATH, map_location=device)
if isinstance(checkpoint, dict):
    if 'server_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['server_state_dict'])
    elif 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
else:
    model.load_state_dict(checkpoint)
model.eval()
print("✅ Model loaded successfully")

# Run attacks
print(f"\n{'='*70}")
print(" RUNNING 5 ATTACKS WITH EXTENDED METRICS")
print('='*70)

attacks = {
    'FGSM': lambda x, y: fgsm(model, x, y, device),
    'PGD': lambda x, y: pgd(model, x, y, device),
    'BIM': lambda x, y: bim(model, x, y, device),
    'MI-FGSM': lambda x, y: mifgsm(model, x, y, device),
    'DeepFool': lambda x, y: deepfool(model, x, y, device)
}

results = []
for name, attack_fn in attacks.items():
    print(f"\n[{name}]")
    
    # Storage for predictions and labels
    clean_preds, clean_labels = [], []
    adv_preds, adv_labels = [], []
    clean_probs, adv_probs = [], []
    
    for imgs, lbls in tqdm(test_loader, desc=f"  {name:12s}", ncols=100, leave=False):
        imgs, lbls = imgs.to(device), lbls.to(device).squeeze()
        
        # Clean predictions
        with torch.no_grad():
            out = model(imgs)
            probs = F.softmax(out, dim=1)
            preds = out.max(1)[1]
            
            clean_preds.extend(preds.cpu().numpy())
            clean_labels.extend(lbls.cpu().numpy())
            clean_probs.extend(probs.cpu().numpy())
        
        # Adversarial predictions
        try:
            adv = attack_fn(imgs, lbls)
            with torch.no_grad():
                out = model(adv)
                probs = F.softmax(out, dim=1)
                preds = out.max(1)[1]
                
                adv_preds.extend(preds.cpu().numpy())
                adv_labels.extend(lbls.cpu().numpy())
                adv_probs.extend(probs.cpu().numpy())
        except:
            # If attack fails, use clean predictions
            adv_preds.extend(clean_preds[-len(lbls):])
            adv_labels.extend(lbls.cpu().numpy())
            adv_probs.extend(clean_probs[-len(lbls):])
    
    # Convert to numpy
    clean_preds = np.array(clean_preds)
    clean_labels = np.array(clean_labels)
    adv_preds = np.array(adv_preds)
    adv_labels = np.array(adv_labels)
    clean_probs = np.array(clean_probs)
    adv_probs = np.array(adv_probs)
    
    # Calculate metrics for clean
    clean_acc = 100 * (clean_preds == clean_labels).mean()
    clean_prec = 100 * precision_score(clean_labels, clean_preds, average='weighted', zero_division=0)
    clean_rec = 100 * recall_score(clean_labels, clean_preds, average='weighted', zero_division=0)
    clean_f1 = 100 * f1_score(clean_labels, clean_preds, average='weighted', zero_division=0)
    
    # Calculate metrics for adversarial
    adv_acc = 100 * (adv_preds == adv_labels).mean()
    adv_prec = 100 * precision_score(adv_labels, adv_preds, average='weighted', zero_division=0)
    adv_rec = 100 * recall_score(adv_labels, adv_preds, average='weighted', zero_division=0)
    adv_f1 = 100 * f1_score(adv_labels, adv_preds, average='weighted', zero_division=0)
    
    # Calculate RMSE (on probabilities)
    clean_rmse = np.sqrt(mean_squared_error(
        np.eye(config['num_classes'])[clean_labels], clean_probs
    ))
    adv_rmse = np.sqrt(mean_squared_error(
        np.eye(config['num_classes'])[adv_labels], adv_probs
    ))
    
    # Attack Success Rate
    asr = 100 * (clean_preds != adv_preds).mean()
    
    print(f"  CLEAN METRICS:")
    print(f"    Accuracy:  {clean_acc:.2f}%")
    print(f"    Precision: {clean_prec:.2f}%")
    print(f"    Recall:    {clean_rec:.2f}%")
    print(f"    F1-Score:  {clean_f1:.2f}%")
    print(f"    RMSE:      {clean_rmse:.4f}")
    print(f"  ADVERSARIAL METRICS:")
    print(f"    Accuracy:  {adv_acc:.2f}%")
    print(f"    Precision: {adv_prec:.2f}%")
    print(f"    Recall:    {adv_rec:.2f}%")
    print(f"    F1-Score:  {adv_f1:.2f}%")
    print(f"    RMSE:      {adv_rmse:.4f}")
    print(f"  ASR: {asr:.2f}%")
    
    results.append({
        'attack': name,
        'clean_metrics': {
            'accuracy': round(clean_acc, 2),
            'precision': round(clean_prec, 2),
            'recall': round(clean_rec, 2),
            'f1_score': round(clean_f1, 2),
            'rmse': round(clean_rmse, 4)
        },
        'adversarial_metrics': {
            'accuracy': round(adv_acc, 2),
            'precision': round(adv_prec, 2),
            'recall': round(adv_rec, 2),
            'f1_score': round(adv_f1, 2),
            'rmse': round(adv_rmse, 4)
        },
        'attack_success_rate': round(asr, 2)
    })

# Save results
with open(f"{SAVE_DIR}/attack_results_extended.json", 'w') as f:
    json.dump(results, f, indent=2)

# Create comprehensive visualization
fig = plt.figure(figsize=(18, 10))
gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

names = [r['attack'] for r in results]
x = np.arange(len(names))

# Plot 1: Accuracy
ax1 = fig.add_subplot(gs[0, 0])
clean_acc = [r['clean_metrics']['accuracy'] for r in results]
adv_acc = [r['adversarial_metrics']['accuracy'] for r in results]
ax1.bar(x - 0.2, clean_acc, 0.4, label='Clean', alpha=0.8, color='#2ecc71')
ax1.bar(x + 0.2, adv_acc, 0.4, label='Adversarial', alpha=0.8, color='#e74c3c')
ax1.set_ylabel('Accuracy (%)'); ax1.set_title('Accuracy Comparison', fontweight='bold')
ax1.set_xticks(x); ax1.set_xticklabels(names, rotation=15)
ax1.legend(); ax1.grid(alpha=0.3, axis='y')

# Plot 2: Precision
ax2 = fig.add_subplot(gs[0, 1])
clean_prec = [r['clean_metrics']['precision'] for r in results]
adv_prec = [r['adversarial_metrics']['precision'] for r in results]
ax2.bar(x - 0.2, clean_prec, 0.4, label='Clean', alpha=0.8, color='#3498db')
ax2.bar(x + 0.2, adv_prec, 0.4, label='Adversarial', alpha=0.8, color='#e67e22')
ax2.set_ylabel('Precision (%)'); ax2.set_title('Precision Comparison', fontweight='bold')
ax2.set_xticks(x); ax2.set_xticklabels(names, rotation=15)
ax2.legend(); ax2.grid(alpha=0.3, axis='y')

# Plot 3: Recall
ax3 = fig.add_subplot(gs[0, 2])
clean_rec = [r['clean_metrics']['recall'] for r in results]
adv_rec = [r['adversarial_metrics']['recall'] for r in results]
ax3.bar(x - 0.2, clean_rec, 0.4, label='Clean', alpha=0.8, color='#9b59b6')
ax3.bar(x + 0.2, adv_rec, 0.4, label='Adversarial', alpha=0.8, color='#34495e')
ax3.set_ylabel('Recall (%)'); ax3.set_title('Recall Comparison', fontweight='bold')
ax3.set_xticks(x); ax3.set_xticklabels(names, rotation=15)
ax3.legend(); ax3.grid(alpha=0.3, axis='y')

# Plot 4: F1 Score
ax4 = fig.add_subplot(gs[1, 0])
clean_f1 = [r['clean_metrics']['f1_score'] for r in results]
adv_f1 = [r['adversarial_metrics']['f1_score'] for r in results]
ax4.bar(x - 0.2, clean_f1, 0.4, label='Clean', alpha=0.8, color='#1abc9c')
ax4.bar(x + 0.2, adv_f1, 0.4, label='Adversarial', alpha=0.8, color='#c0392b')
ax4.set_ylabel('F1-Score (%)'); ax4.set_title('F1-Score Comparison', fontweight='bold')
ax4.set_xticks(x); ax4.set_xticklabels(names, rotation=15)
ax4.legend(); ax4.grid(alpha=0.3, axis='y')

# Plot 5: RMSE
ax5 = fig.add_subplot(gs[1, 1])
clean_rmse = [r['clean_metrics']['rmse'] for r in results]
adv_rmse = [r['adversarial_metrics']['rmse'] for r in results]
ax5.bar(x - 0.2, clean_rmse, 0.4, label='Clean', alpha=0.8, color='#f39c12')
ax5.bar(x + 0.2, adv_rmse, 0.4, label='Adversarial', alpha=0.8, color='#d35400')
ax5.set_ylabel('RMSE'); ax5.set_title('RMSE Comparison', fontweight='bold')
ax5.set_xticks(x); ax5.set_xticklabels(names, rotation=15)
ax5.legend(); ax5.grid(alpha=0.3, axis='y')

# Plot 6: Attack Success Rate
ax6 = fig.add_subplot(gs[1, 2])
asr = [r['attack_success_rate'] for r in results]
colors = ['#e74c3c', '#e67e22', '#f39c12', '#d35400', '#c0392b']
ax6.bar(names, asr, color=colors, alpha=0.8)
ax6.set_ylabel('ASR (%)'); ax6.set_title('Attack Success Rate', fontweight='bold')
ax6.set_xticklabels(names, rotation=15); ax6.grid(alpha=0.3, axis='y')

plt.savefig(f"{SAVE_DIR}/attack_results_extended.png", dpi=300, bbox_inches='tight')
plt.close()

# Print summary table
print(f"\n{'='*70}")
print(" COMPLETE! - DETAILED RESULTS")
print('='*70)
print(f"✅ JSON: {SAVE_DIR}/attack_results_extended.json")
print(f"✅ Plot: {SAVE_DIR}/attack_results_extended.png")
print(f"\n{'='*120}")
print(f"{'Attack':<12} | {'Clean Acc':>9} | {'Adv Acc':>9} | {'Clean Prec':>10} | {'Adv Prec':>10} | {'Clean Rec':>10} | {'Adv Rec':>10} | {'Clean F1':>9} | {'Adv F1':>9} | {'ASR':>6}")
print('='*120)
for r in results:
    print(f"{r['attack']:<12} | {r['clean_metrics']['accuracy']:>8.2f}% | {r['adversarial_metrics']['accuracy']:>8.2f}% | "
          f"{r['clean_metrics']['precision']:>9.2f}% | {r['adversarial_metrics']['precision']:>9.2f}% | "
          f"{r['clean_metrics']['recall']:>9.2f}% | {r['adversarial_metrics']['recall']:>9.2f}% | "
          f"{r['clean_metrics']['f1_score']:>8.2f}% | {r['adversarial_metrics']['f1_score']:>8.2f}% | {r['attack_success_rate']:>5.2f}%")
print('='*120)

Installing dependencies...
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 115.9/115.9 kB 4.3 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 363.4/363.4 MB 5.2 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 119.4 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.6/24.6 MB 93.6 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 883.7/883.7 kB 48.1 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 2.8 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.5/211.5 MB 8.7 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 35.2 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 127.9/127.9 MB 15.0 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.5/207.5 MB 7.8 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 9.4 MB/s eta 0:00:00


ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
libcugraph-cu12 25.6.0 requires libraft-cu12==25.6.*, but you have libraft-cu12 25.2.0 which is incompatible.
pylibcugraph-cu12 25.6.0 requires pylibraft-cu12==25.6.*, but you have pylibraft-cu12 25.2.0 which is incompatible.
pylibcugraph-cu12 25.6.0 requires rmm-cu12==25.6.*, but you have rmm-cu12 25.2.0 which is incompatible.


   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.8/44.8 kB 1.7 MB/s eta 0:00:00
✅ All dependencies loaded!

 LOADING DATA


100%|██████████| 38.2M/38.2M [00:09<00:00, 4.14MB/s]


✅ Loaded 17778 test samples | Classes: 11

 LOADING MODEL


100%|████████████████████████████████████████| 338M/338M [00:01<00:00, 315MiB/s]


✅ Model loaded successfully

 RUNNING 5 ATTACKS WITH EXTENDED METRICS

[FGSM]


                                                                                                    

  CLEAN METRICS:
    Accuracy:  90.12%
    Precision: 90.20%
    Recall:    90.12%
    F1-Score:  89.95%
    RMSE:      0.1210
  ADVERSARIAL METRICS:
    Accuracy:  56.25%
    Precision: 54.72%
    Recall:    56.25%
    F1-Score:  54.64%
    RMSE:      0.2759
  ASR: 34.69%

[PGD]


  PGD         :   3%|█▏                                            | 14/556 [00:12<07:56,  1.14it/s]