In [1]:
!pip install torch pandas scikit-learn

Collecting scikit-learn
  Downloading scikit_learn-1.7.0-cp310-cp310-win_amd64.whl.metadata (14 kB)
Collecting joblib>=1.2.0 (from scikit-learn)
  Downloading joblib-1.5.1-py3-none-any.whl.metadata (5.6 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn)
  Downloading threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)
Downloading scikit_learn-1.7.0-cp310-cp310-win_amd64.whl (10.7 MB)
   ---------------------------------------- 0.0/10.7 MB ? eta -:--:--
   ----------- ---------------------------- 3.1/10.7 MB 14.2 MB/s eta 0:00:01
   --------------------- ------------------ 5.8/10.7 MB 13.0 MB/s eta 0:00:01
   ----------------------------- ---------- 7.9/10.7 MB 12.2 MB/s eta 0:00:01
   -------------------------------------- - 10.2/10.7 MB 12.0 MB/s eta 0:00:01
   ---------------------------------------- 10.7/10.7 MB 10.5 MB/s eta 0:00:00
Downloading joblib-1.5.1-py3-none-any.whl (307 kB)
Downloading threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Installing collected packages: threa

In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

# --- 1. Key Hyperparameters and Configuration ---
# These have been tuned to prevent divergence and ensure stable training.

CONFIG = {
    "num_clients": 10,
    "num_rounds": 40,          # Increased rounds to see learning progress
    "clients_per_round": 5,
    "local_epochs": 2,         # Reduced local epochs to prevent client drift
    "batch_size": 32,
    "learning_rate": 0.01,     # Lowered learning rate to prevent overshooting
    "mu": 0.1,                 # The FedProx hyperparameter (controls the proximal term)
    "non_iid_alpha": 0.5       # Controls the degree of data heterogeneity (lower is more non-IID)
}

# --- 2. Model Definition (Simple CNN for MNIST) ---

class WellnessModel(nn.Module):
    """A simple Convolutional Neural Network for image classification."""
    def __init__(self):
        super(WellnessModel, self).__init__()
        self.conv_stack = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc_stack = nn.Sequential(
            nn.Linear(32 * 7 * 7, 512),
            nn.ReLU(),
            nn.Linear(512, 10), # 10 classes for MNIST digits
        )

    def forward(self, x):
        x = self.conv_stack(x)
        x = x.view(x.size(0), -1) # Flatten the tensor
        logits = self.fc_stack(x)
        return logits

# --- 3. Data Loading and Non-IID Partitioning ---

def get_data():
    """Downloads MNIST and partitions it into Non-IID subsets for clients."""
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    # This is the key part for simulating Non-IID data.
    # It uses a Dirichlet distribution to assign different proportions of classes to each client.
    num_classes = 10
    label_distribution = np.random.dirichlet([CONFIG["non_iid_alpha"]] * num_classes, CONFIG["num_clients"])

    class_indices = [np.where(np.array(train_dataset.targets) == i)[0] for i in range(num_classes)]

    client_data_indices = [[] for _ in range(CONFIG["num_clients"])]
    for c_idx in range(CONFIG["num_clients"]):
        for k_idx in range(num_classes):
            num_samples = int(len(class_indices[k_idx]) * label_distribution[c_idx][k_idx])
            selected_samples = np.random.choice(class_indices[k_idx], num_samples, replace=False)
            client_data_indices[c_idx].extend(selected_samples)

    client_datasets = [torch.utils.data.Subset(train_dataset, indices) for indices in client_data_indices]
    return client_datasets, test_dataset

class CustomDataset(Dataset):
    """Wrapper for Subset to make it a standard dataset."""
    def __init__(self, subset):
        self.subset = subset
    def __getitem__(self, index):
        return self.subset[index]
    def __len__(self):
        return len(self.subset)


Matplotlib is building the font cache; this may take a moment.
