How to Use This Version for Checkpoint Resume in Kaggle

Before running this notebook for the second time (to resume training from a checkpoint), follow the steps below carefully.

1. Download the latest checkpoint file
Download the file named checkpoint_latest.pth from the previous version of your notebook or experiment.

2. Upload the checkpoint to Kaggle Input Directory
Place the downloaded file inside your Kaggle input path, for example:
/kaggle/input/path1/pytorch/default/1/checkpoint_latest.pth

3. Run the following code cell before starting training
This code will copy the checkpoint file to the working directory (/kaggle/working/checkpoints) so that training can resume from the saved state.

4. Resume Training
After the checkpoint file is copied successfully, running the rest of the notebook will automatically start training from the previous checkpoint instead of starting from scratch.




## you can change the dataset and the attack type by simply changing the name in the args.py file.no need to modify anything else.

In [1]:
import os
import shutil

# Source and destination paths
src = "/kaggle/input/path1/pytorch/default/1/checkpoint_latest.pth"
dst_dir = "/kaggle/working/checkpoints"
dst = os.path.join(dst_dir, "checkpoint_latest.pth")

# Step 1: Check if source file exists
if not os.path.exists(src):
    print(f"‚ùå Source file not found: {src}")
else:
    print(f"‚úÖ Found source file: {src}")

    # Step 2: Ensure destination directory exists
    if not os.path.exists(dst_dir):
        os.makedirs(dst_dir)
        print(f"üìÇ Created destination directory: {dst_dir}")
    else:
        print(f"üìÅ Destination directory already exists: {dst_dir}")

    # Step 3: Copy the file
    shutil.copy(src, dst)
    print(f"‚úÖ Copied file to: {dst}")

    # Step 4: List all files in destination
    files = os.listdir(dst_dir)
    if files:
        print("\nüìÑ Files in /kaggle/working/checkpoints:")
        for f in files:
            print(" ‚îú‚îÄ‚îÄ", f)
    else:
        print("‚ö† Destination directory is empty (unexpected).")


‚ùå Source file not found: /kaggle/input/path1/pytorch/default/1/checkpoint_latest.pth


In [2]:
import os

# Target directory
base_dir = "/kaggle/working/FedAvg-PyTorch"
os.makedirs(base_dir, exist_ok=True)

# Python files to create
files = ["main.py", "server.py", "client.py", "model.py", "get_data.py", "args.py"]

# Create each file if not exists
for file in files:
    file_path = os.path.join(base_dir, file)
    if not os.path.exists(file_path):
        with open(file_path, "w") as f:
            f.write(f"# {file} ‚Äî auto-created placeholder\n")
        print(f" Created: {file_path}")
    else:
        print(f" Already exists: {file_path}")

print("\n Folder and files ready in:", base_dir)


 Already exists: /kaggle/working/FedAvg-PyTorch/main.py
 Already exists: /kaggle/working/FedAvg-PyTorch/server.py
 Already exists: /kaggle/working/FedAvg-PyTorch/client.py
 Already exists: /kaggle/working/FedAvg-PyTorch/model.py
 Already exists: /kaggle/working/FedAvg-PyTorch/get_data.py
 Already exists: /kaggle/working/FedAvg-PyTorch/args.py

 Folder and files ready in: /kaggle/working/FedAvg-PyTorch


In [3]:
rm -rf /kaggle/working/data


In [None]:
file_path = "/kaggle/working/FedAvg-PyTorch/args.py"

