How to Use This Version for Checkpoint Resume in Kaggle

Before running this notebook for the second time (to resume training from a checkpoint), follow the steps below carefully.

1. Download the latest checkpoint file
Download the file named checkpoint_latest.pth from the previous version of your notebook or experiment.

2. Upload the checkpoint to Kaggle Input Directory
Place the downloaded file inside your Kaggle input path, for example:
/kaggle/input/path1/pytorch/default/1/checkpoint_latest.pth

3. Run the following code cell before starting training
This code will copy the checkpoint file to the working directory (/kaggle/working/checkpoints) so that training can resume from the saved state.

4. Resume Training
After the checkpoint file is copied successfully, running the rest of the notebook will automatically start training from the previous checkpoint instead of starting from scratch.



## you can change the dataset and the attack type by simply changing the name in the args.py file.no need to modify anything else.

In [None]:
import os
import shutil

# Source and destination paths
src = "/kaggle/input/path1/pytorch/default/1/checkpoint_latest.pth"
dst_dir = "/kaggle/working/checkpoints"
dst = os.path.join(dst_dir, "checkpoint_latest.pth")

# Step 1: Check if source file exists
if not os.path.exists(src):
    print(f"‚ùå Source file not found: {src}")
else:
    print(f"‚úÖ Found source file: {src}")

    # Step 2: Ensure destination directory exists
    if not os.path.exists(dst_dir):
        os.makedirs(dst_dir)
        print(f"üìÇ Created destination directory: {dst_dir}")
    else:
        print(f"üìÅ Destination directory already exists: {dst_dir}")

    # Step 3: Copy the file
    shutil.copy(src, dst)
    print(f"‚úÖ Copied file to: {dst}")

    # Step 4: List all files in destination
    files = os.listdir(dst_dir)
    if files:
        print("\nüìÑ Files in /kaggle/working/checkpoints:")
        for f in files:
            print(" ‚îú‚îÄ‚îÄ", f)
    else:
        print("‚ö† Destination directory is empty (unexpected).")


In [1]:
import os

# Target directory
base_dir = "/kaggle/working/FedPer-PyTorch"
os.makedirs(base_dir, exist_ok=True)

# Python files to create
files = ["main.py", "server.py", "client.py", "model.py", "get_data.py", "args.py"]

# Create each file if not exists
for file in files:
    file_path = os.path.join(base_dir, file)
    if not os.path.exists(file_path):
        with open(file_path, "w") as f:
            f.write(f"# {file} ‚Äî auto-created placeholder\n")
        print(f" Created: {file_path}")
    else:
        print(f" Already exists: {file_path}")

print("\n Folder and files ready in:", base_dir)


 Created: /kaggle/working/FedPer-PyTorch/main.py
 Created: /kaggle/working/FedPer-PyTorch/server.py
 Created: /kaggle/working/FedPer-PyTorch/client.py
 Created: /kaggle/working/FedPer-PyTorch/model.py
 Created: /kaggle/working/FedPer-PyTorch/get_data.py
 Created: /kaggle/working/FedPer-PyTorch/args.py

 Folder and files ready in: /kaggle/working/FedPer-PyTorch


In [2]:
file_path = "/kaggle/working/FedPer-PyTorch/args.py"

