In [5]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import nibabel as nib
from ipywidgets import interact
from torch.utils.data import Dataset, DataLoader
import os
from collections import Counter
from pathlib import Path
from torch.utils.data import random_split
import math
import torch.nn.functional as F
from collections import OrderedDict
import pytorch_lightning as pl
import os
import random
from torch.utils.data import DataLoader, random_split

In [25]:
# read label csv
data_dir = '/kaggle/input/mri-dataset/not_skull_stripped'
label_path = list(Path(data_dir).glob("*.xlsx"))
label_ls = pd.read_excel(label_path[0])

In [26]:
label_ls

Unnamed: 0,subject_age,subject_dx,subject_sex,subject_id,dataset_name
0,55.4,pathology,m,sub-BrainAge000000,ABIDE/Caltech
1,22.9,pathology,m,sub-BrainAge000001,ABIDE/Caltech
2,39.2,pathology,m,sub-BrainAge000002,ABIDE/Caltech
3,22.8,pathology,m,sub-BrainAge000003,ABIDE/Caltech
4,34.6,pathology,f,sub-BrainAge000004,ABIDE/Caltech
...,...,...,...,...,...
23209,66,control,f,sub-BrainAge023209,RocklandSample
23210,69,control,m,sub-BrainAge023210,RocklandSample
23211,23,control,m,sub-BrainAge023211,RocklandSample
23212,54,control,f,sub-BrainAge023212,RocklandSample


In [27]:
subject_paths = []

for subject_id in label_ls['subject_id']:
    mri_folder = Path(data_dir) / subject_id / "anat"
    
    mri_file = list(mri_folder.glob(f"{subject_id}_T1w.nii"))
    
    if mri_file: 
        full_path = mri_file[0] / mri_file[0].name  
        subject_paths.append(full_path.as_posix())
    else:
        subject_paths.append(None) 

label_ls['subject_path'] = subject_paths
print(label_ls.head())

  subject_age subject_dx subject_sex          subject_id   dataset_name  \
0        55.4  pathology           m  sub-BrainAge000000  ABIDE/Caltech   
1        22.9  pathology           m  sub-BrainAge000001  ABIDE/Caltech   
2        39.2  pathology           m  sub-BrainAge000002  ABIDE/Caltech   
3        22.8  pathology           m  sub-BrainAge000003  ABIDE/Caltech   
4        34.6  pathology           f  sub-BrainAge000004  ABIDE/Caltech   

                                        subject_path  
0  /kaggle/input/mri-dataset/not_skull_stripped/s...  
1  /kaggle/input/mri-dataset/not_skull_stripped/s...  
2  /kaggle/input/mri-dataset/not_skull_stripped/s...  
3  /kaggle/input/mri-dataset/not_skull_stripped/s...  
4  /kaggle/input/mri-dataset/not_skull_stripped/s...  


In [28]:
labels = label_ls[["subject_sex", "subject_id", "subject_path"]]

In [29]:
labels_data = labels.dropna(subset=['subject_sex', 'subject_path']).reset_index(drop=True)
labels_data

Unnamed: 0,subject_sex,subject_id,subject_path
0,m,sub-BrainAge000000,/kaggle/input/mri-dataset/not_skull_stripped/s...
1,m,sub-BrainAge000001,/kaggle/input/mri-dataset/not_skull_stripped/s...
2,m,sub-BrainAge000002,/kaggle/input/mri-dataset/not_skull_stripped/s...
3,m,sub-BrainAge000003,/kaggle/input/mri-dataset/not_skull_stripped/s...
4,f,sub-BrainAge000004,/kaggle/input/mri-dataset/not_skull_stripped/s...
...,...,...,...
10270,f,sub-BrainAge023209,/kaggle/input/mri-dataset/not_skull_stripped/s...
10271,m,sub-BrainAge023210,/kaggle/input/mri-dataset/not_skull_stripped/s...
10272,m,sub-BrainAge023211,/kaggle/input/mri-dataset/not_skull_stripped/s...
10273,f,sub-BrainAge023212,/kaggle/input/mri-dataset/not_skull_stripped/s...


In [30]:
labels_data["subject_path"][0]

