How to Use This Version for Checkpoint Resume in Kaggle

Before running this notebook for the second time (to resume training from a checkpoint), follow the steps below carefully.

1. Download the latest checkpoint file
Download the file named checkpoint_latest.pth from the previous version of your notebook or experiment.

2. Upload the checkpoint to Kaggle Input Directory
Place the downloaded file inside your Kaggle input path, for example:
/kaggle/input/path1/pytorch/default/1/checkpoint_latest.pth

3. Run the following code cell before starting training
This code will copy the checkpoint file to the working directory (/kaggle/working/checkpoints) so that training can resume from the saved state.

4. Resume Training
After the checkpoint file is copied successfully, running the rest of the notebook will automatically start training from the previous checkpoint instead of starting from scratch.



## you can change the dataset and the attack type by simply changing the name in the args.py file.no need to modify anything else.

In [None]:
import os
import shutil

# Source and destination paths
src = "/kaggle/input/path1/pytorch/default/1/checkpoint_latest.pth"
dst_dir = "/kaggle/working/checkpoints"
dst = os.path.join(dst_dir, "checkpoint_latest.pth")

# Step 1: Check if source file exists
if not os.path.exists(src):
    print(f"‚ùå Source file not found: {src}")
else:
    print(f"‚úÖ Found source file: {src}")

    # Step 2: Ensure destination directory exists
    if not os.path.exists(dst_dir):
        os.makedirs(dst_dir)
        print(f"üìÇ Created destination directory: {dst_dir}")
    else:
        print(f"üìÅ Destination directory already exists: {dst_dir}")

    # Step 3: Copy the file
    shutil.copy(src, dst)
    print(f"‚úÖ Copied file to: {dst}")

    # Step 4: List all files in destination
    files = os.listdir(dst_dir)
    if files:
        print("\nüìÑ Files in /kaggle/working/checkpoints:")
        for f in files:
            print(" ‚îú‚îÄ‚îÄ", f)
    else:
        print("‚ö† Destination directory is empty (unexpected).")


In [1]:
import os

# Target directory
base_dir = "/kaggle/working/FedProx-PyTorch"
os.makedirs(base_dir, exist_ok=True)

# Python files to create
files = ["main.py", "server.py", "client.py", "model.py", "get_data.py", "args.py"]

# Create each file if not exists
for file in files:
    file_path = os.path.join(base_dir, file)
    if not os.path.exists(file_path):
        with open(file_path, "w") as f:
            f.write(f"# {file} ‚Äî auto-created placeholder\n")
        print(f" Created: {file_path}")
    else:
        print(f" Already exists: {file_path}")

print("\n Folder and files ready in:", base_dir)


 Created: /kaggle/working/FedProx-PyTorch/main.py
 Created: /kaggle/working/FedProx-PyTorch/server.py
 Created: /kaggle/working/FedProx-PyTorch/client.py
 Created: /kaggle/working/FedProx-PyTorch/model.py
 Created: /kaggle/working/FedProx-PyTorch/get_data.py
 Created: /kaggle/working/FedProx-PyTorch/args.py

 Folder and files ready in: /kaggle/working/FedProx-PyTorch


In [2]:
file_path = "/kaggle/working/FedProx-PyTorch/args.py"

