In [None]:
from collections import OrderedDict
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import Dataset, DataLoader, random_split
import nibabel as nib
from pathlib import Path

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

In [None]:
im = nib.load('/kaggle/input/datasetzip/not_skull_stripped/sub-BrainAge000019/anat/sub-BrainAge000019_T1w.nii/sub-BrainAge000019_T1w.nii')
data = im.get_fdata()
data.shape, im.affine, im.header

In [None]:
data_dir = '/kaggle/input/datasetzip/not_skull_stripped'
label_path = list(Path(data_dir).glob("*.xlsx"))
label_ls = pd.read_excel(label_path[0])

In [None]:
label_ls = label_ls[(label_ls['subject_dx'] == 'control') & ((label_ls['subject_sex'] == 'm') | (label_ls['subject_sex'] == 'f'))]

In [None]:
sexes = label_ls[['subject_sex','subject_id']]
sexes.head()

In [None]:
sexes_dict = sexes.set_index('subject_id')['subject_sex'].to_dict()
len(sexes_dict)

In [None]:
class MRIDataset(Dataset):
    def __init__(self, im_dir, label_ls, transform=None):
        self.im_dir = Path(im_dir)
        self.label_ls = label_ls
        self.transform = transform

        # Gather valid image paths
        self.im_filenames = [
            path for path in sorted(self.im_dir.glob("*/*/*/*.nii"))
            if self._is_valid(path)
        ]

    def _is_valid(self, path):
        subject_id = self.extract_subject_id(path)
        if subject_id not in self.label_ls:
            return False
        try:
            nib.load(path).get_fdata()  # just try loading (don't call get_fdata() yet)
            return True
        except Exception:
            return False

    def extract_subject_id(self, im_path):
        for part in Path(im_path).parts:
            if part.startswith("sub-BrainAge"):
                return part
        return None

    def __len__(self):
        return len(self.im_filenames)

    def __getitem__(self, idx):
        im_path = self.im_filenames[idx]
        im = nib.load(im_path).get_fdata()

        # Normalize
        im = (im - np.min(im)) / (np.max(im) - np.min(im) + 1e-5)  # avoid divide-by-zero
        im = im.astype(np.float32)

        subject_id = self.extract_subject_id(im_path)
        sex = self.label_ls.get(subject_id)
        if sex == 'm':
            label = 0
        elif sex == 'f':
            label = 1
        else:
            raise ValueError(f"Invalid label for subject {subject_id}: {sex}")

        if self.transform:
            im = self.transform(im)

        # add channel dim
        return torch.as_tensor(im), torch.tensor(label, dtype=torch.long)

In [None]:
def iid_client_split(dataset, num_client = 3,  val_ratio = 0.2):

    client_datasets = []
    sample_per_client = len(dataset) // num_client


    for i in range(num_client):
        start_idx = i * sample_per_client
        end_idx = (i + 1) * sample_per_client if i < num_client - 1 else len(dataset)
        indecies = list(range(start_idx, end_idx))

        client_dataset = torch.utils.data.Subset(dataset, indecies)
        train_dataset, val_dataset = random_split(client_dataset, [1 - val_ratio, val_ratio])

        client_datasets.append((train_dataset, val_dataset))
    return client_datasets

In [None]:
NUM_CLIENTS = 10
pytorch_transforms = transforms.Compose(
        [transforms.ToTensor()]
    )
torchdatasets = MRIDataset(data_dir, sexes_dict, pytorch_transforms)
# check length
print(len(torchdatasets))

trainvalset, testset = random_split(torchdatasets, [0.8, 0.2], generator = torch.Generator().manual_seed(42))

client_datasets = iid_client_split(trainvalset, num_client=NUM_CLIENTS, val_ratio=0.2)

In [None]:
BATCH_SIZE = 8

def load_datasets(partition_id: int):
    print("load_datasets starting")
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor()]
    )
    torchdatasets = MRIDataset(data_dir, sexes_dict, pytorch_transforms)
    trainvalset, testset = random_split(torchdatasets, [0.8, 0.2], generator = torch.Generator().manual_seed(42))
    client_datasets = iid_client_split(trainvalset, num_client=NUM_CLIENTS, val_ratio=0.2)
    train_set = client_datasets[partition_id][0]
    val_set = client_datasets[partition_id][1]
    test_set = testset
    # Create train/val for each partition and wrap it into DataLoader
    trainloader = DataLoader(
        train_set, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True
    )
    valloader = DataLoader(val_set, batch_size=BATCH_SIZE, pin_memory=True)
    testloader = DataLoader(test_set, batch_size=BATCH_SIZE, pin_memory=True)
    print("load_datasets finished")
    return trainloader, valloader, testloader