new_code = '''
# ========================================
# args.py ‚Äì FedAvg Configuration (with Attacks)
# ========================================
import argparse
import torch

DATASET_CONFIGS = {
    'pathmnist': {
        'num_classes': 9,
        'class_names': [
            "adipose tissue", "background", "debris", "lymphocytes",
            "mucus", "smooth muscle", "normal colon mucosa",
            "cancer-associated stroma", "colorecal adenocarcinoma epithelium"
        ],
        'input_channels': 3
    },
    'tissuemnist': {
        'num_classes': 8,
        'class_names': [
            "collecting duct", "distal convoluted tubule",
            "glomerular endothelial cells", "interstitial endothelial cells",
            "leukocytes", "podocytes", "proximal tubule", "thick ascending limb"
        ],
        'input_channels': 1
    },
    'organamnist': {
        'num_classes': 11,
        'class_names': [
            "bladder", "femur-left", "femur-right", "heart",
            "kidney-left", "kidney-right", "liver",
            "lung-left", "lung-right", "spleen", "pelvis"
        ],
        'input_channels': 1
    },
    'octmnist': {
        'num_classes': 4,
        'class_names': [
            "choroidal neovascularization", "diabetic macular edema",
            "drusen", "normal"
        ],
        'input_channels': 1
    }
}

def args_parser():
    parser = argparse.ArgumentParser(description="FedAvg - Config File")
    
    # Dataset
    parser.add_argument('--dataset', type=str, default='organamnist', 
                        choices=['pathmnist', 'tissuemnist', 'organamnist', 'octmnist'],
                        help='MedMNIST dataset to use')
    
    # Federated Learning Parameters
    parser.add_argument('--E', type=int, default=5, help='local epochs')
    parser.add_argument('--r', type=int, default=50, help='number of communication rounds')
    parser.add_argument('--K', type=int, default=5, help='total number of clients')
    parser.add_argument('--C', type=float, default=1, help='client sampling rate per round')
    parser.add_argument('--B', type=int, default=32, help='batch size')
    parser.add_argument('--use_combined', action='store_true')

    # Model parameters
    parser.add_argument('--clip_model', type=str, default='ViT-B/32')
    parser.add_argument('--freeze_clip', action='store_true')
    parser.add_argument('--dropout', type=float, default=0.5)

    # Optimizer Settings
    parser.add_argument('--lr', type=float, default=0.003)
    parser.add_argument('--optimizer', type=str, default='sgd')
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--momentum', type=float, default=0.9)

    # Checkpoint Settings
    parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints_fedavg_organamnist')
    parser.add_argument('--device', default=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    parser.add_argument('--dominant_ratio', type=float, default=0.7)

    # -------------------------------
    # Adversarial Attack Parameters
    # -------------------------------
    parser.add_argument('--attack_type', type=str, default='none', 
                        choices=['none', 'fgsm', 'pgd', 'cw'],
                        help='Type of adversarial attack during testing')
    parser.add_argument('--attack_epsilon', type=float, default=0.06,
                        help='Epsilon for FGSM/PGD attacks')
    parser.add_argument('--pgd_alpha', type=float, default=0.006,
                        help='Step size for PGD')
    parser.add_argument('--pgd_iters', type=int, default=10,
                        help='Number of PGD iterations')
    parser.add_argument('--cw_c', type=float, default=1.0,
                        help='C parameter for CW attack')
    parser.add_argument('--cw_max_iter', type=int, default=100,
                        help='Max iterations for CW attack')

    args = parser.parse_args(args=[])
    
    # Auto-configure based on selected dataset
    if args.dataset in DATASET_CONFIGS:
        config = DATASET_CONFIGS[args.dataset]
        args.num_classes = config['num_classes']
        args.input_channels = config['input_channels']
        args.class_names = config['class_names']
    else:
        raise ValueError(f"Dataset {args.dataset} not configured!")
    
    return args
'''

with open(file_path, "w") as f:
    f.write(new_code)

print(" args.py updated with attack parameters!")

 args.py updated with attack parameters!


In [5]:
file_path = "/kaggle/working/FedAvg-PyTorch/get_data.py"