new_code = '''
# ========================================
# args.py ‚Äî FedProx Configuration with Dataset Auto-config
# ========================================
import argparse
import torch

# Dataset configurations
DATASET_CONFIGS = {
    'pathmnist': {
        'num_classes': 9,
        'class_names': [
            "adipose tissue", "background", "debris", "lymphocytes",
            "mucus", "smooth muscle", "normal colon mucosa",
            "cancer-associated stroma", "colorectal adenocarcinoma epithelium"
        ],
        'input_channels': 3
    },
    'tissuemnist': {
        'num_classes': 8,
        'class_names': [
            "collecting duct", "distal convoluted tubule",
            "glomerular endothelial cells", "interstitial endothelial cells",
            "leukocytes", "podocytes", "proximal tubule", "thick ascending limb"
        ],
        'input_channels': 1
    },
    'organamnist': {
        'num_classes': 11,
        'class_names': [
            "bladder", "femur-left", "femur-right", "heart",
            "kidney-left", "kidney-right", "liver",
            "lung-left", "lung-right", "spleen", "pelvis"
        ],
        'input_channels': 1
    },
    'octmnist': {
        'num_classes': 4,
        'class_names': [
            "choroidal neovascularization", "diabetic macular edema",
            "drusen", "normal"
        ],
        'input_channels': 1
    }
}

def args_parser():
    parser = argparse.ArgumentParser(description="FedProx - Config File")
    
    # -------------------------------
    # Dataset Selection
    # -------------------------------
    parser.add_argument('--dataset', type=str, default='organamnist',
                        choices=list(DATASET_CONFIGS.keys()),
                        help='Select MedMNIST dataset')
    
    # -------------------------------
    # Federated Learning Parameters
    # -------------------------------
    parser.add_argument('--E', type=int, default=5, help='local epochs')
    parser.add_argument('--r', type=int, default=50, help='communication rounds')
    parser.add_argument('--K', type=int, default=5, help='total number of clients')
    parser.add_argument('--C', type=float, default=1, help='client sampling rate per round')
    parser.add_argument('--B', type=int, default=32, help='batch size')
    parser.add_argument('--use_combined', action='store_true', help='Use train+val+test combined')

    # -------------------------------
    # Model parameters
    # -------------------------------
    parser.add_argument('--clip_model', type=str, default='ViT-B/32', help='CLIP model variant')
    parser.add_argument('--freeze_clip', action='store_true', help='Freeze CLIP backbone')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate')

    # -------------------------------
    # Optimizer Settings
    # -------------------------------
    parser.add_argument('--lr', type=float, default=0.003, help='learning rate')
    parser.add_argument('--optimizer', type=str, default='sgd', help='optimizer')
    parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum for SGD')

    # -------------------------------
    # FedProx Specific Parameter
    # -------------------------------
    parser.add_argument('--mu', type=float, default=0.05, help='FedProx proximal term coefficient')

    # -------------------------------
    # Checkpoint Settings
    # -------------------------------
    parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints', help='directory for checkpoints')
    
    # -------------------------------
    # Device
    # -------------------------------
    parser.add_argument('--device', default=torch.device("cuda" if torch.cuda.is_available() else "cpu"), help='cuda or cpu')

    # -------------------------------
    # Balanced Non-IID
    # -------------------------------
    parser.add_argument('--dominant_ratio', type=float, default=0.7, help='Fraction of dominant class per client (0-1)')

    args = parser.parse_args(args=[])

    # Auto-configure dataset
    if args.dataset in DATASET_CONFIGS:
        config = DATASET_CONFIGS[args.dataset]
        args.num_classes = config['num_classes']
        args.input_channels = config['input_channels']
        args.class_names = config['class_names']
    else:
        raise ValueError(f"Dataset {args.dataset} not configured!")

    return args

'''

with open(file_path, "w") as f:
    f.write(new_code)

print(" args.py updated!")



 args.py updated!


In [3]:
file_path = "/kaggle/working/FedProx-PyTorch/get_data.py"