In [None]:
### MODEL
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict



class _DenseLayer(nn.Sequential):

    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super().__init__()
        self.add_module('norm1', nn.BatchNorm3d(num_input_features))
        self.add_module('relu1', nn.ReLU(inplace=True))
        self.add_module(
            'conv1',
            nn.Conv3d(num_input_features,
                      bn_size * growth_rate,
                      kernel_size=1,
                      stride=1,
                      bias=False))
        self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate))
        self.add_module('relu2', nn.ReLU(inplace=True))
        self.add_module(
            'conv2',
            nn.Conv3d(bn_size * growth_rate,
                      growth_rate,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=False))
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super().forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features,
                                     p=self.drop_rate,
                                     training=self.training)
        return torch.cat([x, new_features], 1)


class _DenseBlock(nn.Sequential):

    def __init__(self, num_layers, num_input_features, bn_size, growth_rate,
                 drop_rate):
        super().__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features + i * growth_rate,
                                growth_rate, bn_size, drop_rate)
            self.add_module('denselayer{}'.format(i + 1), layer)


class _Transition(nn.Sequential):

    def __init__(self, num_input_features, num_output_features):
        super().__init__()
        self.add_module('norm', nn.BatchNorm3d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module(
            'conv',
            nn.Conv3d(num_input_features,
                      num_output_features,
                      kernel_size=1,
                      stride=1,
                      bias=False))
        self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    
    """
    Densenet-BC model class
    
    Args:
        growth_rate (int) - how many filters to add each layer (k in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
    """

    def __init__(self,
                 n_input_channels=1,
                 conv1_t_size=7,
                 conv1_t_stride=1,
                 no_max_pool=False,
                 growth_rate=32,
                 block_config=(6, 12, 24, 16),
                 num_init_features=64,
                 bn_size=4,
                 drop_rate=0,
                 num_classes=1):

        super().__init__()

        # First convolution
        self.features = [('conv1',
                          nn.Conv3d(n_input_channels,
                                    num_init_features,
                                    kernel_size=(conv1_t_size, 7, 7),
                                    stride=(conv1_t_stride, 2, 2),
                                    padding=(conv1_t_size // 2, 3, 3),
                                    bias=False)),
                         ('norm1', nn.BatchNorm3d(num_init_features)),
                         ('relu1', nn.ReLU(inplace=True))]
        if not no_max_pool:
            self.features.append(
                ('pool1', nn.MaxPool3d(kernel_size=3, stride=2, padding=1)))
        self.features = nn.Sequential(OrderedDict(self.features))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers,
                                num_input_features=num_features,
                                bn_size=bn_size,
                                growth_rate=growth_rate,
                                drop_rate=drop_rate)
            self.features.add_module('denseblock{}'.format(i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=num_features // 2)
                self.features.add_module('transition{}'.format(i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm3d(num_features))

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool3d(out,
                                    output_size=(1, 1,
                                                 1)).view(features.size(0), -1)
        out = self.classifier(out)
        return out

In [None]:
def train(net, trainloader, epochs: int, verbose=False, device = "cuda:0", lr = 1e-3):
    """Train the network on the training set."""
    net.to(device)
    criterion = nn.BCEWithLogitsLoss().to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    torch.set_grad_enabled(True)
    net.train(True)
    for epoch in range(epochs):
        epoch_loss, acc_end = 0.0, 0.0
        for (x, y) in trainloader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = net(x.unsqueeze(1))
            loss = criterion(output, y.float().unsqueeze(1))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            acc = ((torch.sigmoid(output) > 0.5).float() == y.unsqueeze(1)).float().mean()
            acc_end += acc
        epoch_loss /= len(trainloader)
        acc_end /= len(trainloader)
        if verbose:
            print(f"Local epoch {epoch+1}: train loss {epoch_loss}, acc {acc_end} ")

def test(net, testloader, device = "cuda:0") -> float | float:
    """Evaluate the network on the entire test set."""
    net.to(device)
    criterion = nn.BCEWithLogitsLoss().to(device)
    losses, acc_end = 0.0, 0.0
    net.eval()
    with torch.no_grad():
        for (x, y) in testloader:
            x, y = x.to(device), y.to(device)
            output = net(x.unsqueeze(1))
            loss = criterion(output, y.float().unsqueeze(1))
            losses += loss.item()
            acc = ((torch.sigmoid(output) > 0.5).float() == y.unsqueeze(1)).float().mean()
            acc_end += acc
        losses /= len(testloader)
        acc_end /= len(testloader)
    return losses, acc_end

In [None]:
trainloader, valloader, testloader = load_datasets(partition_id=0)

In [None]:
net = DenseNet(num_init_features=32,growth_rate=16,block_config=(4, 8, 16, 12)).to(DEVICE)

for epoch in range(5):
    train(net, trainloader, 1, True, lr = 1e-4)
    loss, accuracy = test(net, valloader)
    print(f"Epoch {epoch+1}: validation loss {loss}, accuracy {accuracy}")
    
loss, accuracy = test(net, testloader)
print(f"Final test set performance:\n\tloss {loss}\n\taccuracy {accuracy}")

In [None]:
### Flower seperation ###

In [None]:
def set_parameters(net, parameters: List[np.ndarray]):
    print("Setting parameters")
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.as_tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def get_parameters(net) -> List[np.ndarray]:
    print("Returning parameters")
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

In [None]:
class FlowerClient(NumPyClient):
    def __init__(self, net, trainloader, valloader, partition_id):
        print(f"[Client {partition_id}] initializing")
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader
        self.partition_id = partition_id
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.net.to(self.device)

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, 1, True, self.device, lr=7e-4)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader, self.device)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

In [None]:
def client_fn(context: Context) -> Client:
    """Create a Flower client representing a single organization."""
    partition_id = context.node_config["partition-id"]
    # Load model
    print(f"Client {partition_id} loading model")
    net = DenseNet(num_init_features=32,growth_rate=16,block_config=(4, 8, 16, 12)).to(DEVICE)

    # Load data (CIFAR-10)
    # Note: each client gets a different trainloader/valloader, so each client
    # will train and evaluate on their own unique data partition
    # Read the node_config to fetch data partition associated to this node
    print(f"Client {partition_id} loading data partition")
    trainloader, valloader, _ = load_datasets(partition_id=partition_id)

    # Create a single Flower client representing a single organization
    # FlowerClient is a subclass of NumPyClient, so we need to call .to_client()
    # to convert it to a subclass of `flwr.client.Client`
    print(f"Client {partition_id} starting")
    return FlowerClient(net, trainloader, valloader, partition_id).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [None]:
# Specify the resources each of your clients need
# By default, each client will be allocated 1x CPU and 0x GPUs
backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

# When running on GPU, assign an entire GPU for each client
if DEVICE == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}
    # Refer to our Flower framework documentation for more details about Flower simulations
    # and how to set up the `backend_config`

In [None]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

In [None]:
def server_fn(context: Context) -> ServerAppComponents:
    """Construct components that set the ServerApp behaviour.

    You can use settings in `context.run_config` to parameterize the
    construction of all elements (e.g the strategy or the number of rounds)
    wrapped in the returned ServerAppComponents object.
    """
    print("Server is getting model parameters")
    ndarrays = get_parameters(net = DenseNet(num_init_features=32,growth_rate=16,block_config=(4, 8, 16, 12)).to(DEVICE))
    parameters = ndarrays_to_parameters(ndarrays)

    # Create FedAvg strategy
    strategy = FedAvg(
        fraction_fit=1.0,
        fraction_evaluate=0.5,
        min_fit_clients=10,
        min_evaluate_clients=5,
        min_available_clients=10,
        initial_parameters=parameters,
        evaluate_metrics_aggregation_fn=weighted_average,  # <-- pass the metric aggregation function
    )

    # Configure the server for 5 rounds of training
    config = ServerConfig(num_rounds=10)
    print("Server starting")
    return ServerAppComponents(strategy=strategy, config=config)


# Create a new server instance with the updated FedAvg strategy
server = ServerApp(server_fn=server_fn)

In [None]:
# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
)