new_code = '''
# ========================================
# args.py ‚Äî FedPer 
# ========================================
import argparse
import torch

# Dataset configurations - Change dataset name here to switch
DATASET_CONFIGS = {
    'pathmnist': {
        'num_classes': 9,
        'class_names': [
            "adipose tissue", "background", "debris", "lymphocytes",
            "mucus", "smooth muscle", "normal colon mucosa",
            "cancer-associated stroma", "colorecal adenocarcinoma epithelium"
        ],
        'input_channels': 3
    },
    'tissuemnist': {
        'num_classes': 8,
        'class_names': [
            "collecting duct", "distal convoluted tubule",
            "glomerular endothelial cells", "interstitial endothelial cells",
            "leukocytes", "podocytes", "proximal tubule", "thick ascending limb"
        ],
        'input_channels': 1
    },
    'organamnist': {
        'num_classes': 11,
        'class_names': [
            "bladder", "femur-left", "femur-right", "heart",
            "kidney-left", "kidney-right", "liver",
            "lung-left", "lung-right", "spleen", "pelvis"
        ],
        'input_channels': 1
    },
    'octmnist': {
        'num_classes': 4,
        'class_names': [
            "choroidal neovascularization", "diabetic macular edema",
            "drusen", "normal"
        ],
        'input_channels': 1
    }
}

def args_parser():
    parser = argparse.ArgumentParser(description="FedPer - Config File")
    
    # -------------------------------
    # MAIN: Change this to switch dataset
    # -------------------------------
    parser.add_argument('--dataset', type=str, default='organamnist', 
                        choices=['pathmnist', 'tissuemnist', 'organamnist', 'octmnist'],
                        help='MedMNIST dataset to use')
    
    # -------------------------------
    # Federated Learning Parameters
    # -------------------------------
    parser.add_argument('--E', type=int, default=5, help='local epochs')
    parser.add_argument('--r', type=int, default=50, help='number of communication rounds')
    parser.add_argument('--K', type=int, default=5, help='total number of clients')
    parser.add_argument('--C', type=float, default=1, help='client sampling rate per round')
    parser.add_argument('--B', type=int, default=32, help='batch size')
    parser.add_argument('--use_combined', action='store_true', 
                        help='Use train+val+test combined (all data)')

    # -------------------------------
    # Model parameters (auto-configured)
    # -------------------------------
    parser.add_argument('--clip_model', type=str, default='ViT-B/32', help='CLIP model variant')
    parser.add_argument('--freeze_clip', action='store_true', help='Freeze CLIP backbone')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout (stronger regularization)')
    parser.add_argument('--Kp', type=int, default=2, help='number of personalized layers')

    # -------------------------------
    # Optimizer Settings
    # -------------------------------
    parser.add_argument('--lr', type=float, default=0.003, help='learning rate')
    parser.add_argument('--optimizer', type=str, default='sgd', help='optimizer')
    parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum for SGD')

    # -------------------------------
    # Checkpoint Settings
    # -------------------------------
    parser.add_argument('--checkpoint_dir', type=str,
                        default='./checkpoints_fedper',
                        help='directory to save checkpoints')
    # -------------------------------
    # Adversarial Attack Settings
    # -------------------------------
    parser.add_argument('--enable_attack', action='store_true',
                        help='Enable adversarial attacks during testing')
    parser.add_argument('--attack_type', type=str, default='fgsm',
                        choices=['fgsm', 'pgd', 'cw'],
                        help='Type of adversarial attack')
    parser.add_argument('--attack_epsilon', type=float, default=0.03,
                        help='Perturbation budget for FGSM/PGD (L-inf norm)')
    parser.add_argument('--pgd_alpha', type=float, default=0.01,
                        help='Step size for PGD attack')
    parser.add_argument('--pgd_steps', type=int, default=10,
                        help='Number of PGD iterations')
    parser.add_argument('--cw_c', type=float, default=1.0,
                        help='C&W attack confidence parameter')
    parser.add_argument('--cw_steps', type=int, default=100,
                        help='Number of optimization steps for C&W')
    parser.add_argument('--cw_lr', type=float, default=0.01,
                        help='Learning rate for C&W optimization')
    
    # -------------------------------
    # Device
    # -------------------------------
    parser.add_argument('--device', default=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                        help='cuda or cpu')
    
    # -------------------------------
    # Balanced Non-IID Settings
    # -------------------------------
    parser.add_argument('--dominant_ratio', type=float, default=0.7,
                        help='Fraction of dominant class per client (0-1)')

    args = parser.parse_args(args=[])
    
    # Auto-configure based on selected dataset
    if args.dataset in DATASET_CONFIGS:
        config = DATASET_CONFIGS[args.dataset]
        args.num_classes = config['num_classes']
        args.input_channels = config['input_channels']
        args.class_names = config['class_names']
    else:
        raise ValueError(f"Dataset {args.dataset} not configured!")
    
    return args
'''

with open(file_path, "w") as f:
    f.write(new_code)

print(" args.py updated!")

 args.py updated!


In [3]:
file_path = "/kaggle/working/FedPer-PyTorch/get_data.py"

