In [None]:
# ============================================================================
# UNIFIED ATTACK EVALUATION - FedAvg, FedProx, FedBN
# Three Attack Levels: Weak, Mid, Strong
# ============================================================================
print("Installing dependencies...")
import subprocess
subprocess.run(["pip", "install", "medmnist", "--quiet"], check=True)
subprocess.run(["pip", "install", "git+https://github.com/openai/CLIP.git", "--quiet"], check=True)
subprocess.run(["pip", "install", "scikit-learn", "--quiet"], check=True)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from medmnist.dataset import PathMNIST, TissueMNIST, OrganAMNIST, OCTMNIST
import clip
from tqdm import tqdm
import json
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, mean_squared_error
import os
import shutil

print("‚úÖ All dependencies loaded!")

# ============================================================================
# CONFIGURATION - CHANGE ONLY THESE
# ============================================================================
ATTACK_LEVEL = "strong"  # Options: "weak", "mid", "strong"
DATASET_NAME = "organamnist"  # Options: "pathmnist", "tissuemnist", "organamnist", "octmnist"
BATCH_SIZE = 32

# Model checkpoint paths
MODEL_PATHS = {
    'fedavg': "/kaggle/input/example/pytorch/default/1/Example/checkpoints_fedavg_organamnist/final_global_model.pth",
    'fedprox': "/kaggle/input/example/pytorch/default/1/Example/checkpoints_fedprox_organamnist/final_global_model.pth",
    'fedbn': "/kaggle/input/example3/pytorch/default/1/weight_history/fedbn/checkpoints_fedbn_pathmnist/final_global_model.pth",
    'fedper': "/kaggle/input/example3/pytorch/default/1/weight_history/fedper/checkpoints_fedper_organamnist/final_global_model.pth",
}
# ============================================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Attack parameter configurations
ATTACK_LEVELS = {
    'weak': {
        'name': 'attack-level-weak',
        'eps': 0.010,
        'steps': 10,
        'alpha': 0.0006
    },
    'mid': {
        'name': 'attack-level-mid',
        'eps': 0.030,
        'steps': 10,
        'alpha': 0.0010
    },
    'strong': {
        'name': 'attack-level-strong',
        'eps': 0.050,
        'steps': 10,
        'alpha': 0.0030
    }
}

# Get current attack parameters
attack_config = ATTACK_LEVELS[ATTACK_LEVEL]
EPS = attack_config['eps']
ALPHA = attack_config['alpha']
STEPS = attack_config['steps']
MAIN_DIR = f"/kaggle/working/{attack_config['name']}(eps={EPS:.3f},alpha={ALPHA:.4f})"

print(f"\n{'='*70}")
print(f" ATTACK CONFIGURATION: {ATTACK_LEVEL.upper()}")
print('='*70)
print(f"  Epsilon (eps):   {EPS}")
print(f"  Alpha:           {ALPHA}")
print(f"  Steps:           {STEPS}")
print(f"  Main Directory:  {MAIN_DIR}")
print('='*70)

# Dataset configurations
DATASET_CONFIGS = {
    'pathmnist': {
        'num_classes': 9,
        'class': PathMNIST,
        'class_names': ["adipose", "background", "debris", "lymphocytes",
                       "mucus", "smooth muscle", "normal colon mucosa",
                       "cancer-associated stroma", "colorectal adenocarcinoma epithelium"]
    },
    'tissuemnist': {
        'num_classes': 8,
        'class': TissueMNIST,
        'class_names': ["collecting duct", "thick ascending limb",
                       "distal convoluted tubule", "proximal tubule",
                       "glomerular tuft", "blood vessel", "macula densa",
                       "interstitial fibrosis"]
    },
    'organamnist': {
        'num_classes': 11,
        'class': OrganAMNIST,
        'class_names': ["bladder", "femur-left", "femur-right", "heart",
                       "kidneys", "liver", "lungs", "pancreas",
                       "pelvis", "spleen", "kidney cyst"]
    },
    'octmnist': {
        'num_classes': 4,
        'class': OCTMNIST,
        'class_names': ["choroidal neovascularization", "diabetic macular edema",
                       "drusen", "normal"]
    }
}