new_code = r'''
# ========================================
# get_data.py ‚Äî True Balanced Non-IID Distribution
# Supports: pathmnist, tissuemnist, organamnist, octmnist
# ========================================
import os
import numpy as np
import torch
import pickle
from torch.utils.data import DataLoader, Subset, ConcatDataset
from torchvision import transforms
from medmnist import INFO
from medmnist.dataset import PathMNIST, TissueMNIST, OrganAMNIST, OCTMNIST

# Map dataset names to classes
DATASET_MAP = {
    'pathmnist': PathMNIST,
    'tissuemnist': TissueMNIST,
    'organamnist': OrganAMNIST,
    'octmnist': OCTMNIST
}

def get_transforms():
    """Returns train and test transforms (handles grayscale and RGB)"""
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.Grayscale(num_output_channels=3),  # ensures 3 channels
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    return train_transform, test_transform

def balanced_noniid_split(dataset, num_clients, dominant_ratio=0.7):
    """Balanced Non-IID split: dominant_ratio% to main client, rest distributed"""
    labels = np.array([dataset[i][1].item() for i in range(len(dataset))])
    num_classes = len(np.unique(labels))
    client_indices = [[] for _ in range(num_clients)]

    # Prepare class indices
    class_indices = {c: np.where(labels == c)[0].tolist() for c in range(num_classes)}
    for c in class_indices:
        np.random.shuffle(class_indices[c])

    # Step 1: assign dominant_ratio% of dominant class to primary client
    for client_id in range(num_clients):
        dominant_class = client_id % num_classes
        n_dominant = int(len(class_indices[dominant_class]) * dominant_ratio)
        if n_dominant > 0:
            client_indices[client_id].extend(class_indices[dominant_class][:n_dominant])
            class_indices[dominant_class] = class_indices[dominant_class][n_dominant:]

    # Step 2: distribute remaining samples equally among other clients
    for c in range(num_classes):
        remaining = class_indices[c]
        np.random.shuffle(remaining)
        other_clients = [i for i in range(num_clients) if i % num_classes != c]
        for i, idx in enumerate(remaining):
            client_id = other_clients[i % len(other_clients)]
            client_indices[client_id].append(idx)

    # Shuffle each client's indices
    for i in range(num_clients):
        np.random.shuffle(client_indices[i])

    return client_indices

def load_medmnist_data(args):
    """Load MedMNIST data - automatically configured based on args.dataset"""
    train_transform, test_transform = get_transforms()
    data_root = './data/medmnist'
    os.makedirs(data_root, exist_ok=True)

    data_flag = args.dataset.lower()
    if data_flag not in DATASET_MAP:
        raise ValueError(f"Dataset {data_flag} not supported. Choose from {list(DATASET_MAP.keys())}")

    n_classes = args.num_classes

    print("\n" + "="*60)
    print(f" Loading {data_flag.upper()} Dataset")
    print("="*60)
    print(f"Number of classes: {n_classes}")

    DatasetClass = DATASET_MAP[data_flag]

    # Load datasets
    train_dataset = DatasetClass(root=data_root, split='train', download=True, transform=train_transform)
    val_dataset = DatasetClass(root=data_root, split='val', download=True, transform=train_transform)
    test_dataset = DatasetClass(root=data_root, split='test', download=True, transform=test_transform)

    combined_train = ConcatDataset([train_dataset, val_dataset])
    print(f" Using Train+Val: {len(combined_train)} samples for federated learning")

    # Load or create client indices
    cache_file = f'./data/medmnist/client_indices_{data_flag}_K{args.K}_dr{args.dominant_ratio}.pkl'
    if os.path.exists(cache_file):
        print(f"\n Loading cached split from: {cache_file}")
        with open(cache_file, 'rb') as f:
            client_indices = pickle.load(f)
    else:
        print(f"\n Creating balanced Non-IID split (dominant_ratio={args.dominant_ratio})...")
        client_indices = balanced_noniid_split(combined_train, args.K, dominant_ratio=args.dominant_ratio)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        with open(cache_file, 'wb') as f:
            pickle.dump(client_indices, f)
        print(f" Split cached to: {cache_file}")

    # Print client-wise class distribution
    print("\n" + "="*60)
    print(" Client-wise Image Distribution")
    print("="*60)
    total_samples = 0
    for client_id, indices in enumerate(client_indices):
        client_labels = [combined_train[i][1].item() for i in indices]
        label_counts = np.bincount(client_labels, minlength=n_classes)
        total_client = len(indices)
        total_samples += total_client
        distribution_str = ", ".join([f"C{c}:{label_counts[c]}" for c in range(n_classes)])
        dominant_class = np.argmax(label_counts)
        dominant_pct = (label_counts[dominant_class] / total_client) * 100
        print(f"Client {client_id:2d}: {total_client:5d} samples | Dominant: Class {dominant_class} ({dominant_pct:.1f}%)")
        print(f"           [{distribution_str}]")

    print(f"\nTotal samples: {total_samples}")
    print("="*60 + "\n")

    # Create DataLoaders
    client_loaders = []
    for i, indices in enumerate(client_indices):
        subset = Subset(combined_train, indices)
        loader = DataLoader(subset, batch_size=args.B, shuffle=True, num_workers=0, pin_memory=True)
        client_loaders.append(loader)

    test_loader = DataLoader(test_dataset, batch_size=args.B, shuffle=False, num_workers=0, pin_memory=True)
    return client_loaders, test_loader


'''

with open(file_path, "w") as f:
    f.write(new_code)

print(" get_data.py updated!")


 get_data.py updated!