new_code = r'''
# ========================================
# get_data.py ‚Äì Fixed with Retry Logic for FedPer
# Supports: pathmnist, tissuemnist, organamnist, octmnist
# ========================================
import os
import numpy as np
import torch
import pickle
import time
from torch.utils.data import DataLoader, Subset, ConcatDataset
from torchvision import transforms
from medmnist import INFO
from medmnist.dataset import PathMNIST, TissueMNIST, OrganAMNIST, OCTMNIST

# Map dataset names to classes
DATASET_MAP = {
    'pathmnist': PathMNIST,
    'tissuemnist': TissueMNIST,
    'organamnist': OrganAMNIST,
    'octmnist': OCTMNIST
}

def get_transforms():
    """Returns train and test transforms (handles grayscale and RGB)"""
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.Grayscale(num_output_channels=3),  # ensures 3 channels for CLIP
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    return train_transform, test_transform

def balanced_noniid_split(dataset, num_clients, dominant_ratio=0.7):
    """Balanced Non-IID split: dominant_ratio% to main client, rest distributed"""
    labels = np.array([dataset[i][1].item() for i in range(len(dataset))])
    num_classes = len(np.unique(labels))
    client_indices = [[] for _ in range(num_clients)]

    # Prepare class indices
    class_indices = {c: np.where(labels == c)[0].tolist() for c in range(num_classes)}
    for c in class_indices:
        np.random.shuffle(class_indices[c])

    # Step 1: assign dominant_ratio% of dominant class to primary client
    for client_id in range(num_clients):
        dominant_class = client_id % num_classes
        n_dominant = int(len(class_indices[dominant_class]) * dominant_ratio)
        if n_dominant > 0:
            client_indices[client_id].extend(class_indices[dominant_class][:n_dominant])
            class_indices[dominant_class] = class_indices[dominant_class][n_dominant:]

    # Step 2: distribute remaining samples equally among other clients
    for c in range(num_classes):
        remaining = class_indices[c]
        np.random.shuffle(remaining)
        other_clients = [i for i in range(num_clients) if i % num_classes != c]
        for i, idx in enumerate(remaining):
            client_id = other_clients[i % len(other_clients)]
            client_indices[client_id].append(idx)

    # Shuffle each client's indices
    for i in range(num_clients):
        np.random.shuffle(client_indices[i])

    return client_indices

def download_with_retry(DatasetClass, root, split, transform, max_retries=3, delay=5):
    """Download dataset with retry logic"""
    for attempt in range(max_retries):
        try:
            print(f"  Attempt {attempt + 1}/{max_retries}: Downloading {split} split...")
            dataset = DatasetClass(root=root, split=split, download=True, transform=transform)
            print(f"   Successfully loaded {split} split")
            return dataset
        except Exception as e:
            if attempt < max_retries - 1:
                print(f"   Download failed: {str(e)}")
                print(f"  Waiting {delay} seconds before retry...")
                time.sleep(delay)
            else:
                print(f"   Failed after {max_retries} attempts")
                raise RuntimeError(f"""
Dataset download failed after {max_retries} attempts.

Possible solutions:
1. Wait a few minutes and try again (server may be overloaded)
2. Manually download the dataset:
   - Go to: https://zenodo.org/records/10519652
   - Download the .npz file for your dataset
   - Place it in: {root}/
3. Change dataset in args.py to one that's already downloaded
4. Check your internet connection

Error: {str(e)}
""")

def load_medmnist_data(args):
    """Load MedMNIST data for FedPer - automatically configured based on args.dataset"""
    train_transform, test_transform = get_transforms()
    data_root = './data/medmnist'
    os.makedirs(data_root, exist_ok=True)

    data_flag = args.dataset.lower()
    if data_flag not in DATASET_MAP:
        raise ValueError(f"Dataset {data_flag} not supported. Choose from {list(DATASET_MAP.keys())}")

    n_classes = args.num_classes

    print("\n" + "="*60)
    print(f" Loading {data_flag.upper()} Dataset for FedPer")
    print("="*60)
    print(f"Number of classes: {n_classes}")

    DatasetClass = DATASET_MAP[data_flag]

    # Check if dataset already exists
    dataset_file = os.path.join(data_root, f'{data_flag}.npz')
    if os.path.exists(dataset_file):
        print(f" Found existing dataset: {dataset_file}")
        print("  Loading from cache...")
        
    # Load datasets with retry logic
    print("\n Loading dataset splits...")
    train_dataset = download_with_retry(DatasetClass, data_root, 'train', train_transform)
    val_dataset = download_with_retry(DatasetClass, data_root, 'val', train_transform)
    test_dataset = download_with_retry(DatasetClass, data_root, 'test', test_transform)

    combined_train = ConcatDataset([train_dataset, val_dataset])
    print(f" Using Train+Val: {len(combined_train)} samples for federated learning")

    # Load or create client indices
    cache_file = f'./data/medmnist/client_indices_{data_flag}_K{args.K}_dr{args.dominant_ratio}_fedper.pkl'
    if os.path.exists(cache_file):
        print(f"\n Loading cached split from: {cache_file}")
        with open(cache_file, 'rb') as f:
            client_indices = pickle.load(f)
    else:
        print(f"\n Creating balanced Non-IID split (dominant_ratio={args.dominant_ratio})...")
        client_indices = balanced_noniid_split(combined_train, args.K, dominant_ratio=args.dominant_ratio)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        with open(cache_file, 'wb') as f:
            pickle.dump(client_indices, f)
        print(f" Split cached to: {cache_file}")

    # Print client-wise class distribution
    print("\n" + "="*60)
    print(" Client-wise Image Distribution")
    print("="*60)
    total_samples = 0
    for client_id, indices in enumerate(client_indices):
        client_labels = [combined_train[i][1].item() for i in indices]
        label_counts = np.bincount(client_labels, minlength=n_classes)
        total_client = len(indices)
        total_samples += total_client
        distribution_str = ", ".join([f"C{c}:{label_counts[c]}" for c in range(n_classes)])
        dominant_class = np.argmax(label_counts)
        dominant_pct = (label_counts[dominant_class] / total_client) * 100
        print(f"Client {client_id:2d}: {total_client:5d} samples | Dominant: Class {dominant_class} ({dominant_pct:.1f}%)")
        print(f"           [{distribution_str}]")

    print(f"\nTotal samples: {total_samples}")
    print("="*60 + "\n")

    # Create DataLoaders
    client_loaders = []
    for i, indices in enumerate(client_indices):
        subset = Subset(combined_train, indices)
        loader = DataLoader(subset, batch_size=args.B, shuffle=True, num_workers=0, pin_memory=True)
        client_loaders.append(loader)

    test_loader = DataLoader(test_dataset, batch_size=args.B, shuffle=False, num_workers=0, pin_memory=True)
    return client_loaders, test_loader
'''