new_code = r'''
# ========================================
# get_data.py ‚Äî FedProx MedMNIST Loader with True Balanced Non-IID
# Supports: pathmnist, tissuemnist, organamnist, octmnist
# ========================================
import os
import numpy as np
import torch
import pickle
from torch.utils.data import DataLoader, Subset, ConcatDataset
from torchvision import transforms
from medmnist.dataset import PathMNIST, TissueMNIST, OrganAMNIST, OCTMNIST

DATASET_MAP = {
    'pathmnist': PathMNIST,
    'tissuemnist': TissueMNIST,
    'organamnist': OrganAMNIST,
    'octmnist': OCTMNIST
}

def get_transforms():
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    return train_transform, test_transform

def balanced_noniid_split(dataset, num_clients, dominant_ratio=0.7):
    labels = np.array([dataset[i][1].item() for i in range(len(dataset))])
    num_classes = len(np.unique(labels))
    client_indices = [[] for _ in range(num_clients)]

    class_indices = {c: np.where(labels==c)[0].tolist() for c in range(num_classes)}
    for c in class_indices:
        np.random.shuffle(class_indices[c])

    # dominant_ratio to main client
    for client_id in range(num_clients):
        dominant_class = client_id % num_classes
        n_dom = int(len(class_indices[dominant_class]) * dominant_ratio)
        client_indices[client_id].extend(class_indices[dominant_class][:n_dom])
        class_indices[dominant_class] = class_indices[dominant_class][n_dom:]

    # remaining samples distributed
    for c in range(num_classes):
        rem = class_indices[c]
        other_clients = [i for i in range(num_clients) if i%num_classes != c]
        for i, idx in enumerate(rem):
            client_indices[other_clients[i % len(other_clients)]].append(idx)

    for i in range(num_clients):
        np.random.shuffle(client_indices[i])
    return client_indices

def load_medmnist_data(args):
    train_transform, test_transform = get_transforms()
    data_root = './data/medmnist'
    os.makedirs(data_root, exist_ok=True)

    DatasetClass = DATASET_MAP[args.dataset.lower()]

    train_dataset = DatasetClass(root=data_root, split='train', download=True, transform=train_transform)
    val_dataset = DatasetClass(root=data_root, split='val', download=True, transform=train_transform)
    test_dataset = DatasetClass(root=data_root, split='test', download=True, transform=test_transform)

    combined_train = ConcatDataset([train_dataset, val_dataset])

    cache_file = f'./data/medmnist/client_indices_{args.dataset}_K{args.K}_dr{args.dominant_ratio}_fedprox.pkl'
    if os.path.exists(cache_file):
        with open(cache_file,'rb') as f:
            client_indices = pickle.load(f)
    else:
        client_indices = balanced_noniid_split(combined_train, args.K, dominant_ratio=args.dominant_ratio)
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        with open(cache_file,'wb') as f:
            pickle.dump(client_indices,f)

    client_loaders = [DataLoader(Subset(combined_train, idx), batch_size=args.B, shuffle=True, num_workers=0, pin_memory=True)
                      for idx in client_indices]
    test_loader = DataLoader(test_dataset, batch_size=args.B, shuffle=False, num_workers=0, pin_memory=True)
    return client_loaders, test_loader

'''

with open(file_path, "w") as f:
    f.write(new_code)

print(" get_data.py updated!")


 get_data.py updated!


In [4]:
!pip install torch torchvision ftfy regex tqdm

Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-

In [5]:
file_path = "/kaggle/working/FedProx-PyTorch/client.py"