'/kaggle/input/mri-dataset/not_skull_stripped/sub-BrainAge000000/anat/sub-BrainAge000000_T1w.nii/sub-BrainAge000000_T1w.nii'

In [31]:
data = labels_data.sample(frac=1).reset_index(drop=True)

num_clients = 5
client_data = {i: {'subject_id': [], 'subject_sex': [], 'subject_path': []} for i in range(num_clients)}

for idx, row in data.iterrows():
    client_id = idx % num_clients  
    client_data[client_id]['subject_id'].append(row['subject_id'])
    client_data[client_id]['subject_sex'].append(row['subject_sex'])
    client_data[client_id]['subject_path'].append(row['subject_path'])

In [34]:
class MRIDataset(Dataset):
    def __init__(self, subject_paths, subject_sexes, transform=None):
        self.subject_paths = subject_paths
        self.subject_sexes = subject_sexes
        self.transform = transform

    def __len__(self):
        return len(self.subject_paths)

    def __getitem__(self, idx):
        nii_path = self.subject_paths[idx]
        try:
            img = nib.load(nii_path).get_fdata()
            img = np.nan_to_num(img)
            img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)  # tránh chia 0
            img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0)

            label = 0 if self.subject_sexes[idx] == 'f' else 1
            label_tensor = torch.tensor(label, dtype=torch.long)

            return img_tensor, label_tensor

        except Exception as e:
            print(f"[Warning] Failed to load: {nii_path}\nError: {e}")
            # Cách 1: raise để DataLoader bỏ mẫu này nếu dùng custom sampler
            raise RuntimeError(f"Corrupted sample at index {idx}: {nii_path}")

    # def resize_image(self, img):
    #     target_shape = (128, 128, 128)
    #     resized_img = np.resize(img, target_shape)
    #     return resized_img

In [35]:
clients_dataset = []
for i in range(5):
    dataset = MRIDataset(
        subject_paths=client_data[i]['subject_path'],
        subject_sexes=client_data[i]['subject_sex']
    )
    clients_dataset.append(dataset)

In [36]:
clients_dataloaders = []

for dataset in clients_dataset:
    total_len = len(dataset)
    train_len = int(0.7 * total_len)
    val_len = int(0.2 * total_len)
    test_len = total_len - train_len - val_len  

    train_set, val_set, test_set = random_split(dataset, [train_len, val_len, test_len])

    train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=8, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=8, shuffle=False)

    clients_dataloaders.append({
        'train': train_loader,
        'val': val_loader,
        'test': test_loader
    })

In [37]:
class _DenseLayer(nn.Sequential):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super().__init__()
        self.add_module('norm1', nn.BatchNorm3d(num_input_features))
        self.add_module('relu1', nn.ReLU(inplace=True))
        self.add_module(
            'conv1',
            nn.Conv3d(num_input_features,
                      bn_size * growth_rate,
                      kernel_size=1,
                      stride=1,
                      bias=False))
        self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate))
        self.add_module('relu2', nn.ReLU(inplace=True))
        self.add_module(
            'conv2',
            nn.Conv3d(bn_size * growth_rate,
                      growth_rate,
                      kernel_size=3,
                      stride=1,
                      padding=1,
                      bias=False))
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super().forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features,
                                     p=self.drop_rate,
                                     training=self.training)
        return torch.cat([x, new_features], 1)

class _DenseBlock(nn.Sequential):

    def __init__(self, num_layers, num_input_features, bn_size, growth_rate,
                 drop_rate):
        super().__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features + i * growth_rate,
                                growth_rate, bn_size, drop_rate)
            self.add_module('denselayer{}'.format(i + 1), layer)