with open(file_path, "w") as f:
    f.write(new_code)

print(" get_data.py updated!")

 get_data.py updated!


In [4]:
# ========================================
# STEP 1: Fix model.py
# ========================================
file_path = "/kaggle/working/FedPer-PyTorch/model.py"

fixed_model_code = r'''
# ========================================
# model.py ‚Äî CLIP-based FedPer Classifier (FIXED)
# ========================================
import torch
from torch import nn
import clip

class CLIPFedPerClassifier(nn.Module):
    """
    FedPer with CLIP: Proper Implementation
    ----------------------------------------
    Shared Layers: Last 2 transformer blocks of CLIP (AGGREGATED)
    Personalized Layers: MLP head (NOT aggregated)
    """
    def __init__(self, args, name='clip_fedper_model'):
        super(CLIPFedPerClassifier, self).__init__()
        self.name = name
        self.num_classes = args.num_classes
        self.Kp = getattr(args, 'Kp', 2)
        self.dropout = getattr(args, 'dropout', 0.3)
        
        # Load pretrained CLIP model
        self.clip_model, self.preprocess = clip.load(args.clip_model, device=args.device)
        
        # Get CLIP feature dimension
        self.feature_dim = self.clip_model.visual.output_dim
        
        # Auto-configured class names
        self.class_names = args.class_names
        
        # FedPer Strategy: Unfreeze last transformer blocks
        self._setup_fedper_layers()
        
        # Personalized MLP Head (NOT aggregated)
        self.head = nn.Sequential(
            nn.Linear(self.feature_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(self.dropout),
            
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(self.dropout * 0.7),
            
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(self.dropout * 0.5),
            
            nn.Linear(256, self.num_classes)
        )
        
        self._initialize_weights()
    
    def _setup_fedper_layers(self):
        """Freeze most of CLIP, unfreeze last 2 transformer blocks"""
        # Freeze all first
        for param in self.clip_model.parameters():
            param.requires_grad = False
        
        # Unfreeze last 2 transformer blocks (shared, will be aggregated)
        if hasattr(self.clip_model.visual, 'transformer'):
            num_blocks = len(self.clip_model.visual.transformer.resblocks)
            num_shared = 2  # Last 2 blocks for aggregation
            
            for block in self.clip_model.visual.transformer.resblocks[-num_shared:]:
                for param in block.parameters():
                    param.requires_grad = True
            
            print(f"  ‚úì Unfroze last {num_shared}/{num_blocks} transformer blocks (shared layers)")
        
        # Also unfreeze final projection layer
        if hasattr(self.clip_model.visual, 'proj') and self.clip_model.visual.proj is not None:
            self.clip_model.visual.proj.requires_grad = True
            print(f"  ‚úì Unfroze CLIP projection layer (shared)")
    
    def _initialize_weights(self):
        for m in self.head.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, images):
        # CLIP features (last 2 blocks are trainable)
        image_features = self.clip_model.encode_image(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        image_features = image_features.float()
        
        # Personalized head
        logits = self.head(image_features)
        return logits
    
    # ========================================
    # FedPer Required Methods
    # ========================================
    def get_shared_params(self):
        """Returns trainable CLIP parameters for aggregation"""
        shared_params = []
        for param in self.clip_model.parameters():
            if param.requires_grad:
                shared_params.append(param)
        return shared_params
    
    def get_personalized_params(self):
        """Returns personalized head parameters (NOT aggregated)"""
        return [p for p in self.head.parameters() if p.requires_grad]
    
    def get_trainable_params(self):
        """Return all trainable parameters"""
        return [p for p in self.parameters() if p.requires_grad]
'''

with open(file_path, "w") as f:
    f.write(fixed_model_code)

print("‚úÖ model.py FIXED!")

‚úÖ model.py FIXED!