new_code = """
# ========================================
# client.py ‚Äî Client-side helper for FedProx (with Proximal Term + tqdm)
# ========================================

import torch
from torch import nn
import torch.optim as optim
from tqdm import tqdm
import copy


class Client:
    def __init__(self, model, train_loader, device, val_loader=None, lr=0.0001, 
                 weight_decay=5e-4, mu=0.01, global_params=None):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.lr = lr
        self.weight_decay = weight_decay
        self.mu = mu
        
        if global_params is not None:
            self.global_params = [param.clone().detach() for param in global_params]
        else:
            self.global_params = None

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(
            self.model.parameters(), 
            lr=self.lr, 
            momentum=0.9, 
            weight_decay=self.weight_decay
        )

    def compute_accuracy(self, loader):
        self.model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in loader:
                images = images.to(self.device)
                labels = labels.to(self.device).squeeze()
                
                outputs = self.model(images)
                if isinstance(outputs, tuple) or hasattr(outputs, 'logits'):
                    outputs = outputs.logits
                
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        return 100 * correct / total if total > 0 else 0

    def train(self, epochs=1):
        model_name = self.model.name
        mu_value = self.mu
        print(f"\\n Training {model_name} with FedProx objective (Œº={mu_value})...")
        
        for epoch in range(epochs):
            self.model.train()
            total_loss = 0.0

            pbar = tqdm(
                self.train_loader,
                desc=f"  Epoch {epoch+1}/{epochs}",
                ncols=100,
                leave=False
            )
            
            for images, labels in pbar:
                images = images.to(self.device)
                labels = labels.to(self.device).squeeze()

                self.optimizer.zero_grad()
                outputs = self.model(images)
                
                if isinstance(outputs, tuple) or hasattr(outputs, 'logits'):
                    outputs = outputs.logits
                
                loss = self.criterion(outputs, labels)

                # FedProx proximal term
                if self.global_params is not None:
                    prox_term = 0.0
                    for param, global_param in zip(self.model.parameters(), self.global_params):
                        prox_term += torch.norm(param - global_param.to(self.device)) ** 2
                    loss += (self.mu / 2) * prox_term

                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})

            avg_loss = total_loss / len(self.train_loader)
            train_acc = self.compute_accuracy(self.train_loader)
            log_msg = f"  Epoch {epoch+1:2d}/{epochs} | Loss: {avg_loss:.4f} | Train Acc: {train_acc:.2f}%"

            if self.val_loader is not None:
                val_acc = self.compute_accuracy(self.val_loader)
                val_loss = 0
                self.model.eval()
                with torch.no_grad():
                    for images, labels in self.val_loader:
                        images = images.to(self.device)
                        labels = labels.to(self.device).squeeze()
                        
                        outputs = self.model(images)
                        if isinstance(outputs, tuple) or hasattr(outputs, 'logits'):
                            outputs = outputs.logits
                        
                        loss = self.criterion(outputs, labels)
                        val_loss += loss.item()
                
                val_loss /= len(self.val_loader)
                log_msg += f" | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%"

            print(log_msg)

        final_msg = f" {model_name} local training complete (FedProx)"
        print(final_msg)
        print("")
        return self.model.state_dict()
"""

with open(file_path, "w") as f:
    f.write(new_code)
    
print("client.py updated!")


client.py updated!


In [6]:
# First, install CLIP


# Then update model.py
file_path = "/kaggle/working/FedProx-PyTorch/model.py"

new_code = '''
# ========================================
# model.py ‚Äî CLIP-based FedProx Classifier (Fixed dtype)
# ========================================
import torch
from torch import nn
import clip

class CLIPFedProxClassifier(nn.Module):
    """
    CLIP-based classifier for FedProx
    Shared CLIP backbone + trainable head
    """

    def __init__(self, args, name='clip_fedprox_model'):
        super().__init__()
        self.name = name
        self.num_classes = args.num_classes
        self.dropout = getattr(args, 'dropout', 0.5)

        # Load CLIP
        self.clip_model, self.preprocess = clip.load(args.clip_model, device=args.device)
        # Freeze CLIP backbone
        for param in self.clip_model.parameters():
            param.requires_grad = False

        self.feature_dim = self.clip_model.visual.output_dim
        self.class_names = args.class_names

        # Trainable head
        self.head = nn.Sequential(
            nn.Linear(self.feature_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(self.dropout),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(self.dropout * 0.7),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(self.dropout * 0.5),

            nn.Linear(256, self.num_classes)
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.head.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, images):
        with torch.no_grad():
            feats = self.clip_model.encode_image(images)
            feats = feats / feats.norm(dim=-1, keepdim=True)
        
        # **FIX: Convert features to float32 to match head dtype**
        feats = feats.float()
        
        logits = self.head(feats)
        return logits

    def get_trainable_params(self):
        return [p for p in self.head.parameters() if p.requires_grad]

'''

with open(file_path, "w") as f:
    f.write(new_code)

print(" model.py created!")
print(f" File saved at: {file_path}")

 model.py created!
 File saved at: /kaggle/working/FedProx-PyTorch/model.py


In [7]:
file_path = "/kaggle/working/FedProx-PyTorch/server.py"