class _Transition(nn.Sequential):

    def __init__(self, num_input_features, num_output_features):
        super().__init__()
        self.add_module('norm', nn.BatchNorm3d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module(
            'conv',
            nn.Conv3d(num_input_features,
                      num_output_features,
                      kernel_size=1,
                      stride=1,
                      bias=False))
        self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    
    """
    Densenet-BC model class
    
    Args:
        growth_rate (int) - how many filters to add each layer (k in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
    """

    def __init__(self,
                 n_input_channels=1,
                 conv1_t_size=7,
                 conv1_t_stride=1,
                 no_max_pool=False,
                 growth_rate=32,
                 block_config=(6, 12, 24, 16),
                 num_init_features=64,
                 bn_size=4,
                 drop_rate=0,
                 num_classes=1):

        super().__init__()

        # First convolution
        self.features = [('conv1',
                          nn.Conv3d(n_input_channels,
                                    num_init_features,
                                    kernel_size=(conv1_t_size, 7, 7),
                                    stride=(conv1_t_stride, 2, 2),
                                    padding=(conv1_t_size // 2, 3, 3),
                                    bias=False)),
                         ('norm1', nn.BatchNorm3d(num_init_features)),
                         ('relu1', nn.ReLU(inplace=True))]
        if not no_max_pool:
            self.features.append(
                ('pool1', nn.MaxPool3d(kernel_size=3, stride=2, padding=1)))
        self.features = nn.Sequential(OrderedDict(self.features))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers,
                                num_input_features=num_features,
                                bn_size=bn_size,
                                growth_rate=growth_rate,
                                drop_rate=drop_rate)
            self.features.add_module('denseblock{}'.format(i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=num_features // 2)
                self.features.add_module('transition{}'.format(i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm3d(num_features))

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool3d(out,
                                    output_size=(1, 1,
                                                 1)).view(features.size(0), -1)
        out = self.classifier(out)
        return out

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

class DenseNetModule(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.criterion = nn.BCEWithLogitsLoss()
        self.lr = 0.01
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr)

    def forward(self, x):
        return self.model(x)

    def train(self, dataloader, num_epochs=1):
        self.model.train()
        self.model.to(device)

        initial_weights = torch.nn.utils.parameters_to_vector(self.model.parameters()).detach().cpu().numpy()

        total_loss = 0.0
        total_samples = 0

        for epoch in range(num_epochs):
            print(f"Epoch {epoch+1}:")
            epoch_loss = 0.0
            epoch_accuracy = 0.0
            for x_batch, y_batch in dataloader:
                x_batch = x_batch.to(device)
                y_batch = y_batch.to(device)

                self.optimizer.zero_grad()
                loss, acc = self.compute_loss_and_accuracy(x_batch, y_batch)  # Sử dụng hàm này

                loss.backward()
                self.optimizer.step()
    
                batch_size = x_batch.size(0)
                epoch_loss += loss.item() * batch_size
                epoch_accuracy += acc.item() * batch_size
                total_samples += batch_size
    
            total_loss += epoch_loss
            print(f"Train loss: {epoch_loss / total_samples}, Accuracy: {epoch_accuracy / total_samples}")

        averaged_loss = total_loss / total_samples

        updated_weights = torch.nn.utils.parameters_to_vector(self.model.parameters()).detach().cpu().numpy()
        
        # delta_list = []
        # for key in initial_weights:
        #     delta = updated_weights[key] - initial_weights[key]
        #     delta_list.append(delta.view(-1))  
        
        # delta_vector = torch.cat(delta_list) 

        delta_vector = updated_weights - initial_weights

        return delta_vector, averaged_loss

    @torch.no_grad()
    def evaluate(self, dataloader):
        self.model.eval()
        self.model.to(device)
    
        total_correct = 0
        total_samples = 0
        total_loss = 0.0
        total_accuracy = 0.0
    
        for x_batch, y_batch in dataloader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
    
            loss, acc = self.compute_loss_and_accuracy(x_batch, y_batch)
    
            total_loss += loss.item() * x_batch.size(0)
            total_accuracy += acc.item() * x_batch.size(0)
            total_samples += x_batch.size(0)
    
        avg_loss = total_loss / total_samples
        avg_accuracy = total_accuracy / total_samples

        print(f"Test loss: {avg_loss:.2f}, test accuracy: {avg_accuracy:.2f}")
    
        return avg_loss, avg_accuracy

    def compute_loss_and_accuracy(self, x, y):
        logits = self(x)
        loss = self.criterion(logits, y.float().unsqueeze(1))
        preds = (torch.sigmoid(logits) > 0.5).float()
        acc = (preds == y.unsqueeze(1)).float().mean()
        return loss, acc

In [22]:
class Client:
    def __init__(self, model, data, epochs=1):
        self.model = model
        self.data = data  
        self.lr = 0.01
        self.epochs = epochs
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr)

    def train(self):
        update, loss = self.model.train(self.data, num_epochs=self.epochs)
        num_samples = len(self.data.dataset)
        return num_samples, update

    def get_weights(self):
        return torch.nn.utils.parameters_to_vector(self.model.parameters()).detach().cpu().numpy()

    def set_weights(self, new_weights):
        vector = torch.tensor(new_weights)
        torch.nn.utils.vector_to_parameters(vector, self.model.parameters())

In [23]:
class ServerModel:
    def __init__(self, model):
        self.model = model
        # self.rng = model.rng

    @property
    def size(self):
        return self.model.optimizer.size()

    @property
    def cur_model(self):
        return self.model

    def send_to(self, clients):
        """Copies server model weights to each client"""
        weights = self.get_weights()
        for c in clients:
            c.set_weights(weights)

    def get_weights(self):
        return torch.nn.utils.parameters_to_vector(
            self.model.model.parameters()
        ).detach().cpu().numpy()

    def set_weights(self, new_weights):
        vector = torch.tensor(new_weights, dtype=torch.float32)
        torch.nn.utils.vector_to_parameters(vector, self.model.model.parameters())

    @staticmethod
    def weighted_average_oracle(points, weights):
        """Computes weighted average of atoms with specified weights

        Args:
            points: list, whose weighted average we wish to calculate
                Each element is a list_of_np.ndarray
            weights: list of weights of the same length as atoms
        """
        tot_weights = np.sum(weights)
        weighted_updates = np.zeros_like(points[0])

        for w, p in zip(weights, points):
            weighted_updates += (w / tot_weights) * p

        return weighted_updates

    def update(self, updates, max_update_norm=None, maxiter=4, 
            fraction_to_discard=0.0, norm_bound=None, 
        ):
        """Updates server model using given client updates.

        Args:
            updates: list of (num_samples, update), where num_samples is the
                number of training samples corresponding to the update, and update
                is a list of variable weights
            aggregation: Algorithm used for aggregation. Allowed values are:
                [ 'mean', 'geom_median']
            max_update_norm: Reject updates larger than this norm,
            maxiter: maximum number of calls to the Weiszfeld algorithm if using the geometric median
        """
        if len(updates) == 0:
            print('No updates obtained. Continuing without update')
            return 1, False

        def accept_update(u):
            # Calculate norm of update
            norm = np.linalg.norm(u[1])
            return not (np.isinf(norm) or np.isnan(norm))
                
        all_updates = updates
        updates = [u for u in updates if accept_update(u)]
        
        if len(updates) < len(all_updates):
            print('Rejected {} individual updates because of NaN or Inf'.format(len(all_updates) - len(updates)))
        if len(updates) == 0:
            print('All individual updates rejected. Continuing without update')
            return 1, False
    
        points = [u[1] for u in updates]  # List of np.ndarray
        alphas = [u[0] for u in updates]  # List of num_samples
    
        weighted_updates, num_comm_rounds, _ = self.geometric_median_update(points, alphas, maxiter=maxiter)
    
        update_norm = np.linalg.norm(weighted_updates)
    
        if max_update_norm is None or update_norm < max_update_norm:
            current_weights = self.get_weights()
            new_weights = current_weights + weighted_updates
            self.set_weights(new_weights)
            updated = True
        else:
            print(f"Update norm = {update_norm} is too large. Update rejected")
            updated = False
    
        return num_comm_rounds, updated
        
    @staticmethod
    def geometric_median_update(points, alphas, maxiter=4, eps=1e-5, verbose=False, ftol=1e-6):
        """Computes geometric median of atoms with weights alphas using Weiszfeld's Algorithm
        """
        alphas = np.asarray(alphas, dtype=points[0].dtype) / sum(alphas)
        median = ServerModel.weighted_average_oracle(points, alphas)
        num_oracle_calls = 1

        # logging
        obj_val = ServerModel.geometric_median_objective(median, points, alphas)
        logs = []
        log_entry = [0, obj_val, 0, 0]
        logs.append(log_entry)
        if verbose:
            print('Starting Weiszfeld algorithm')
            print(log_entry)

        # start
        for i in range(maxiter):
            prev_median, prev_obj_val = median, obj_val
            weights = np.asarray([alpha / max(eps, ServerModel.l2dist(median, p)) for alpha, p in zip(alphas, points)],
                                 dtype=alphas.dtype)
            weights = weights / weights.sum()
            median = ServerModel.weighted_average_oracle(points, weights)
            num_oracle_calls += 1
            obj_val = ServerModel.geometric_median_objective(median, points, alphas)
            log_entry = [i+1, obj_val,
                         (prev_obj_val - obj_val)/obj_val,
                         ServerModel.l2dist(median, prev_median)]
            logs.append(log_entry)
            if verbose:
                print(log_entry)
            if abs(prev_obj_val - obj_val) < ftol * obj_val:
                break
        return median, num_oracle_calls, logs

    @staticmethod
    def l2dist(p1, p2):
        """L2 distance between p1, p2, each of which is a list of nd-arrays"""
        #return np.linalg.norm([np.linalg.norm(x1 - x2) for x1, x2 in zip(p1, p2)])
        return np.linalg.norm(p1 - p2)

    @staticmethod
    def geometric_median_objective(median, points, alphas):
        """Compute geometric median objective."""
        return sum([alpha * ServerModel.l2dist(median, p) for alpha, p in zip(alphas, points)])

In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

base_model = DenseNet(num_init_features=32, growth_rate=16, block_config=(4, 8, 16, 12))
global_weights = copy.deepcopy(base_model.state_dict())

clients = []
for i in range(3):
    model = DenseNet(num_init_features=32, growth_rate=16, block_config=(4, 8, 16, 12))
    model.load_state_dict(copy.deepcopy(global_weights))
    model = DenseNetModule(model.to(device))
    client = Client(model=model, data=clients_dataloaders[i]["train"])
    clients.append(client)

server_model = DenseNet(num_init_features=32, growth_rate=16, block_config=(4, 8, 16, 12))
server_model.load_state_dict(copy.deepcopy(global_weights))
server_model = DenseNetModule(server_model.to(device))
server = ServerModel(model=server_model)


for round in range(5):  
    print(f"Round {round + 1}")
    
    server.send_to(clients)

    updates = []
    for client_index, client in enumerate(clients):
        print(f"Training on clients {client_index + 1}")
        num_samples, update = client.train() 
        updates.append((num_samples, update))
        loss, acc = client.model.evaluate(clients_dataloaders[client_index]["val"])
    print(f"Update server on round {round + 1}")
    server.update(updates)

  m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')


Round 1
Training on clients 1
Epoch 1:
Train loss: 0.6129160450463169, Accuracy: 0.6703755217235178
Test loss: 1.00, test accuracy: 0.54
Training on clients 2
Epoch 1:
Train loss: 0.6161303473778992, Accuracy: 0.634214186535757
Test loss: 0.86, test accuracy: 0.52
Training on clients 3
Epoch 1:
Train loss: 0.6228067395750106, Accuracy: 0.6411682894564802
Test loss: 0.67, test accuracy: 0.59
Update server on round 1
Round 2
Training on clients 1
Epoch 1:
Train loss: 0.5359035450991072, Accuracy: 0.7593880391087751
Test loss: 0.71, test accuracy: 0.50
Training on clients 2
Epoch 1:
Train loss: 0.5216137270270866, Accuracy: 0.7482614742698191
Test loss: 0.70, test accuracy: 0.64
Training on clients 3
Epoch 1:
Train loss: 0.5282685822173849, Accuracy: 0.7538247566063978
Test loss: 0.45, test accuracy: 0.81
Update server on round 2
Round 3
Training on clients 1
Epoch 1:
Train loss: 0.45896652316019165, Accuracy: 0.7934631434203189
Test loss: 0.70, test accuracy: 0.60
Training on clients 2
E