In [5]:
file_path = "/kaggle/working/FedPer-PyTorch/client.py"

new_code = r'''
# ========================================
# client.py ‚Äî Client
# ========================================
import torch
from torch import nn
import torch.optim as optim
from tqdm import tqdm
import copy

class Client:
    def __init__(self, model, train_loader, device, val_loader=None, lr=0.0001, weight_decay=5e-4, Kp=2, shared_params=None):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.lr = lr
        self.weight_decay = weight_decay
        self.Kp = Kp  # number of personalized layers (head)
        self.shared_params = [p.clone().detach() for p in shared_params] if shared_params else None

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9, weight_decay=self.weight_decay)

    def compute_accuracy(self, loader):
        self.model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in loader:
                images, labels = images.to(self.device), labels.to(self.device)
                labels = labels.squeeze()
                outputs = self.model(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        return 100 * correct / total if total > 0 else 0

    def train(self, epochs=1, shared_params=None):
        """Train client model; keep head personalized, optionally update shared layers"""
        print(f"\\n Training {self.model.name} with FedPer (personalized head)...")

        # Initialize shared layers if provided
        if shared_params is not None:
            model_shared = self.model.get_shared_params()
            for m_p, g_p in zip(model_shared, shared_params):
                m_p.data = g_p.data.clone()

        for epoch in range(epochs):
            self.model.train()
            total_loss = 0.0
            pbar = tqdm(self.train_loader, desc=f"  Epoch {epoch+1}/{epochs}", ncols=100, leave=False)
            
            for images, labels in pbar:
                images, labels = images.to(self.device), labels.to(self.device)
                labels = labels.squeeze()

                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()

                total_loss += loss.item()
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})

            avg_loss = total_loss / len(self.train_loader)
            train_acc = self.compute_accuracy(self.train_loader)
            log_msg = f"  Epoch {epoch+1:2d}/{epochs} | Loss: {avg_loss:.4f} | Train Acc: {train_acc:.2f}%"

            if self.val_loader is not None:
                val_acc = self.compute_accuracy(self.val_loader)
                val_loss = 0
                self.model.eval()
                with torch.no_grad():
                    for images, labels in self.val_loader:
                        images, labels = images.to(self.device), labels.to(self.device)
                        labels = labels.squeeze()
                        outputs = self.model(images)
                        loss = self.criterion(outputs, labels)
                        val_loss += loss.item()
                val_loss /= len(self.val_loader)
                log_msg += f" | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%"

            print(log_msg)

        print(f" {self.model.name} local training complete (FedPer)\\n")
        return self.model.state_dict()
'''

with open(file_path, "w") as f:
    f.write(new_code)

print(" client.py updated!")



 client.py updated!


In [6]:
# ========================================
# STEP 2: Fix server.py - Global model
# ========================================
file_path = "/kaggle/working/FedPer-PyTorch/server.py"