new_code = """
# ========================================
# server.py ‚Äî FedProx Server with Resume Support
# ========================================
import copy
import random
import numpy as np
import torch
import os
import json
from model import CLIPFedProxClassifier as ImageClassifier
from sklearn.metrics import precision_score, recall_score, f1_score, mean_squared_error
from get_data import load_medmnist_data
from client import Client

class FedProxServer:
    def __init__(self, args, resume_from=None):
        self.args = args
        os.makedirs(args.checkpoint_dir, exist_ok=True)
        self.current_round = 0
        self.best_global_acc = 0
        self.history = {'rounds': [], 'avg_accuracy': [], 'best_accuracy': []}
        self.global_model = ImageClassifier(args, name="server").to(args.device)
        self.client_models = []
        for i in range(self.args.K):
            model = copy.deepcopy(self.global_model)
            model.name = f"Client_{i}"
            self.client_models.append(model)
        self.client_loaders, self.test_loader = load_medmnist_data(args)
        if resume_from:
            self.load_checkpoint(resume_from)

    def load_checkpoint(self, checkpoint_path):
        print(f"\\n Loading checkpoint from: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=self.args.device, weights_only=False)
        self.global_model.load_state_dict(checkpoint['server_state_dict'])
        if 'client_state_dicts' in checkpoint:
            for i, state_dict in enumerate(checkpoint['client_state_dicts']):
                self.client_models[i].load_state_dict(state_dict)
        self.current_round = checkpoint.get('round', 0)
        self.best_global_acc = checkpoint.get('best_global_acc', 0)
        self.history = checkpoint.get('history', {'rounds': [], 'avg_accuracy': [], 'best_accuracy': []})

    def dispatch(self, selected_clients):
        for idx in selected_clients:
            client_model = self.client_models[idx]
            for client_param, global_param in zip(client_model.parameters(), self.global_model.parameters()):
                client_param.data = global_param.data.clone()

    def aggregate(self, selected_clients):
        total_samples = sum([len(self.client_loaders[idx].dataset) for idx in selected_clients])
        global_params = {k: torch.zeros_like(v.data) for k, v in self.global_model.named_parameters()}
        for idx in selected_clients:
            weight = len(self.client_loaders[idx].dataset) / total_samples
            client_params = dict(self.client_models[idx].named_parameters())
            for k in global_params.keys():
                global_params[k] += client_params[k].data * weight
        for k, v in self.global_model.named_parameters():
            v.data = global_params[k].data.clone()

    def client_update(self, idx):
        client_model = self.client_models[idx]
        client_loader = self.client_loaders[idx]
        client_obj = Client(model=client_model,
                            train_loader=client_loader,
                            device=self.args.device,
                            lr=self.args.lr,
                            weight_decay=self.args.weight_decay,
                            mu=self.args.mu,
                            global_params=self.global_model.parameters())
        client_obj.train(epochs=self.args.E)

    def test_global_model(self):
        self.global_model.eval()
        all_labels = []
        all_preds = []
        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                labels = labels.squeeze()
                outputs = self.global_model(images)
                if isinstance(outputs, tuple) or hasattr(outputs, 'logits'):
                    outputs = outputs.logits
                _, predicted = torch.max(outputs, 1)
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(predicted.cpu().numpy())
        all_labels = np.array(all_labels)
        all_preds = np.array(all_preds)
        acc = 100 * np.mean(all_labels == all_preds)
        precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
        recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)
        f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
        rmse = np.sqrt(mean_squared_error(all_labels, all_preds))
        print(f" Global Test ‚Äî Acc: {acc:.2f}% | Prec: {precision:.3f} | Recall: {recall:.3f} | F1: {f1:.3f} | RMSE: {rmse:.3f}")
        return acc
    

    def save_checkpoint(self, round_num, avg_acc):
        checkpoint = {
            'round': round_num,
            'server_state_dict': self.global_model.state_dict(),
            'client_state_dicts': [model.state_dict() for model in self.client_models],
            'best_global_acc': self.best_global_acc,
            'history': self.history,
            'args': vars(self.args)
        }
        
        torch.save(checkpoint, os.path.join(self.args.checkpoint_dir, 'checkpoint_latest.pth'))
        
        with open(os.path.join(self.args.checkpoint_dir, 'training_history.json'), 'w') as f:
            json.dump(self.history, f, indent=4)

    def run(self):
        start_round = self.current_round
        for r in range(start_round, self.args.r):
            m = max(int(self.args.C * self.args.K), 1)
            selected_clients = random.sample(range(self.args.K), m)
            self.dispatch(selected_clients)
            for idx in selected_clients:
                self.client_update(idx)
            self.aggregate(selected_clients)
            avg_acc = self.test_global_model()
            self.history['rounds'].append(r+1)
            self.history['avg_accuracy'].append(avg_acc)
            self.history['best_accuracy'].append(max(self.best_global_acc, avg_acc))
            if avg_acc > self.best_global_acc:
                self.best_global_acc = avg_acc
            self.current_round = r + 1
            self.save_checkpoint(r+1, avg_acc)
        return self.global_model

"""