In [6]:
file_path = "/kaggle/working/FedAvg-PyTorch/client.py"
new_code = """
# ========================================
# client.py ‚Äî Client-side helper for FedAvg 
# ========================================
import torch
from torch import nn
import torch.optim as optim
from tqdm import tqdm
import copy

class Client:
    def __init__(self, model, train_loader, device, val_loader=None, lr=0.0001, weight_decay=5e-4):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.lr = lr
        self.weight_decay = weight_decay
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9, weight_decay=self.weight_decay)
    
    def compute_accuracy(self, loader):
        self.model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in loader:
                images, labels = images.to(self.device), labels.to(self.device)
                labels = labels.squeeze()
                outputs = self.model(images)
                if isinstance(outputs, tuple) or hasattr(outputs, 'logits'):
                    outputs = outputs.logits
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        return 100 * correct / total if total > 0 else 0
    
    def train(self, epochs=1):
        print(f"\\n Training {self.model.name} with FedAvg objective...")
        for epoch in range(epochs):
            self.model.train()
            total_loss = 0.0
            pbar = tqdm(self.train_loader,
                        desc=f"  Epoch {epoch+1}/{epochs}",
                        ncols=100,
                        leave=False)
            
            for images, labels in pbar:
                images, labels = images.to(self.device), labels.to(self.device)
                labels = labels.squeeze()
                
                self.optimizer.zero_grad()
                outputs = self.model(images)
                if isinstance(outputs, tuple) or hasattr(outputs, 'logits'):
                    outputs = outputs.logits
                
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            
            avg_loss = total_loss / len(self.train_loader)
            train_acc = self.compute_accuracy(self.train_loader)
            log_msg = f"  Epoch {epoch+1:2d}/{epochs} | Loss: {avg_loss:.4f} | Train Acc: {train_acc:.2f}%"
            
            if self.val_loader is not None:
                val_acc = self.compute_accuracy(self.val_loader)
                val_loss = 0
                self.model.eval()
                with torch.no_grad():
                    for images, labels in self.val_loader:
                        images, labels = images.to(self.device), labels.to(self.device)
                        labels = labels.squeeze()
                        outputs = self.model(images)
                        if isinstance(outputs, tuple) or hasattr(outputs, 'logits'):
                            outputs = outputs.logits
                        loss = self.criterion(outputs, labels)
                        val_loss += loss.item()
                val_loss /= len(self.val_loader)
                log_msg += f" | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%"
            
            print(log_msg)
        
        print(f" {self.model.name} local training complete (FedAvg)\\n")
        return self.model.state_dict()
"""

with open(file_path, "w") as f:
    f.write(new_code)

print(" client.py updated!")


 client.py updated!


In [7]:
file_path = "/kaggle/working/FedAvg-PyTorch/attacks.py"

new_code = """
# ========================================
# attacks.py ‚Äì Adversarial Attacks for FedAvg
# ========================================
import torch
import torch.nn as nn

class AdversarialAttacks:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.criterion = nn.CrossEntropyLoss()
    
    def fgsm_attack(self, images, labels, epsilon=0.03):
        '''FGSM Attack - Fixed version'''
        self.model.eval()
        
        # Clone and enable gradients
        images = images.clone().detach().to(self.device)
        labels = labels.to(self.device)
        images.requires_grad = True
        
        # Forward pass with gradient tracking
        outputs = self.model(images)
        loss = self.criterion(outputs, labels)
        
        # Compute gradients
        self.model.zero_grad()
        loss.backward()
        
        # Check if gradients exist
        if images.grad is None:
            print("    No gradients computed - attack failed")
            return images.detach()
        
        # Create perturbation
        perturbation = epsilon * images.grad.sign()
        
        # Create adversarial images
        adv_images = images + perturbation
        adv_images = torch.clamp(adv_images, -1, 1)
        
        return adv_images.detach()
    
    def pgd_attack(self, images, labels, epsilon=0.03, alpha=0.007, iters=10):
        '''PGD Attack - Fixed version'''
        self.model.eval()
        
        images = images.clone().detach().to(self.device)
        labels = labels.to(self.device)
        
        # Initialize with random perturbation
        adv_images = images.clone()
        adv_images = adv_images + torch.empty_like(adv_images).uniform_(-epsilon, epsilon)
        adv_images = torch.clamp(adv_images, -1, 1)
        
        for i in range(iters):
            adv_images.requires_grad = True
            
            # Forward pass
            outputs = self.model(adv_images)
            loss = self.criterion(outputs, labels)
            
            # Compute gradients
            self.model.zero_grad()
            if adv_images.grad is not None:
                adv_images.grad.zero_()
            
            loss.backward()
            
            # Check gradients
            if adv_images.grad is None:
                print(f"    No gradients at iteration {i+1}")
                break
            
            # Update with gradient ascent
            with torch.no_grad():
                adv_images = adv_images.detach() + alpha * adv_images.grad.sign()
                perturbation = torch.clamp(adv_images - images, -epsilon, epsilon)
                adv_images = images + perturbation
                adv_images = torch.clamp(adv_images, -1, 1)
        
        return adv_images.detach()
    
    def cw_attack(self, images, labels, c=1.0, kappa=0, max_iter=100, learning_rate=0.01):
        '''Carlini-Wagner (CW) Attack'''
        self.model.eval()
        
        images = images.clone().detach().to(self.device)
        labels = labels.to(self.device)
        batch_size = images.shape[0]
        
        # Use tanh space for better optimization
        w = torch.zeros_like(images, requires_grad=True, device=self.device)
        optimizer = torch.optim.Adam([w], lr=learning_rate)
        
        best_adv = images.clone()
        best_loss = float('inf') * torch.ones(batch_size, device=self.device)
        
        for step in range(max_iter):
            # Transform w to valid image range
            adv_images = torch.tanh(w) * 1.0
            
            # Forward pass
            outputs = self.model(adv_images)
            
            # CW loss formulation
            real = outputs.gather(1, labels.unsqueeze(1)).squeeze(1)
            
            # Get second highest logit
            other = outputs.clone()
            other.scatter_(1, labels.unsqueeze(1), -float('inf'))
            other_max = other.max(1)[0]
            
            # Loss: want other_max > real (misclassification)
            f_loss = torch.clamp(real - other_max + kappa, min=0)
            
            # L2 distance loss
            l2_dist = torch.sum((adv_images - images) ** 2, dim=[1, 2, 3])
            
            # Combined loss
            loss = (c * f_loss + l2_dist).sum()
            
            # Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Track best adversarial examples
            with torch.no_grad():
                pred = outputs.argmax(1)
                successful = (pred != labels)
                
                for i in range(batch_size):
                    if successful[i] and l2_dist[i] < best_loss[i]:
                        best_loss[i] = l2_dist[i]
                        best_adv[i] = adv_images[i]
        
        return best_adv.detach()
"""