fixed_server_code = r'''
# ========================================
# server.py ‚Äî FedPer Server (FIXED)
# ========================================
import copy
import random
import numpy as np
import torch
import os
import json
from model import CLIPFedPerClassifier as ImageClassifier
from get_data import load_medmnist_data
from client import Client
from sklearn.metrics import precision_score, recall_score, f1_score, mean_squared_error

class FedPerServer:
    def __init__(self, args, resume_from=None):
        self.args = args
        os.makedirs(args.checkpoint_dir, exist_ok=True)
        self.current_round = 0
        self.best_global_acc = 0
        self.history = {
            'rounds': [], 
            'avg_accuracy': [], 
            'best_accuracy': [],
            'precision': [],
            'recall': [],
            'f1': [],
            'rmse': []
        }

        self.global_model = ImageClassifier(args, name="server").to(args.device)
        self.client_models = []
        for i in range(self.args.K):
            model = copy.deepcopy(self.global_model)
            model.name = f"Client_{i}"
            self.client_models.append(model)

        self.client_loaders, self.test_loader = load_medmnist_data(args)
        if resume_from:
            self.load_checkpoint(resume_from)

    def load_checkpoint(self, checkpoint_path):
        """Load checkpoint"""
        print("\n Loading checkpoint from:", checkpoint_path)
        
        checkpoint = torch.load(
            checkpoint_path, 
            map_location=self.args.device,
            weights_only=False  
        )
        
        self.global_model.load_state_dict(checkpoint['server_state_dict'])
        
        if 'client_state_dicts' in checkpoint:
            for i, state_dict in enumerate(checkpoint['client_state_dicts']):
                self.client_models[i].load_state_dict(state_dict)
        
        self.current_round = checkpoint.get('round', 0)
        self.best_global_acc = checkpoint.get('best_global_acc', 0)
        self.history = checkpoint.get('history', {
            'rounds': [], 
            'avg_accuracy': [], 
            'best_accuracy': [],
            'precision': [],
            'recall': [],
            'f1': [],
            'rmse': []
        })
        
        print(" Checkpoint loaded successfully!")
        print(f"   Resuming from Round: {self.current_round}/{self.args.r}")
        print(f"   Best Accuracy: {self.best_global_acc:.2f}%")

    def dispatch(self, selected_clients):
        """Send global shared params to clients"""
        global_shared_params = self.global_model.get_shared_params()
        
        if len(global_shared_params) == 0:
            print("  WARNING: No shared parameters to dispatch!")
            return
            
        for idx in selected_clients:
            client_model = self.client_models[idx]
            client_params = client_model.get_shared_params()
            for c_param, g_param in zip(client_params, global_shared_params):
                c_param.data = g_param.data.clone()

    def aggregate(self, selected_clients):
        """Aggregate shared layers only"""
        global_shared_params = self.global_model.get_shared_params()
        
        if len(global_shared_params) == 0:
            print("  WARNING: No shared parameters to aggregate!")
            return
            
        total_samples = sum([len(self.client_loaders[idx].dataset) for idx in selected_clients])
        agg_params = [torch.zeros_like(p.data) for p in global_shared_params]

        for idx in selected_clients:
            client_model = self.client_models[idx]
            client_shared = client_model.get_shared_params()
            weight = len(self.client_loaders[idx].dataset) / total_samples
            for i, p in enumerate(client_shared):
                agg_params[i] += p.data * weight

        for p, agg_p in zip(global_shared_params, agg_params):
            p.data = agg_p.data.clone()
        
        print(f"  ‚úì Aggregated {len(global_shared_params)} shared parameter tensors")

    def client_update(self, idx):
        client_model = self.client_models[idx]
        client_loader = self.client_loaders[idx]
        client_obj = Client(
            model=client_model,
            train_loader=client_loader,
            device=self.args.device,
            lr=self.args.lr,
            weight_decay=self.args.weight_decay,
            Kp=self.args.Kp
        )
        global_shared_params = self.global_model.get_shared_params()
        client_obj.train(epochs=self.args.E, shared_params=global_shared_params)

    def test_all_clients(self):
        """Test each client on global test set and return average"""
        client_accuracies = []
        all_labels_combined = []
        all_preds_combined = []
        
        for idx, client_model in enumerate(self.client_models):
            client_model.eval()
            client_labels = []
            client_preds = []
            
            with torch.no_grad():
                for images, labels in self.test_loader:
                    images, labels = images.to(self.args.device), labels.to(self.args.device)
                    labels = labels.squeeze()
                    
                    outputs = client_model(images)
                    _, predicted = torch.max(outputs, 1)
                    client_labels.extend(labels.cpu().numpy())
                    client_preds.extend(predicted.cpu().numpy())
            
            client_labels = np.array(client_labels)
            client_preds = np.array(client_preds)
            client_acc = 100 * np.mean(client_labels == client_preds)
            client_accuracies.append(client_acc)
            
            all_labels_combined.extend(client_labels)
            all_preds_combined.extend(client_preds)
        
        # Average accuracy across all clients
        avg_acc = np.mean(client_accuracies)
        
        # Compute metrics on combined predictions
        all_labels_combined = np.array(all_labels_combined)
        all_preds_combined = np.array(all_preds_combined)
        
        precision = precision_score(all_labels_combined, all_preds_combined, average='macro', zero_division=0)
        recall = recall_score(all_labels_combined, all_preds_combined, average='macro', zero_division=0)
        f1 = f1_score(all_labels_combined, all_preds_combined, average='macro', zero_division=0)
        rmse = np.sqrt(mean_squared_error(all_labels_combined, all_preds_combined))
        
        print(f" Client Avg Test ‚Äî Acc: {avg_acc:.2f}% | Prec: {precision:.3f} | Recall: {recall:.3f} | F1: {f1:.3f} | RMSE: {rmse:.3f}")
        print(f"   Individual: {[f'{acc:.1f}%' for acc in client_accuracies]}")
        
        return {
            'accuracy': avg_acc,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'rmse': rmse
        }

    def save_checkpoint(self, round_num, metrics):
        """Save checkpoint"""
        checkpoint = {
            'round': round_num,
            'server_state_dict': self.global_model.state_dict(),
            'client_state_dicts': [model.state_dict() for model in self.client_models],
            'best_global_acc': self.best_global_acc,
            'history': self.history,
            'args': vars(self.args)
        }
        
        
        
        latest_path = os.path.join(self.args.checkpoint_dir, 'checkpoint_latest.pth')
        torch.save(checkpoint, latest_path)
        
    
        history_path = os.path.join(self.args.checkpoint_dir, 'training_history.json')
        with open(history_path, 'w') as f:
            json.dump(self.history, f, indent=4)

    def run(self):
        """Main training loop"""
        start_round = self.current_round
        
        separator_line = "=" * 80
        dash_line = "-" * 80
        
        print(f"\n{separator_line}")
        print(" STARTING FEDERATED TRAINING (FedPer)")
        print(separator_line)
        print(f"Starting from Round: {start_round + 1}/{self.args.r}")
        print(f"Current Best Accuracy: {self.best_global_acc:.2f}%")
        
        # Check shared params
        shared_params = self.global_model.get_shared_params()
        print(f"Shared parameters: {len(shared_params)} tensors")
        print(f"{separator_line}\n")
        
        for r in range(start_round, self.args.r):
            print(f"\n{dash_line}")
            print(f"Round {r+1}/{self.args.r}")
            print(dash_line)
            
            m = max(int(self.args.C * self.args.K), 1)
            selected_clients = random.sample(range(self.args.K), m)
            print(f"Selected {m} clients: {selected_clients}")
            
            self.dispatch(selected_clients)
            
            for idx in selected_clients:
                print(f"  Training Client {idx}...", end=" ")
                self.client_update(idx)
                print("‚úì")
            
            self.aggregate(selected_clients)
            
            # Test using client models (FedPer approach)
            metrics = self.test_all_clients()
            avg_acc = metrics['accuracy']
            
            self.history['rounds'].append(r+1)
            self.history['avg_accuracy'].append(avg_acc)
            self.history['best_accuracy'].append(max(self.best_global_acc, avg_acc))
            self.history['precision'].append(metrics['precision'])
            self.history['recall'].append(metrics['recall'])
            self.history['f1'].append(metrics['f1'])
            self.history['rmse'].append(metrics['rmse'])
            
            if avg_acc > self.best_global_acc:
                self.best_global_acc = avg_acc
            
            self.current_round = r + 1
            
            self.save_checkpoint(r+1, metrics)
            
            print(f"\n Round {r+1} Results:")
            print(f"   Accuracy: {avg_acc:.2f}%")
            print(f"   Best Accuracy: {self.best_global_acc:.2f}%")
            print(f"   Precision: {metrics['precision']:.3f}")
            print(f"   Recall: {metrics['recall']:.3f}")
            print(f"   F1-Score: {metrics['f1']:.3f}")
            print(f"   RMSE: {metrics['rmse']:.3f}")
        
        print(f"\n{separator_line}")
        print(" TRAINING COMPLETE")
        print(separator_line)
        print(f"Best Accuracy Achieved: {self.best_global_acc:.2f}%")
        print(f"{separator_line}\n")
        
        return self.global_model
'''