with open(file_path, "w") as f:
    f.write(new_code)

print(" server.py created!")

 server.py created!


In [8]:
file_path = "/kaggle/working/FedProx-PyTorch/main.py"

new_code = """
# ========================================
# main.py ‚Äî Run FedProx with Resume Support
# ========================================
from args import args_parser
from server import FedProxServer
from get_data import load_medmnist_data
import torch
import os
import json
import matplotlib.pyplot as plt

def plot_training_history(history, save_path):
    plt.figure(figsize=(12, 5))
    
    # Plot 1: Average Accuracy
    plt.subplot(1, 2, 1)
    plt.plot(history['rounds'], history['avg_accuracy'], 'b-', label='Avg Accuracy', linewidth=2)
    plt.plot(history['rounds'], history['best_accuracy'], 'r--', label='Best Accuracy', linewidth=2)
    plt.xlabel('Communication Round')
    plt.ylabel('Accuracy (%)')
    plt.title('Federated Learning Accuracy Progress')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot 2: Accuracy Improvement
    plt.subplot(1, 2, 2)
    if len(history['avg_accuracy']) > 1:
        improvements = [history['avg_accuracy'][i] - history['avg_accuracy'][i-1] 
                        for i in range(1, len(history['avg_accuracy']))]
        plt.bar(history['rounds'][1:], improvements, alpha=0.7)
        plt.axhline(y=0, color='r', linestyle='-', linewidth=0.5)
        plt.xlabel('Communication Round')
        plt.ylabel('Accuracy Change (%)')
        plt.title('Round-to-Round Accuracy Change')
        plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f'  Training plot saved at: {save_path}')

def main():
    # Load arguments from args.py
    args = args_parser()
    
    # Print configuration
    print('\\n' + '='*60)
    print(' FEDERATED LEARNING CONFIGURATION')
    print('='*60)
    print(f'Dataset: {args.dataset.upper()}')
    print(f'Clients: {args.K} | Rounds: {args.r} | Local Epochs: {args.E}')
    print(f'Batch Size: {args.B} | Learning Rate: {args.lr}')
    print(f'FedProx Œº: {args.mu} | Weight Decay: {args.weight_decay}')
    print(f'CLIP Model: {args.clip_model} | Device: {args.device}')
    print(f'Non-IID: {args.dominant_ratio*100:.0f}% dominant class per client')
    print('='*60)
    
    # Load client loaders & test loader
    client_loaders, test_loader = load_medmnist_data(args)
    
    # Check for existing checkpoint
    latest_checkpoint = os.path.join(args.checkpoint_dir, 'checkpoint_latest.pth')
    resume_from = None
    
    if os.path.exists(latest_checkpoint):
        print('\\n' + '='*60)
        print(' CHECKPOINT FOUND!')
        print('='*60)
        try:
            checkpoint = torch.load(latest_checkpoint, map_location=args.device)
            completed_rounds = checkpoint.get('round', 0)
            best_acc = checkpoint.get('best_global_acc', 0)
            print(f' Checkpoint Details:')
            print(f'   - Completed Rounds: {completed_rounds}/{args.r}')
            print(f'   - Best Accuracy: {best_acc:.2f}%')
            user_input = input('\\n  Resume from checkpoint? (y/n): ').strip().lower()
            if user_input == 'y':
                resume_from = latest_checkpoint
                print(' Resuming from checkpoint...')
            else:
                print(' Starting fresh training...')
        except Exception as e:
            print(f' Error loading checkpoint: {e}')
            print(' Starting fresh training...')
    else:
        print('\\n' + '='*60)
        print(' No checkpoint found. Starting fresh training...')
        print('='*60)
    
    # Initialize FedProx server
    server = FedProxServer(args, resume_from=resume_from)
    
    # Run federated training
    final_model = server.run()
    
    # Save final global model
    final_model_path = os.path.join(args.checkpoint_dir, 'final_global_model.pth')
    torch.save(final_model.state_dict(), final_model_path)
    print(f'\\n Final global model saved at: {final_model_path}')
    
    # Save best model if exists
    best_model_path = os.path.join(args.checkpoint_dir, 'best_model.pth')
    if os.path.exists(best_model_path):
        print(f' Best global model saved at: {best_model_path}')
    
    # Save training history
    history_path = os.path.join(args.checkpoint_dir, 'training_history.json')
    with open(history_path, 'w') as f:
        json.dump(server.history, f, indent=4)
    print(f' Training history saved at: {history_path}')
    
    # Plot training curves
    plot_path = os.path.join(args.checkpoint_dir, 'training_plot.png')
    plot_training_history(server.history, plot_path)
    
    # Print summary
    print('\\n' + '='*60)
    print(' TRAINING SUMMARY')
    print('='*60)
    print(f' Best Global Accuracy: {server.best_global_acc:.2f}%')
    print(f' Final Round Accuracy: {server.history["avg_accuracy"][-1]:.2f}%')
    print(f' Total Improvement: {server.history["avg_accuracy"][-1] - server.history["avg_accuracy"][0]:.2f}%')
    print(f' All checkpoints saved in: {args.checkpoint_dir}')
    print('='*60)

if __name__ == '__main__':
    main()
"""