with open(file_path, "w") as f:
    f.write(new_code)

print(" attacks.py created for FedAvg!")
print(f" File path: {file_path}")

 attacks.py created for FedAvg!
 File path: /kaggle/working/FedAvg-PyTorch/attacks.py


In [8]:
file_path = "/kaggle/working/FedAvg-PyTorch/model.py"
new_code = '''
# ========================================
# model.py ‚Äî CLIP-based Classifier for MedMNIST (Auto-configured)
# ========================================
import torch
from torch import nn
import clip

class CLIPMedMNISTClassifier(nn.Module):
    def __init__(self, args, name='clip_model'):
        super(CLIPMedMNISTClassifier, self).__init__()
        self.name = name
        
        # Load pretrained CLIP
        self.clip_model, self.preprocess = clip.load("ViT-B/32", device=args.device)
        
        # Get class names from args (auto-configured based on dataset)
        self.class_names = args.class_names
        self.num_classes = args.num_classes
        
        # Freeze text encoder completely
        for param in self.clip_model.transformer.parameters():
            param.requires_grad = False
        for param in self.clip_model.token_embedding.parameters():
            param.requires_grad = False
        for param in self.clip_model.ln_final.parameters():
            param.requires_grad = False
        self.clip_model.positional_embedding.requires_grad = False
        self.clip_model.text_projection.requires_grad = False
        
        # Precompute and freeze text embeddings
        with torch.no_grad():
            text_tokens = clip.tokenize([f"a microscopic image of {c}" for c in self.class_names]).to(args.device)
            text_features = self.clip_model.encode_text(text_tokens)
            text_features /= text_features.norm(dim=-1, keepdim=True)
        
        # Register as buffer
        self.register_buffer('text_features', text_features)
    
    def forward(self, images):
        image_features = self.clip_model.encode_image(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        logits = 100.0 * image_features @ self.text_features.T
        return logits
    
    def get_trainable_params(self):
        return [p for p in self.parameters() if p.requires_grad]
    
    def freeze_layers(self, num_layers_to_freeze):
        if num_layers_to_freeze > 0:
            for i, layer in enumerate(self.clip_model.visual.transformer.resblocks):
                if i < num_layers_to_freeze:
                    for param in layer.parameters():
                        param.requires_grad = False

'''

with open(file_path, "w") as f:
    f.write(new_code)

print(" model.py updated!")
print(f" File saved at: {file_path}")

 model.py updated!
 File saved at: /kaggle/working/FedAvg-PyTorch/model.py


In [None]:
file_path = "/kaggle/working/FedAvg-PyTorch/server.py"