with open(file_path, "w") as f:
    f.write(fixed_server_code)

print("‚úÖ server.py FIXED!")

‚úÖ server.py FIXED!


In [7]:
file_path = "/kaggle/working/FedPer-PyTorch/main.py"

new_code = r"""
# ========================================
# main.py ‚Äî Run FedPer with Resume Support (FedProx style auto-resume)
# ========================================
from args import args_parser
from server import FedPerServer
from get_data import load_medmnist_data
import torch
import os
import json
import matplotlib.pyplot as plt

def plot_training_history(history, save_path):
    plt.figure(figsize=(12, 5))
    
    # Plot 1: Average Accuracy
    plt.subplot(1, 2, 1)
    plt.plot(history['rounds'], history['avg_accuracy'], 'b-', label='Avg Accuracy', linewidth=2)
    plt.plot(history['rounds'], history['best_accuracy'], 'r--', label='Best Accuracy', linewidth=2)
    plt.xlabel('Communication Round')
    plt.ylabel('Accuracy (%)')
    plt.title('Federated Learning Accuracy Progress')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot 2: Accuracy Improvement
    plt.subplot(1, 2, 2)
    if len(history['avg_accuracy']) > 1:
        improvements = [history['avg_accuracy'][i] - history['avg_accuracy'][i-1] 
                        for i in range(1, len(history['avg_accuracy']))]
        plt.bar(history['rounds'][1:], improvements, alpha=0.7)
        plt.axhline(y=0, color='r', linestyle='-', linewidth=0.5)
        plt.xlabel('Communication Round')
        plt.ylabel('Accuracy Change (%)')
        plt.title('Round-to-Round Accuracy Change')
        plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f' Training plot saved at: {save_path}')

def main():
    # Load arguments from args.py
    args = args_parser()
    
    # Print configuration
    print('\\n' + '='*60)
    print(' FEDERATED LEARNING CONFIGURATION (FedPer)')
    print('='*60)
    print(f'Dataset: {args.dataset.upper()}')
    print(f'Clients: {args.K} | Rounds: {args.r} | Local Epochs: {args.E}')
    print(f'Batch Size: {args.B} | Learning Rate: {args.lr}')
    print(f'CLIP Model: {args.clip_model} | Device: {args.device}')
    print(f'Non-IID: {args.dominant_ratio*100:.0f}% dominant class per client')
    print(f'Personalized Layers (Kp): {args.Kp}')
    print('='*60)
    
    # Load client loaders & test loader
    client_loaders, test_loader = load_medmnist_data(args)
    
    # Check for existing checkpoint (FedProx style auto-resume)
    latest_checkpoint = os.path.join(args.checkpoint_dir, 'checkpoint_latest.pth')
    resume_from = None
    if os.path.exists(latest_checkpoint):
        try:
            checkpoint = torch.load(latest_checkpoint, map_location=args.device, weights_only=False)
            completed_rounds = checkpoint.get('round', 0)
            best_acc = checkpoint.get('best_global_acc', 0)
            print('\\n' + '='*60)
            print(' CHECKPOINT DETECTED ‚Äî auto-resuming!')
            print('='*60)
            print(f' Completed Rounds: {completed_rounds}/{args.r}')
            print(f' Best Accuracy: {best_acc:.2f}%')
            resume_from = latest_checkpoint
        except Exception as e:
            print(f' Error loading checkpoint: {e}')
            print(' Starting fresh training...')
    else:
        print('\\n' + '='*60)
        print(' No checkpoint found. Starting fresh training...')
        print('='*60)
    
    # Initialize FedPer server
    server = FedPerServer(args, resume_from=resume_from)
    
    # Run federated training
    final_model = server.run()
    
    # Save final global model (shared layers)
    final_model_path = os.path.join(args.checkpoint_dir, 'final_global_model.pth')
    torch.save(final_model.state_dict(), final_model_path)
    print(f'\\n Final global model (shared layers) saved at: {final_model_path}')
    
    # Save best model if exists
    best_model_path = os.path.join(args.checkpoint_dir, 'best_model.pth')
    if os.path.exists(best_model_path):
        print(f' Best global model saved at: {best_model_path}')
    
    # Save training history
    history_path = os.path.join(args.checkpoint_dir, 'training_history.json')
    with open(history_path, 'w') as f:
        json.dump(server.history, f, indent=4)
    print(f' Training history saved at: {history_path}')
    
    # Plot training curves
    plot_path = os.path.join(args.checkpoint_dir, 'training_plot.png')
    plot_training_history(server.history, plot_path)
    
    # Print summary
    print('\\n' + '='*60)
    print(' TRAINING SUMMARY (FedPer)')
    print('='*60)
    print(f' Best Global Accuracy: {server.best_global_acc:.2f}%')
    print(f' Final Round Accuracy: {server.history["avg_accuracy"][-1]:.2f}%')
    print(f' Total Improvement: {server.history["avg_accuracy"][-1] - server.history["avg_accuracy"][0]:.2f}%')
    print(f' All checkpoints saved in: {args.checkpoint_dir}')
    print('='*60)

if __name__ == '__main__':
    main()

"""