with open(file_path, "w") as f:
    f.write(new_code)

print(" main.py created!")

 main.py created!


In [9]:
!pip install git+https://github.com/openai/CLIP.git
import clip
print(clip.available_models())


Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-lv65amdl
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-lv65amdl
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369490 sha256=d0274ad70ca7c4a00b7d841a790186fabc947ca0e2fac08308c4d53015a44669
  Stored in directory: /tmp/pip-ephem-wheel-cache-xhwyukhi/wheels/3f/7c/a4/9b490845988bf7a4db33674d52f709f088f64392063872eb9a
Successfully built clip
Installing collected packages: clip
Successfully installed clip-1.0
['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']


In [10]:
!pip install medmnist --quiet


[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m115.9/115.9 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
!python /kaggle/working/FedProx-PyTorch/main.py



 FEDERATED LEARNING CONFIGURATION
Dataset: ORGANAMNIST
Clients: 5 | Rounds: 50 | Local Epochs: 5
Batch Size: 32 | Learning Rate: 0.003
FedProx Œº: 0.05 | Weight Decay: 0.0005
CLIP Model: ViT-B/32 | Device: cuda
Non-IID: 70% dominant class per client
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 38.2M/38.2M [00:25<00:00, 1.52MB/s]

 No checkpoint found. Starting fresh training...
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 338M/338M [00:03<00:00, 91.9MiB/s]

 Training Client_4 with FedProx objective (Œº=0.05)...
  Epoch  1/5 | Loss: 1.5745 | Train Acc: 72.98%                                                     
  Epoch  2/5 | Loss: 1.2934 | Train Acc: 77.78%                                                     
  Epoch  3/5 | Loss: 1.2739 | Train Acc: 79.50%                                                     
  Epoch  4/5 | Loss: 1.2