new_code = """
# ========================================
# server.py ‚Äì FedAvg with Attack Testing
# ========================================
import copy
import random
import numpy as np
import torch
import os
import json
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, mean_squared_error
from model import CLIPMedMNISTClassifier as ImageClassifier
from get_data import load_medmnist_data
from client import Client
from attacks import AdversarialAttacks

class FedAvgServer:
    def __init__(self, args):
        self.args = args
        os.makedirs(args.checkpoint_dir, exist_ok=True)
        self.current_round = 0
        self.best_global_acc = 0
        self.history = {'rounds': [], 'avg_accuracy': [], 'best_accuracy': []}
        self.global_model = ImageClassifier(args, name="server").to(args.device)
        self.client_models = []
        for i in range(self.args.K):
            model = copy.deepcopy(self.global_model)
            model.name = f"Client_{i}"
            self.client_models.append(model)
        self.client_loaders, self.test_loader = load_medmnist_data(args)

        latest_ckpt = os.path.join(args.checkpoint_dir, "checkpoint_latest.pth")
        if os.path.exists(latest_ckpt):
            self.load_checkpoint(latest_ckpt)
            print(f"\\n Resuming training from Round {self.current_round}")
        else:
            print("\\n Starting new training session")

    def load_checkpoint(self, checkpoint_path):
        print(f"\\n Loading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=self.args.device, weights_only=False)
        self.global_model.load_state_dict(checkpoint['server_state_dict'])
        if 'client_state_dicts' in checkpoint:
            for i, state_dict in enumerate(checkpoint['client_state_dicts']):
                self.client_models[i].load_state_dict(state_dict)
        self.current_round = checkpoint.get('round', 0)
        self.best_global_acc = checkpoint.get('best_global_acc', 0)
        self.history = checkpoint.get('history', {'rounds': [], 'avg_accuracy': [], 'best_accuracy': []})

    def dispatch(self, selected_clients):
        for idx in selected_clients:
            client_model = self.client_models[idx]
            for client_param, global_param in zip(client_model.parameters(), self.global_model.parameters()):
                client_param.data = global_param.data.clone()

    def aggregate(self, selected_clients):
        total_samples = sum([len(self.client_loaders[idx].dataset) for idx in selected_clients])
        global_params = {k: torch.zeros_like(v.data) for k, v in self.global_model.named_parameters()}
        for idx in selected_clients:
            weight = len(self.client_loaders[idx].dataset) / total_samples
            client_params = dict(self.client_models[idx].named_parameters())
            for k in global_params.keys():
                global_params[k] += client_params[k].data * weight
        for k, v in self.global_model.named_parameters():
            v.data = global_params[k].data.clone()

    def client_update(self, idx):
        client_model = self.client_models[idx]
        client_loader = self.client_loaders[idx]
        client_obj = Client(model=client_model,
                            train_loader=client_loader,
                            device=self.args.device,
                            lr=self.args.lr,
                            weight_decay=self.args.weight_decay)
        client_obj.train(epochs=self.args.E)

    def test_global_model(self):
        '''Test global model with optional adversarial attacks'''
        self.global_model.eval()
        all_labels = []
        all_preds = []

        # Initialize attacker if needed
        attacker = None
        if self.args.attack_type != 'none':
            attacker = AdversarialAttacks(self.global_model, self.args.device)
            print(f"\\n  Testing with {self.args.attack_type.upper()} attack (Œµ={self.args.attack_epsilon})")

        context_manager = torch.no_grad() if self.args.attack_type == 'none' else torch.enable_grad()

        with context_manager:
            for images, labels in tqdm(self.test_loader, desc="  Testing", leave=False, ncols=100):
                #  Store ORIGINAL clean images
                original_images = images.clone()
                
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                labels = labels.squeeze()

                #  Apply attack on ORIGINAL images
                if self.args.attack_type == 'fgsm':
                    images = attacker.fgsm_attack(original_images.to(self.args.device), labels,
                                                   self.args.attack_epsilon)
                elif self.args.attack_type == 'pgd':
                    images = attacker.pgd_attack(original_images.to(self.args.device), labels,
                                                  epsilon=self.args.attack_epsilon,
                                                  alpha=self.args.pgd_alpha,
                                                  iters=self.args.pgd_iters)
                elif self.args.attack_type == 'cw':
                    images = attacker.cw_attack(original_images.to(self.args.device), labels,
                                                 c=self.args.cw_c,
                                                 max_iter=self.args.cw_max_iter)

                # Get predictions
                with torch.no_grad():
                    outputs = self.global_model(images)

                    if isinstance(outputs, tuple) or hasattr(outputs, 'logits'):
                        if hasattr(outputs, 'logits'):
                            outputs = outputs.logits
                        else:
                            outputs = outputs[0]

                    _, predicted = torch.max(outputs, 1)
                    all_labels.extend(labels.cpu().numpy())
                    all_preds.extend(predicted.cpu().numpy())

        all_labels = np.array(all_labels)
        all_preds = np.array(all_preds)

        acc = 100.0 * np.mean(all_labels == all_preds)
        try:
            precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
            recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)
            f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
        except Exception:
            precision, recall, f1 = 0.0, 0.0, 0.0

        rmse = float(np.sqrt(mean_squared_error(all_labels, all_preds)))

        metrics = {
            'accuracy': acc,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'rmse': rmse
        }
        return metrics

    def save_checkpoint(self, round_num, metrics):
        checkpoint = {
            'round': round_num,
            'server_state_dict': self.global_model.state_dict(),
            'client_state_dicts': [model.state_dict() for model in self.client_models],
            'best_global_acc': self.best_global_acc,
            'history': self.history,
            'args': vars(self.args)
        }
        
        torch.save(checkpoint, os.path.join(self.args.checkpoint_dir, 'checkpoint_latest.pth'))
        
        with open(os.path.join(self.args.checkpoint_dir, 'training_history.json'), 'w') as f:
            json.dump(self.history, f, indent=4)

    def run(self):
        start_round = self.current_round
        for r in range(start_round, self.args.r):
            print(f"\\n{'='*60}")
            print(f" Round [{r+1}/{self.args.r}]")
            print('='*60)
            
            m = max(int(self.args.C * self.args.K), 1)
            selected_clients = random.sample(range(self.args.K), m)
            print(f" Selected Clients: {selected_clients}")
            
            self.dispatch(selected_clients)
            
            for idx in selected_clients:
                print(f"\\n Training Client {idx}...")
                self.client_update(idx)
            
            self.aggregate(selected_clients)
            
            metrics = self.test_global_model()
            avg_acc = metrics['accuracy']
            
            self.history['rounds'].append(r+1)
            self.history['avg_accuracy'].append(avg_acc)
            self.history['precision'] = self.history.get('precision', []) + [metrics['precision']]
            self.history['recall'] = self.history.get('recall', []) + [metrics['recall']]
            self.history['f1'] = self.history.get('f1', []) + [metrics['f1']]
            self.history['rmse'] = self.history.get('rmse', []) + [metrics['rmse']]
            
            if avg_acc > self.best_global_acc:
                self.best_global_acc = avg_acc
            
            self.current_round = r + 1
            self.save_checkpoint(r+1, metrics)
            
            print(f"\\n{'='*60}")
            print(f" Round [{r+1}/{self.args.r}] Results:")
            print(f"{'='*60}")
            print(f"   Accuracy:  {metrics['accuracy']:.2f}%")
            print(f"   Precision: {metrics['precision']:.4f}")
            print(f"   Recall:    {metrics['recall']:.4f}")
            print(f"   F1-Score:  {metrics['f1']:.4f}")
            print(f"   RMSE:      {metrics['rmse']:.4f}")
            print(f"   Best Acc:  {self.best_global_acc:.2f}%")
            print('='*60)
            
        return self.global_model
"""