with open(file_path, "w") as f:
    f.write(new_code)

print("‚úÖ main.py updated!")


‚úÖ main.py updated!


In [8]:
!pip install medmnist --quiet

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m115.9/115.9 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m363.4/363.4 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m13.8/13.8 MB[0m [31m85.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m24.6/24.6 MB[0m [31m67.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m883

In [9]:
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-iky5rbah
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-iky5rbah
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m44.8/44.8 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369490 sha256=66c302bbfcfb97f3f5d553f6550aefb4196c0d65037fc2e7b569c66b5a4fa630
  Stored in 

In [None]:
!python -u /kaggle/working/FedPer-PyTorch/main.py

 FEDERATED LEARNING CONFIGURATION (FedPer)
Dataset: ORGANAMNIST
Clients: 5 | Rounds: 50 | Local Epochs: 5
Batch Size: 32 | Learning Rate: 0.003
CLIP Model: ViT-B/32 | Device: cuda
Non-IID: 70% dominant class per client
Personalized Layers (Kp): 2

 Loading ORGANAMNIST Dataset for FedPer
Number of classes: 11

 Loading dataset splits...
  Attempt 1/3: Downloading train split...
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38.2M/38.2M [00:17<00:00, 2.13MB/s]
   Successfully loaded train split
  Attempt 1/3: Downloading val split...
   Successfully loaded val split
  Attempt 1/3: Downloading test split...
   Successfully loaded test split
 Using Train+Val: 41052 samples for federated learning

 Creating balanced Non-IID split (dominant_ratio=0.7)...
 Split cached to: ./data/medmnist/client_indices_organamnist_K5_dr0.7_fedper.pkl

 Client-wise Image Distribution
Client  0:  8151 samples | Dominant: Class 0 (19.5%)