# ============================================================================
# MODEL ARCHITECTURES
# ============================================================================

class CLIPFedPerClassifier(nn.Module):
   
    def __init__(self, num_classes, clip_model_name="ViT-B/32", dropout=0.3, device='cuda'):
        super(CLIPFedPerClassifier, self).__init__()
        
        self.num_classes = num_classes
        self.dropout = dropout
        
        # Load pretrained CLIP model
        self.clip_model, _ = clip.load(clip_model_name, device=device)
        
        # Get CLIP feature dimension
        self.feature_dim = self.clip_model.visual.output_dim
        
        # Setup FedPer layer configuration
        self._setup_fedper_layers()
        
        # Personalized MLP Head (NOT aggregated across clients)
        # This is client-specific and remains local
        self.head = nn.Sequential(
            nn.Linear(self.feature_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(dropout),
            
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout * 0.7),
            
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            
            nn.Linear(256, num_classes)
        )
        
        self._initialize_weights()
    
    def _setup_fedper_layers(self):
      
        # Freeze all CLIP parameters first
        for param in self.clip_model.parameters():
            param.requires_grad = False
        
        # Unfreeze last 2 transformer blocks (SHARED layers for aggregation)
        if hasattr(self.clip_model.visual, 'transformer'):
            num_blocks = len(self.clip_model.visual.transformer.resblocks)
            num_shared = 2  # Last 2 blocks are shared
            
            for block in self.clip_model.visual.transformer.resblocks[-num_shared:]:
                for param in block.parameters():
                    param.requires_grad = True
        
        # Also unfreeze final projection layer (shared)
        if hasattr(self.clip_model.visual, 'proj') and self.clip_model.visual.proj is not None:
            self.clip_model.visual.proj.requires_grad = True
    
    def _initialize_weights(self):
       
        for m in self.head.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, images):
      
        # CLIP feature extraction (last 2 blocks are trainable)
        image_features = self.clip_model.encode_image(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        image_features = image_features.float()
        
        # Personalized classification head
        logits = self.head(image_features)
        return logits
    
    def get_shared_params(self):
       
        shared_params = []
        for param in self.clip_model.parameters():
            if param.requires_grad:
                shared_params.append(param)
        return shared_params
    
    def get_personalized_params(self):
      
        return [p for p in self.head.parameters() if p.requires_grad]
    
    def get_trainable_params(self):
       
        return [p for p in self.parameters() if p.requires_grad]


class CLIPFedAvgClassifier(nn.Module):
    """FedAvg: CLIP + Text-based Classification"""
    def __init__(self, num_classes, device, class_names):
        super().__init__()
        self.clip_model, _ = clip.load("ViT-B/32", device=device)
        self.num_classes = num_classes
        
        # Freeze text encoder
        for param in self.clip_model.transformer.parameters():
            param.requires_grad = False
        for param in self.clip_model.token_embedding.parameters():
            param.requires_grad = False
        for param in self.clip_model.ln_final.parameters():
            param.requires_grad = False
        self.clip_model.positional_embedding.requires_grad = False
        self.clip_model.text_projection.requires_grad = False
        
        # Register text features
        if class_names:
            with torch.no_grad():
                text_tokens = clip.tokenize([f"a microscopic image of {c}" for c in class_names]).to(device)
                text_features = self.clip_model.encode_text(text_tokens)
                text_features /= text_features.norm(dim=-1, keepdim=True)
            self.register_buffer('text_features', text_features)
    
    def forward(self, images):
        image_features = self.clip_model.encode_image(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        return 100.0 * image_features @ self.text_features.T


class CLIPFedProxClassifier(nn.Module):
    """FedProx: CLIP + Trainable Head"""
    def __init__(self, num_classes, clip_model_name="ViT-B/32", dropout=0.5):
        super().__init__()
        self.clip_model, _ = clip.load(clip_model_name, device=device)
        
        # Freeze CLIP
        for param in self.clip_model.parameters():
            param.requires_grad = False
        
        feature_dim = self.clip_model.visual.output_dim
        
        # Trainable head
        self.head = nn.Sequential(
            nn.Linear(feature_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(dropout),
            
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout * 0.7),
            
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            
            nn.Linear(256, num_classes)
        )
    
    def forward(self, images):
        with torch.no_grad():
            feats = self.clip_model.encode_image(images)
            feats = feats / feats.norm(dim=-1, keepdim=True)
        feats = feats.float()
        return self.head(feats)


class CLIPFedBNClassifier(nn.Module):
    """FedBN: CLIP + Simple Head with BN"""
    def __init__(self, num_classes, clip_model_name="ViT-B/32", dropout=0.5):
        super().__init__()
        self.clip_model, _ = clip.load(clip_model_name, device=device)
        
        # Freeze CLIP
        for param in self.clip_model.parameters():
            param.requires_grad = False
        
        feature_dim = self.clip_model.visual.output_dim
        
        self.fc1 = nn.Linear(feature_dim, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, num_classes)
        self.bn2 = nn.BatchNorm1d(num_classes)
        self.relu = nn.ReLU()
        self.drop = nn.Dropout(dropout)
    
    def forward(self, x):
        with torch.no_grad():
            img_feat = self.clip_model.encode_image(x).float()
        x = self.fc1(img_feat)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.bn2(x)
        return x


# ============================================================================
# ATTACK FUNCTIONS 
# ============================================================================



class AttackableWrapper(nn.Module):
    """Wrapper to bypass torch.no_grad() context in model forward"""
    def __init__(self, model, model_type):
        super().__init__()
        self.model = model
        self.model_type = model_type
        
        # Store original gradient states
        self.original_grad_states = {}
        
        # Enable gradients based on model type
        if model_type == 'fedprox':
            for name, param in self.model.clip_model.visual.named_parameters():
                self.original_grad_states[f'visual.{name}'] = param.requires_grad
                param.requires_grad = True
            for name, param in self.model.head.named_parameters():
                self.original_grad_states[f'head.{name}'] = param.requires_grad
                param.requires_grad = True
                
        elif model_type == 'fedbn':
            for name, param in self.model.clip_model.visual.named_parameters():
                self.original_grad_states[f'visual.{name}'] = param.requires_grad
                param.requires_grad = True
            for name, param in self.model.fc1.named_parameters():
                self.original_grad_states[f'fc1.{name}'] = param.requires_grad
                param.requires_grad = True
            for name, param in self.model.fc2.named_parameters():
                self.original_grad_states[f'fc2.{name}'] = param.requires_grad
                param.requires_grad = True
                
        elif model_type == 'fedavg':
            for name, param in self.model.clip_model.visual.named_parameters():
                self.original_grad_states[f'visual.{name}'] = param.requires_grad
                param.requires_grad = True
    
    def forward(self, x):
        """Forward pass WITHOUT torch.no_grad() context"""
        if self.model_type == 'fedprox':
            # Manual forward to bypass no_grad
            feats = self.model.clip_model.encode_image(x)
            feats = feats / feats.norm(dim=-1, keepdim=True)
            feats = feats.float()
            return self.model.head(feats)
            
        elif self.model_type == 'fedbn':
            # Manual forward to bypass no_grad
            img_feat = self.model.clip_model.encode_image(x).float()
            x = self.model.fc1(img_feat)
            x = self.model.bn1(x)
            x = self.model.relu(x)
            x = self.model.drop(x)
            x = self.model.fc2(x)
            x = self.model.bn2(x)
            return x
            
        elif self.model_type == 'fedavg':
            # FedAvg uses text features
            image_features = self.model.clip_model.encode_image(x)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            return 100.0 * image_features @ self.model.text_features.T
        
        else:
            return self.model(x)
    
    def restore_frozen_state(self):
        """Restore original frozen state after attack"""
        if self.model_type == 'fedprox':
            for name, param in self.model.clip_model.visual.named_parameters():
                key = f'visual.{name}'
                if key in self.original_grad_states:
                    param.requires_grad = self.original_grad_states[key]
            for name, param in self.model.head.named_parameters():
                key = f'head.{name}'
                if key in self.original_grad_states:
                    param.requires_grad = self.original_grad_states[key]
                    
        elif self.model_type == 'fedbn':
            for name, param in self.model.clip_model.visual.named_parameters():
                key = f'visual.{name}'
                if key in self.original_grad_states:
                    param.requires_grad = self.original_grad_states[key]
            for name, param in self.model.fc1.named_parameters():
                key = f'fc1.{name}'
                if key in self.original_grad_states:
                    param.requires_grad = self.original_grad_states[key]
            for name, param in self.model.fc2.named_parameters():
                key = f'fc2.{name}'
                if key in self.original_grad_states:
                    param.requires_grad = self.original_grad_states[key]
                    
        elif self.model_type == 'fedavg':
            for name, param in self.model.clip_model.visual.named_parameters():
                key = f'visual.{name}'
                if key in self.original_grad_states:
                    param.requires_grad = self.original_grad_states[key]


def fgsm_attack(model, x, y, model_type, eps):
    """FGSM Attack - CORRECTED"""
    attack_model = AttackableWrapper(model, model_type)
    attack_model.train()
    
    x_adv = x.clone().detach().requires_grad_(True)
    
    # Forward pass
    outputs = attack_model(x_adv)
    loss = F.cross_entropy(outputs, y)
    
    # Backward pass
    attack_model.zero_grad()
    if x_adv.grad is not None:
        x_adv.grad.zero_()
    
    loss.backward()
    
    # Check gradient
    if x_adv.grad is None:
        attack_model.restore_frozen_state()
        return x.detach()
    
    # Generate adversarial example
    with torch.no_grad():
        grad_sign = x_adv.grad.sign()
        x_adv = x + eps * grad_sign
        x_adv = torch.clamp(x_adv, -1, 1)
    
    # Restore frozen state
    attack_model.restore_frozen_state()
    
    return x_adv.detach()


def pgd_attack(model, x, y, model_type, eps, alpha, steps):
    """PGD Attack - CORRECTED"""
    attack_model = AttackableWrapper(model, model_type)
    
    # Random initialization
    with torch.no_grad():
        delta = torch.empty_like(x).uniform_(-eps, eps)
        x_adv = torch.clamp(x + delta, -1, 1)
    
    for step in range(steps):
        attack_model.train()
        x_adv = x_adv.clone().detach().requires_grad_(True)
        
        # Forward pass
        outputs = attack_model(x_adv)
        loss = F.cross_entropy(outputs, y)
        
        # Backward pass
        attack_model.zero_grad()
        if x_adv.grad is not None:
            x_adv.grad.zero_()
        
        loss.backward()
        
        if x_adv.grad is None:
            break
        
        # Update adversarial example
        with torch.no_grad():
            grad_sign = x_adv.grad.sign()
            x_adv = x_adv + alpha * grad_sign
            
            # Project back to epsilon ball
            eta = torch.clamp(x_adv - x, -eps, eps)
            x_adv = torch.clamp(x + eta, -1, 1)
    
    # Restore frozen state
    attack_model.restore_frozen_state()
    
    return x_adv.detach()


def bim_attack(model, x, y, model_type, eps, alpha, steps):
    """BIM Attack - CORRECTED"""
    attack_model = AttackableWrapper(model, model_type)
    x_adv = x.clone().detach()
    
    for step in range(steps):
        attack_model.train()
        x_adv = x_adv.clone().detach().requires_grad_(True)
        
        # Forward pass
        outputs = attack_model(x_adv)
        loss = F.cross_entropy(outputs, y)
        
        # Backward pass
        attack_model.zero_grad()
        if x_adv.grad is not None:
            x_adv.grad.zero_()
        
        loss.backward()
        
        if x_adv.grad is None:
            break
        
        # Update adversarial example
        with torch.no_grad():
            grad_sign = x_adv.grad.sign()
            x_adv = x_adv + alpha * grad_sign
            
            # Project back to epsilon ball
            eta = torch.clamp(x_adv - x, -eps, eps)
            x_adv = torch.clamp(x + eta, -1, 1)
    
    # Restore frozen state
    attack_model.restore_frozen_state()
    
    return x_adv.detach()


def mifgsm_attack(model, x, y, model_type, eps, alpha, steps, decay=1.0):
    """MI-FGSM Attack - CORRECTED"""
    attack_model = AttackableWrapper(model, model_type)
    momentum = torch.zeros_like(x)
    x_adv = x.clone().detach()
    
    for step in range(steps):
        attack_model.train()
        x_adv = x_adv.clone().detach().requires_grad_(True)
        
        # Forward pass
        outputs = attack_model(x_adv)
        loss = F.cross_entropy(outputs, y)
        
        # Backward pass
        attack_model.zero_grad()
        if x_adv.grad is not None:
            x_adv.grad.zero_()
        
        loss.backward()
        
        if x_adv.grad is None:
            break
        
        # Update momentum
        grad = x_adv.grad.data
        grad_norm = grad / (torch.mean(torch.abs(grad), dim=(1,2,3), keepdim=True) + 1e-8)
        momentum = decay * momentum + grad_norm
        
        # Update adversarial example
        with torch.no_grad():
            x_adv = x_adv + alpha * momentum.sign()
            
            # Project back to epsilon ball
            eta = torch.clamp(x_adv - x, -eps, eps)
            x_adv = torch.clamp(x + eta, -1, 1)
    
    # Restore frozen state
    attack_model.restore_frozen_state()
    
    return x_adv.detach()

# ============================================================================
# EVALUATION FUNCTION
# ============================================================================

def evaluate_model_with_attacks(model, test_loader, model_type, num_classes):
    """Evaluate a model with all attacks"""
    
    attacks = {
        'FGSM': lambda x, y: fgsm_attack(model, x, y, model_type, EPS),
        'PGD': lambda x, y: pgd_attack(model, x, y, model_type, EPS, ALPHA, STEPS),
        'BIM': lambda x, y: bim_attack(model, x, y, model_type, EPS, ALPHA, STEPS),
        'MI-FGSM': lambda x, y: mifgsm_attack(model, x, y, model_type, EPS, ALPHA, STEPS),
        
    }
    
    results = []
    
    for attack_name, attack_fn in attacks.items():
        print(f"\n  [{attack_name}]")
        
        clean_preds, clean_labels = [], []
        adv_preds, adv_labels = [], []
        clean_probs, adv_probs = [], []
        
        for imgs, lbls in tqdm(test_loader, desc=f"    {attack_name:12s}", ncols=100, leave=False):
            imgs, lbls = imgs.to(device), lbls.to(device).squeeze()
            
            # Clean predictions
            model.eval()
            with torch.no_grad():
                outputs = model(imgs)
                probs = F.softmax(outputs, dim=1)
                preds = outputs.max(1)[1]
                clean_preds.extend(preds.cpu().numpy())
                clean_labels.extend(lbls.cpu().numpy())
                clean_probs.extend(probs.cpu().numpy())
            
            # Adversarial predictions
            try:
                adv_imgs = attack_fn(imgs, lbls)
                model.eval()
                with torch.no_grad():
                    outputs = model(adv_imgs)
                    probs = F.softmax(outputs, dim=1)
                    preds = outputs.max(1)[1]
                    adv_preds.extend(preds.cpu().numpy())
                    adv_labels.extend(lbls.cpu().numpy())
                    adv_probs.extend(probs.cpu().numpy())
            except Exception as e:
                adv_preds.extend(clean_preds[-len(lbls):])
                adv_labels.extend(lbls.cpu().numpy())
                adv_probs.extend(clean_probs[-len(lbls):])
        
        # Calculate metrics
        clean_preds = np.array(clean_preds)
        clean_labels = np.array(clean_labels)
        adv_preds = np.array(adv_preds)
        adv_labels = np.array(adv_labels)
        clean_probs = np.array(clean_probs)
        adv_probs = np.array(adv_probs)
        
        # Accuracy
        clean_acc = 100 * (clean_preds == clean_labels).mean()
        adv_acc = 100 * (adv_preds == adv_labels).mean()
        asr = 100 * (clean_preds != adv_preds).mean()
        
        # Precision
        clean_prec = 100 * precision_score(clean_labels, clean_preds, average='weighted', zero_division=0)
        adv_prec = 100 * precision_score(adv_labels, adv_preds, average='weighted', zero_division=0)
        
        # Recall
        clean_rec = 100 * recall_score(clean_labels, clean_preds, average='weighted', zero_division=0)
        adv_rec = 100 * recall_score(adv_labels, adv_preds, average='weighted', zero_division=0)
        
        # F1-Score
        clean_f1 = 100 * f1_score(clean_labels, clean_preds, average='weighted', zero_division=0)
        adv_f1 = 100 * f1_score(adv_labels, adv_preds, average='weighted', zero_division=0)
        
        # RMSE
        clean_rmse = np.sqrt(mean_squared_error(
            np.eye(num_classes)[clean_labels], clean_probs
        ))
        adv_rmse = np.sqrt(mean_squared_error(
            np.eye(num_classes)[adv_labels], adv_probs
        ))
        
        print(f"    Clean: {clean_acc:.2f}% | Adv: {adv_acc:.2f}% | ASR: {asr:.2f}%")
        
        results.append({
            'attack': attack_name,
            'clean_metrics': {
                'accuracy': round(clean_acc, 2),
                'precision': round(clean_prec, 2),
                'recall': round(clean_rec, 2),
                'f1_score': round(clean_f1, 2),
                'rmse': round(clean_rmse, 4)
            },
            'adversarial_metrics': {
                'accuracy': round(adv_acc, 2),
                'precision': round(adv_prec, 2),
                'recall': round(adv_rec, 2),
                'f1_score': round(adv_f1, 2),
                'rmse': round(adv_rmse, 4)
            },
            'attack_success_rate': round(asr, 2)
        })
    
    return results


# ============================================================================
# VISUALIZATION
# ============================================================================

def create_visualization(results, save_path, model_name):
    """Create comprehensive visualization"""
    fig = plt.figure(figsize=(18, 12))
    gs = fig.add_gridspec(3, 3, hspace=0.35, wspace=0.35)
    fig.suptitle(f'{model_name.upper()} - Attack Evaluation (eps={EPS}, alpha={ALPHA})', 
                 fontsize=16, fontweight='bold')
    
    names = [r['attack'] for r in results]
    x = np.arange(len(names))
    
    # Plot 1: Accuracy
    ax = fig.add_subplot(gs[0, 0])
    clean_acc = [r['clean_metrics']['accuracy'] for r in results]
    adv_acc = [r['adversarial_metrics']['accuracy'] for r in results]
    ax.bar(x - 0.2, clean_acc, 0.4, label='Clean', color='#2ecc71', alpha=0.8)
    ax.bar(x + 0.2, adv_acc, 0.4, label='Adversarial', color='#e74c3c', alpha=0.8)
    ax.set_ylabel('Accuracy (%)', fontweight='bold')
    ax.set_title('Accuracy Comparison')
    ax.set_xticks(x)
    ax.set_xticklabels(names, rotation=15, ha='right')
    ax.legend()
    ax.grid(alpha=0.3, axis='y')
    
    # Plot 2: Precision
    ax = fig.add_subplot(gs[0, 1])
    clean_prec = [r['clean_metrics']['precision'] for r in results]
    adv_prec = [r['adversarial_metrics']['precision'] for r in results]
    ax.bar(x - 0.2, clean_prec, 0.4, label='Clean', color='#3498db', alpha=0.8)
    ax.bar(x + 0.2, adv_prec, 0.4, label='Adversarial', color='#e67e22', alpha=0.8)
    ax.set_ylabel('Precision (%)', fontweight='bold')
    ax.set_title('Precision Comparison')
    ax.set_xticks(x)
    ax.set_xticklabels(names, rotation=15, ha='right')
    ax.legend()
    ax.grid(alpha=0.3, axis='y')
    
    # Plot 3: Recall
    ax = fig.add_subplot(gs[0, 2])
    clean_rec = [r['clean_metrics']['recall'] for r in results]
    adv_rec = [r['adversarial_metrics']['recall'] for r in results]
    ax.bar(x - 0.2, clean_rec, 0.4, label='Clean', color='#9b59b6', alpha=0.8)
    ax.bar(x + 0.2, adv_rec, 0.4, label='Adversarial', color='#34495e', alpha=0.8)
    ax.set_ylabel('Recall (%)', fontweight='bold')
    ax.set_title('Recall Comparison')
    ax.set_xticks(x)
    ax.set_xticklabels(names, rotation=15, ha='right')
    ax.legend()
    ax.grid(alpha=0.3, axis='y')
    
    # Plot 4: F1-Score
    ax = fig.add_subplot(gs[1, 0])
    clean_f1 = [r['clean_metrics']['f1_score'] for r in results]
    adv_f1 = [r['adversarial_metrics']['f1_score'] for r in results]
    ax.bar(x - 0.2, clean_f1, 0.4, label='Clean', color='#1abc9c', alpha=0.8)
    ax.bar(x + 0.2, adv_f1, 0.4, label='Adversarial', color='#c0392b', alpha=0.8)
    ax.set_ylabel('F1-Score (%)', fontweight='bold')
    ax.set_title('F1-Score Comparison')
    ax.set_xticks(x)
    ax.set_xticklabels(names, rotation=15, ha='right')
    ax.legend()
    ax.grid(alpha=0.3, axis='y')
    
    # Plot 5: RMSE
    ax = fig.add_subplot(gs[1, 1])
    clean_rmse = [r['clean_metrics']['rmse'] for r in results]
    adv_rmse = [r['adversarial_metrics']['rmse'] for r in results]
    ax.bar(x - 0.2, clean_rmse, 0.4, label='Clean', color='#f39c12', alpha=0.8)
    ax.bar(x + 0.2, adv_rmse, 0.4, label='Adversarial', color='#d35400', alpha=0.8)
    ax.set_ylabel('RMSE', fontweight='bold')
    ax.set_title('RMSE Comparison')
    ax.set_xticks(x)
    ax.set_xticklabels(names, rotation=15, ha='right')
    ax.legend()
    ax.grid(alpha=0.3, axis='y')
    
    # Plot 6: Attack Success Rate
    ax = fig.add_subplot(gs[1, 2])
    asr = [r['attack_success_rate'] for r in results]
    ax.bar(names, asr, color=['#e74c3c', '#e67e22', '#f39c12', '#d35400'], alpha=0.8)
    ax.set_ylabel('ASR (%)', fontweight='bold')
    ax.set_title('Attack Success Rate')
    ax.set_xticklabels(names, rotation=15, ha='right')
    ax.grid(alpha=0.3, axis='y')
    
    # Plot 7-9: Metric Degradation
    metrics_to_plot = [
        ('accuracy', 'Accuracy Drop', gs[2, 0], '#e74c3c'),
        ('precision', 'Precision Drop', gs[2, 1], '#3498db'),
        ('f1_score', 'F1-Score Drop', gs[2, 2], '#9b59b6')
    ]
    
    for metric, title, position, color in metrics_to_plot:
        ax = fig.add_subplot(position)
        drops = [r['clean_metrics'][metric] - r['adversarial_metrics'][metric] for r in results]
        ax.bar(names, drops, color=color, alpha=0.8)
        ax.set_ylabel('Drop (%)', fontweight='bold')
        ax.set_title(title)
        ax.set_xticklabels(names, rotation=15, ha='right')
        ax.grid(alpha=0.3, axis='y')
        ax.axhline(y=0, color='black', linestyle='--', linewidth=0.8, alpha=0.5)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()


# ============================================================================
# MAIN EXECUTION
# ============================================================================

def main():
    """Main execution function"""
    
    os.makedirs(MAIN_DIR, exist_ok=True)
    data_root = '/kaggle/working/medmnist_data'
    os.makedirs(data_root, exist_ok=True)
    
    # Load dataset
    print(f"\n{'='*70}")
    print(" LOADING DATASET")
    print('='*70)
    
    config = DATASET_CONFIGS[DATASET_NAME.lower()]
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    test_dataset = config['class'](root=data_root, split='test', 
                                    download=True, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    print(f"‚úÖ Loaded {len(test_dataset)} test samples | Classes: {config['num_classes']}")
    
    # Model configurations
    model_configs = {
    'fedprox': {
        'class': CLIPFedProxClassifier,
        'init_args': {
            'num_classes': config['num_classes'],
            'dropout': 0.5
        }
    },

     'fedbn': {
         'class': CLIPFedBNClassifier,
         'init_args': {
             'num_classes': config['num_classes'],
             'dropout': 0.5
         }
     },

     'fedavg': {
        'class': CLIPFedAvgClassifier,
         'init_args': {
             'num_classes': config['num_classes'],
             'device': device,
            'class_names': config['class_names']
        }
     },

    'fedper': {
        'class': CLIPFedPerClassifier,
        'init_args': {
            'num_classes': config['num_classes'],
            'dropout': 0.3,
            'device': device
        }
     }
}

    
    all_results = {}
    
    for model_name in [ 'fedprox','fedper','fedpbn','fedpavg']:
        print(f"\n{'='*70}")
        print(f" EVALUATING: {model_name.upper()}")
        print('='*70)
        
        # Check if model path exists
        if not os.path.exists(MODEL_PATHS[model_name]):
            print(f"‚ö†Ô∏è  Model not found: {MODEL_PATHS[model_name]}")
            print(f"   Skipping {model_name}...")
            continue
        
        try:
            # Load model
            model_config = model_configs[model_name]
            model = model_config['class'](**model_config['init_args'])
            
            checkpoint = torch.load(MODEL_PATHS[model_name], map_location=device)
            if isinstance(checkpoint, dict):
                if 'server_state_dict' in checkpoint:
                    state_dict = checkpoint['server_state_dict']
                elif 'model_state_dict' in checkpoint:
                    state_dict = checkpoint['model_state_dict']
                else:
                    state_dict = checkpoint
            else:
                state_dict = checkpoint
            
            model.load_state_dict(state_dict, strict=False)
            model.to(device)
            model.eval()
            
            print(f"‚úÖ {model_name.upper()} model loaded successfully")
            
            # Run attacks
            results = evaluate_model_with_attacks(model, test_loader, model_name, config['num_classes'])
            all_results[model_name] = results
            
            # Save results
            save_dir = os.path.join(MAIN_DIR, f"{model_name}_{DATASET_NAME}_attack_results")
            os.makedirs(save_dir, exist_ok=True)
            
            # Save JSON
            with open(f"{save_dir}/attack_results.json", 'w') as f:
                json.dump(results, f, indent=2)
            
            # Create visualization
            create_visualization(results, f"{save_dir}/attack_results.png", model_name)
            
            # Print summary
            print(f"\n  DETAILED SUMMARY:")
            print(f"  {'-'*90}")
            print(f"  {'Attack':<12} | {'Clean':<40} | {'Adversarial':<40}")
            print(f"  {'-'*90}")
            for r in results:
                clean_str = f"Acc:{r['clean_metrics']['accuracy']:>5.1f}% Prec:{r['clean_metrics']['precision']:>5.1f}% Rec:{r['clean_metrics']['recall']:>5.1f}% F1:{r['clean_metrics']['f1_score']:>5.1f}%"
                adv_str = f"Acc:{r['adversarial_metrics']['accuracy']:>5.1f}% Prec:{r['adversarial_metrics']['precision']:>5.1f}% Rec:{r['adversarial_metrics']['recall']:>5.1f}% F1:{r['adversarial_metrics']['f1_score']:>5.1f}%"
                print(f"  {r['attack']:<12} | {clean_str:<40} | {adv_str:<40}")
            print(f"  {'-'*90}")
            print(f"\n  {'Attack':<12} | {'ASR':>8} | {'Clean RMSE':>12} | {'Adv RMSE':>12}")
            print(f"  {'-'*50}")
            for r in results:
                print(f"  {r['attack']:<12} | {r['attack_success_rate']:>7.2f}% | {r['clean_metrics']['rmse']:>12.4f} | {r['adversarial_metrics']['rmse']:>12.4f}")
            print(f"  {'-'*50}")
            
            print(f"\n  ‚úÖ Results saved to: {save_dir}/")
            
            # Clean up
            del model
            torch.cuda.empty_cache()
            
        except Exception as e:
            print(f"‚ùå Error evaluating {model_name}: {e}")
            continue
    
    # Final summary
    print(f"\n{'='*70}")
    print(" EVALUATION COMPLETE!")
    print('='*70)
    print(f"üìÅ All results saved in: {MAIN_DIR}/")
    print(f"\nProcessed models: {', '.join(all_results.keys())}")
    print('='*70)


if __name__ == "__main__":
    main()