with open(file_path, "w") as f:
    f.write(new_code)

print(" server.py updated with attack testing!")

 server.py updated with attack testing!


In [10]:
file_path = "/kaggle/working/FedAvg-PyTorch/main.py"

new_code = """
# ========================================
# main.py ‚Äì Run FedAvg with Attack Testing
# ========================================
from args import args_parser
from server import FedAvgServer
import torch
import os
import json
import matplotlib.pyplot as plt

def plot_training_history(history, save_path):
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['rounds'], history['avg_accuracy'], 'b-', label='Avg Accuracy', linewidth=2)
    plt.plot(history['rounds'], history['best_accuracy'], 'r--', label='Best Accuracy', linewidth=2)
    plt.xlabel('Communication Round')
    plt.ylabel('Accuracy (%)')
    plt.title('Federated Learning Accuracy Progress')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    if len(history['avg_accuracy']) > 1:
        improvements = [history['avg_accuracy'][i] - history['avg_accuracy'][i-1] 
                        for i in range(1, len(history['avg_accuracy']))]
        plt.bar(history['rounds'][1:], improvements, alpha=0.7)
        plt.axhline(y=0, color='r', linestyle='-', linewidth=0.5)
        plt.xlabel('Communication Round')
        plt.ylabel('Accuracy Change (%)')
        plt.title('Round-to-Round Accuracy Change')
        plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f' Training plot saved at: {save_path}')

def test_with_all_attacks(server, args):
    '''Test model with all attack types'''
    attacks = ['none', 'fgsm', 'pgd', 'cw']
    results = {}
    
    print('\\n' + '='*60)
    print('  ADVERSARIAL ROBUSTNESS TESTING')
    print('='*60)
    
    for attack in attacks:
        original_attack = args.attack_type
        args.attack_type = attack
        
        print(f"\\n{'='*60}")
        print(f" Testing with {attack.upper()} attack")
        print('='*60)
        
        metrics = server.test_global_model()
        results[attack] = metrics
        
        print(f"\\n Results:")
        print(f"   Accuracy:  {metrics['accuracy']:.2f}%")
        print(f"   Precision: {metrics['precision']:.4f}")
        print(f"   Recall:    {metrics['recall']:.4f}")
        print(f"   F1-Score:  {metrics['f1']:.4f}")
        print(f"   RMSE:      {metrics['rmse']:.4f}")
    
    args.attack_type = original_attack
    
    # Save results
    results_path = os.path.join(args.checkpoint_dir, 'attack_results.json')
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=4)
    print(f"\\n Attack results saved at: {results_path}")
    
    return results

def main():
    args = args_parser()
    
    print('\\n' + '='*60)
    print(' FEDAVG CONFIGURATION')
    print('='*60)
    print(f'Dataset: {args.dataset.upper()}')
    print(f'Classes: {args.num_classes} | Clients: {args.K} | Rounds: {args.r}')
    print(f'Batch: {args.B} | LR: {args.lr} | Device: {args.device}')
    print('='*60)
    
    server = FedAvgServer(args)
    
    # Check if training complete
    if server.current_round == args.r:
        print('\\n Training already complete!')
        print('  Running adversarial robustness testing...')
        attack_results = test_with_all_attacks(server, args)
        
        # Print comparison table
        print('\\n' + '='*60)
        print(' ATTACK RESULTS COMPARISON')
        print('='*60)
        print(f"{'Attack':<15} {'Accuracy':<10} {'Precision':<10} {'Recall':<10} {'F1'}") 
        print('-'*60)
        for attack, metrics in attack_results.items():
            print(f"{attack.upper():<15} {metrics['accuracy']:>6.2f}%    {metrics['precision']:>7.4f}    {metrics['recall']:>7.4f}    {metrics['f1']:>7.4f}")
        print('='*60)
    else:
        # Run training
        final_model = server.run()
        
        # Save final model
        final_model_path = os.path.join(args.checkpoint_dir, 'final_global_model.pth')
        torch.save(final_model.state_dict(), final_model_path)
        print(f'\\n Final model saved: {final_model_path}')
        
        # Plot training curves
        plot_path = os.path.join(args.checkpoint_dir, 'training_plot.png')
        plot_training_history(server.history, plot_path)
        
        # Print summary
        print('\\n' + '='*60)
        print(' TRAINING SUMMARY')
        print('='*60)
        print(f'Best Accuracy: {server.best_global_acc:.2f}%')
        print(f'Final Accuracy: {server.history[\"avg_accuracy\"][-1]:.2f}%')
        print('='*60)
        
        # Test with attacks
        print('\\n  Testing adversarial robustness...')
        attack_results = test_with_all_attacks(server, args)
        
        # Print comparison
        print('\\n' + '='*60)
        print(' ATTACK RESULTS COMPARISON')
        print('='*60)
        print(f"{'Attack':<15} {'Accuracy':<10} {'Drop'}") 
        print('-'*60)
        clean_acc = attack_results['none']['accuracy']
        for attack, metrics in attack_results.items():
            drop = clean_acc - metrics['accuracy'] if attack != 'none' else 0
            print(f"{attack.upper():<15} {metrics['accuracy']:>6.2f}%    {drop:>5.2f}%")
        print('='*60)

if __name__ == '__main__':
    main()
"""

with open(file_path, "w") as f:
    f.write(new_code)

print("‚úÖ main.py updated with attack testing!")

‚úÖ main.py updated with attack testing!


In [11]:
!pip install medmnist --quiet


[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m664.8/664.8 MB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m211.5/211.5 MB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m56.3/56.3 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m127.9/127.9 MB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚î

In [12]:
!pip install git+https://github.com/openai/CLIP.git


Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-v08abj7f
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-v08abj7f
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m44.8/44.8 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369490 sha256=e5b7cd1f16a5f72d1311baee79673bb6e131b57bfbd33b18a916d5ccdf0818ed
  Stored in 

In [None]:
!python /kaggle/working/FedAvg-PyTorch/main.py


 FEDAVG CONFIGURATION
Dataset: ORGANAMNIST
Classes: 11 | Clients: 5 | Rounds: 1
Batch: 32 | LR: 0.003 | Device: cuda
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 338M/338M [00:03<00:00, 102MiB/s]

 Loading ORGANAMNIST Dataset
Number of classes: 11
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38.2M/38.2M [00:34<00:00, 1.11MB/s]
 Using Train+Val: 41052 samples for federated learning

 Creating balanced Non-IID split (dominant_ratio=0.7)...
 Split cached to: ./data/medmnist/client_indices_organamnist_K5_dr0.7.pkl

 Client-wise Image Distribution
Client  0:  8151 samples | Dominant: Class 0 (19.5%)
           [C0:1593, C1:122, C2:119, C3:140, C4:340, C5:891, C6:1440, C7:991, C8:988, C9:712, C10:815]
Client  1:  7743 samples | Dominant: Class 6 (18.6%)
           [C0:171, C1:1136, C2:119, C3:140, C4:340, C5:891, C6:1440, C7:991, C8:98