## Environment Setup and Library Imports

In this section, we initialize the working environment and import necessary libraries for our federated learning project. Key components include:

1. **PyTorch Libraries (`torch`, `torchvision`):** PyTorch is our primary framework for building and training neural network models. `torchvision` is used for image processing and transformation tasks.

2. **Albumentations (`albumentations`):** A powerful library for image augmentation, allowing us to apply various transformations to our images, which is crucial for enhancing the robustness of our model.

3. **Flower (`flwr`):** This is a federated learning framework that we use to simulate and manage federated learning environments. It helps in distributing the model training across multiple clients.

4. **WandB (`wandb`):** Used for experiment tracking and visualization. It helps in monitoring the training process, logging metrics, and comparing different runs.

5. **CUDA Environment:** We set up CUDA for GPU acceleration to speed up our training process. This is critical for handling large-scale data and complex model architectures efficiently.

Understanding these components is essential for grasping the workflow and architecture of our federated learning experiment.


In [None]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt

import numpy as np
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import albumentations as A
from torch.cuda.amp.grad_scaler import GradScaler
from torch.cuda.amp.autocast_mode import autocast
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
import os
from PIL import Image
import flwr as fl
import torch.optim as optim
from flwr.common import Metrics
import wandb
from typing import List, Union
from flwr.common import Parameters, Scalar
from flwr.server.client_proxy import ClientProxy, FitRes
import copy
import random


from torch.nn.parallel import DistributedDataParallel as DDP
from torchmetrics import JaccardIndex


DEVICE = torch.device("cuda")  # Try "cuda" to train on GPU
print(
    f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)

## Configuration of Initial Parameters

Before we dive into the data processing and model training, it's crucial to define some initial parameters that will be used throughout the notebook. These parameters are fundamental to our federated learning setup and data handling:

1. **Number of Clients (`NUM_CLIENTS`):** This parameter specifies the number of clients that will participate in the federated learning process. In a federated learning environment, the model is trained across multiple decentralized devices or servers holding local data samples.

2. **Batch Size (`BATCH_SIZE`):** The batch size determines how many samples will be processed before the model's internal parameters are updated. It is a key factor in optimizing the training process, balancing speed and resource consumption.

3. **Image Dimensions (`IMAGE_HEIGHT`, `IMAGE_WIDTH`):** These parameters define the height and width of the images that will be used for training. Resizing images to uniform dimensions is important for consistent processing and is often a requirement for neural network inputs.

Setting these parameters correctly is essential for ensuring that the federated learning process runs smoothly and that the data is handled properly.


In [None]:
NUM_CLIENTS = 3
BATCH_SIZE = 12
IMAGE_HEIGHT = 240
IMAGE_WIDTH = 240

## Data Preprocessing and Transformation

Data preprocessing and transformation are critical steps in preparing our dataset for effective model training, especially in image-based applications. In this section, we define the transformations that will be applied to our images:

1. **Image Resizing:** We resize images to the predefined dimensions (`IMAGE_HEIGHT` and `IMAGE_WIDTH`). This ensures uniformity in input size for the model.

2. **Augmentation Techniques:**
   - **Rotation (`Rotate`):** Images are rotated by a specified limit to introduce variance and make the model robust to orientation changes.
   - **Horizontal Flip (`HorizontalFlip`):** This augmentation flips the image horizontally with a certain probability, further diversifying the dataset.
   - **Vertical Flip (`VerticalFlip`):** Similar to horizontal flipping, but in the vertical direction, introducing more variance.

3. **Normalization (`Normalize`):** We normalize the image pixel values. Normalization is a common practice in image processing that makes model training more efficient and stable.

4. **Conversion to Tensor (`ToTensorV2`):** The transformed image is converted into a PyTorch tensor, which is the required format for training neural network models in PyTorch.

These transformations are packaged into a `transform` function, which will be applied to our training dataset. We also define `val_transforms` for the validation dataset, which includes resizing and normalization, but without the additional augmentation techniques used in the training dataset. This distinction is important to evaluate the model's performance on more standard data forms.


In [None]:
# Define the transform
transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

# used to transform validation set
val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

## Custom Function for Dataset Splitting

This section introduces a custom function named `adjusted_random_split`, designed to split a dataset into subsets of specified lengths. This function is particularly useful in scenarios where the standard random split might not perfectly divide the dataset due to rounding issues or specific length requirements.

### Functionality of `adjusted_random_split`:
1. **Handling Standard Split:** Initially, the function attempts to split the dataset into subsets based on the provided `lengths`. This is done using PyTorch's `random_split` function.

2. **Error Handling and Adjustment:**
   - The function includes error handling to catch a `ValueError` that occurs if the sum of the specified lengths does not match the total length of the dataset.
   - If this error is encountered, the function adjusts the length of the last subset to compensate for any discrepancy. This ensures that the total length of all subsets exactly matches the length of the dataset.

3. **Retrying the Split:** After adjusting the lengths, the function retries the split operation to ensure successful partitioning of the dataset.

This custom function is essential for ensuring precise dataset splitting, especially in a federated learning context where data distribution among different clients needs to be controlled and accurate.


In [None]:
def adjusted_random_split(dataset, lengths):
    try:
        subsets = random_split(dataset, lengths)
    except ValueError as e:
        if str(e) == "Sum of input lengths does not equal the length of the input dataset!":
            print("Length mismatch detected. Adjusting lengths to match dataset.")
            lengths[-1] += len(dataset) - sum(lengths)  # adjust last split length to match dataset length
            subsets = random_split(dataset, lengths)  # retry the split
        else:
            raise  # re-raise the exception if it's not what we're expecting
    return subsets

## Custom Dataset Class for Wound Images

This section of the notebook defines a custom dataset class named `WoundDataset`, which is a subclass of PyTorch's `Dataset`. This class is tailored for handling a dataset consisting of wound images and their corresponding masks, crucial for tasks like image segmentation in medical imaging.

### Structure and Functionality of `WoundDataset`:
1. **Initialization (`__init__`):**
   - The constructor takes `image_dir` (directory containing images), `mask_dir` (directory containing corresponding masks), and an optional `transform`.
   - It initializes the dataset by listing all images in `image_dir`.

2. **Length Method (`__len__`):**
   - This method returns the total number of images in the dataset, allowing PyTorch to understand the dataset's size.

3. **Item Access Method (`__getitem__`):**
   - Given an index, this method loads the image and its corresponding mask from the directories.
   - The images are converted to the RGB format, and masks are converted to grayscale. The mask values are also adjusted to ensure a binary format (0 for background, 1 for wound).
   - If transformations are specified, they are applied to both the image and the mask. This is crucial for data augmentation and preprocessing.
   - The method returns a tuple containing the transformed image and mask.

This custom class is essential for efficiently loading and preprocessing the dataset, particularly for a task that involves simultaneous handling of images and their corresponding segmentation masks. By customizing the dataset loading process, we can ensure that the data is correctly formatted and transformed, ready for use in model training.


In [None]:
class WoundDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index])
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]
        return image, mask #, self.images[index]

## Dataset Preparation and Distribution for Federated Learning

This section of the notebook focuses on loading the wound images and masks, combining them into a single dataset, and then distributing them among clients for federated learning. This process is critical for ensuring that each client receives a representative portion of the data.

### Steps Involved in Dataset Preparation:
1. **Loading Datasets:**
   - We use the `WoundDataset` class to load training and testing datasets separately for images and masks.

2. **Combining Datasets:**
   - The training and testing datasets are combined into one dataset using `ConcatDataset`. This is done to shuffle and redistribute them into new training and testing sets.

3. **Shuffling the Combined Dataset:**
   - The combined dataset is shuffled to ensure randomness and prevent any bias that might be present in the way the data was originally ordered.

4. **Splitting into New Training and Testing Sets:**
   - The shuffled dataset is then split into new training and testing sets. The split ratio is determined randomly between 70% and 90% to vary the size of the training set.

5. **Adjusted Random Split for Federated Learning:**
   - To ensure each client receives at least one image, we distribute the datasets among clients. We start by assigning one image to each client and then distribute the remaining images randomly.
   - The `adjusted_random_split` function is used to handle the splitting, ensuring that the sum of splits matches the dataset size exactly.

6. **Creating Data Loaders:**
   - For each subset of the new training and testing sets, a `DataLoader` is created. These loaders are essential for batch processing during model training and evaluation.

7. **Diagnostic Print Statements:**
   - Finally, print statements are used to output the size of the new train and test sets, the split ratio, and the number of batches per client. This helps in verifying the distribution and ensuring everything is as expected.

Through these steps, we prepare our datasets in a manner that suits federated learning, where data is distributed across multiple clients for decentralized training.


In [None]:
# Load images and masks
trainset = WoundDataset("../../wound_data/data/woundData/train_images", "../../wound_data/data/woundData/train_masks", transform=transform)
testset = WoundDataset("../../wound_data/data/woundData/val_images", "../../wound_data/data/woundData/val_masks", transform=val_transforms)

# Combine train and test sets
combined_set = ConcatDataset([trainset, testset])

#Shuffle the combined dataset
indices = list(range(len(combined_set)))
random.shuffle(indices)
shuffled_set = torch.utils.data.Subset(combined_set, indices)

# Split into new train and test 
split_ratio = random.uniform(0.7, 0.9)
train_size = int(split_ratio * len(shuffled_set))
test_size = len(shuffled_set) - train_size
adj_trainset, adj_testset = random_split(shuffled_set, [train_size, test_size])

# Ensure each client gets at least one picture
lengths_train = [1] * NUM_CLIENTS
lengths_test = [1] * NUM_CLIENTS

# Distribute the remaining pictures randomly among the clients
remaining_train_pics = len(adj_trainset) - NUM_CLIENTS
remaining_test_pics = len(adj_testset) - NUM_CLIENTS

for i in range(remaining_train_pics):
    lengths_train[random.randint(0, NUM_CLIENTS - 1)] += 1

for i in range(remaining_test_pics):
    lengths_test[random.randint(0, NUM_CLIENTS - 1)] += 1

new_trainset = adjusted_random_split(adj_trainset, lengths_train)
new_testset = adjusted_random_split(adj_testset, lengths_test)

train_loaders = []
test_loaders = []

for trainset in new_trainset:
    train_loaders.append(DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True))
for testset in new_testset:
    test_loaders.append(DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False))

print(f"New train set size: {len(adj_trainset)}")
print(f"New test set size: {len(adj_testset)}")
print(f"Split ratio: {split_ratio}")
print(type(train_loaders[0]))  

for i, client_sample in train_loaders:
    print(f"Train Client {i} Size: {len(client_sample)}")

for i, client_sample in test_loaders:
    print(f"Test Client {i} Size: {len(client_sample)}")

## Function to Display and Save Original Dataset Samples

In this part of the notebook, we define a function named `display_original_samples`. This function is designed to display and save a specified number of original images and masks from the dataset. This visual inspection is crucial for understanding the nature of the data we are working with and for ensuring the quality of the images and masks before proceeding with training and analysis.

### Overview of `display_original_samples` Function:
1. **Function Parameters:**
   - `dataset`: The dataset from which samples will be displayed.
   - `num_samples`: The number of samples to display. By default, set to 3.
   - `save_path`: The file path where the figure will be saved.
   - `target_size`: The target size to which the images and masks will be resized.

2. **Processing and Displaying Samples:**
   - The function iterates through the specified number of samples from the dataset.
   - Each image and its corresponding mask are resized to the `target_size` for consistent display.
   - Both the original image and mask are displayed side by side for each sample. This helps in visualizing the data and understanding the correlation between the images and masks.

3. **Customization and Saving the Figure:**
   - Titles are added for clarity, and axes are turned off for a cleaner look, suitable for scientific presentations or publications.
   - The layout is adjusted for optimal spacing, and the figure is saved to the specified `save_path`.
   - The figure is also displayed inline in the notebook for immediate inspection.

### Example Usage:
After defining the function, an example usage is shown where the `WoundDataset` is loaded without transformations. The `display_original_samples` function is then called with this dataset to visualize the first three samples.

This function is an essential tool for preliminary data analysis, allowing researchers to visually inspect the dataset and confirm that the images and masks are correctly aligned and processed.


In [None]:
def display_original_samples(dataset, num_samples=3, save_path='figures/original_samples.png', target_size=(256, 256)):
    """
    Display and save the first three original images and masks from the dataset.

    Parameters:
    dataset (Dataset): The dataset to display images from.
    num_samples (int): Number of samples to display, set to 3.
    save_path (str): File path to save the figure.
    target_size (tuple): The target size for images and masks.
    """
    fig, ax = plt.subplots(num_samples, 2, figsize=(10, 5 * num_samples))

    for i in range(num_samples):
        image, mask = dataset[i]

        # Resize images and masks
        resized_image = Image.fromarray(image).resize(target_size)
        resized_mask = Image.fromarray(mask * 255).resize(target_size).convert("L")

        # Display images and masks
        img_ax = ax[i, 0].imshow(np.array(resized_image))
        mask_ax = ax[i, 1].imshow(np.array(resized_mask), cmap='gray')

        # Customization for a scientific journal
        ax[i, 0].set_title('Original Image')
        ax[i, 1].set_title('Original Mask')
        ax[i, 0].axis('off')
        ax[i, 1].axis('off')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.show()

# Example usage:
original_dataset = WoundDataset(image_dir, mask_dir, transform=None)
display_original_samples(original_dataset, num_samples=3)


## Function to Display and Save Samples from a Client's Dataset

In this part of the notebook, we define a function named `display_client_samples`. This function is intended for displaying and saving a specified number of images and masks from a given client's dataset in a federated learning context. This is crucial for verifying the data each client receives and ensuring consistency and alignment of images and masks across different clients.

### Overview of `display_client_samples` Function:
1. **Function Parameters:**
   - `data_loader`: The DataLoader associated with a specific client's dataset.
   - `num_samples`: The number of samples to display, defaulting to 3.
   - `save_path`: The file path where the figure will be saved.
   - `target_size`: The target size for resizing images and masks.

2. **Processing and Displaying Samples:**
   - The function iterates through the DataLoader, processing and displaying only the specified number of samples.
   - Each image and mask are converted from tensors to numpy arrays, scaled to a 0-255 range, and resized to the target size for consistent visualization.
   - The images and masks are displayed side by side for each sample. This allows for a direct comparison and ensures that they are correctly aligned.

3. **Customization and Saving the Figure:**
   - Titles are added for clarity, and axes are turned off for a cleaner presentation.
   - Colorbars are added next to each image and mask for better visual interpretation of the data.
   - The layout is adjusted for optimal spacing, and the figure is saved to the specified `save_path`.
   - The figure is also displayed in the notebook for immediate review.

### Example Usage:
The function is demonstrated with an example where it is called on a DataLoader from one of the clients (`train_loaders[client_id]`). This showcases how the function can be used to inspect the data distribution and quality for individual clients in a federated learning setup.

This visualization function is an essential tool for analyzing the data distribution in federated learning scenarios, ensuring that each client receives appropriate and well-preprocessed data.


In [None]:
def display_client_samples(data_loader, num_samples=3, save_path='figures/sample_figure.png', target_size=(256, 256)):
    """
    Display and save the first three images and masks from a client's dataset, ensuring they are all the same size and aligned.

    Parameters:
    data_loader (DataLoader): The DataLoader for the client's dataset.
    num_samples (int): Number of samples to display, set to 3.
    save_path (str): File path to save the figure.
    target_size (tuple): The target size for images and masks.
    """
    fig, ax = plt.subplots(num_samples, 2, figsize=(10, 5 * num_samples))  # Adjust layout for multiple samples

    for i, (images, masks) in enumerate(data_loader):
        if i >= num_samples:
            break

        # Process and resize images and masks
        image = images[0].permute(1, 2, 0).numpy()  # Convert tensor to numpy array
        image = (image * 255).astype(np.uint8)  # Scale to 0-255 and convert to uint8
        resized_image = Image.fromarray(image).resize(target_size)

        mask = masks[0].numpy()  # Convert tensor to numpy array
        mask = (mask * 255).astype(np.uint8)  # Scale to 0-255 and convert to uint8
        resized_mask = Image.fromarray(mask).resize(target_size)

        # Display images and masks
        img_ax = ax[i, 0].imshow(np.array(resized_image))
        mask_ax = ax[i, 1].imshow(np.array(resized_mask), cmap='gray')

        # Customization for a scientific journal
        ax[i, 0].set_title('Original Image')
        ax[i, 1].set_title('Mask')
        ax[i, 0].axis('off')
        ax[i, 1].axis('off')
        
        # Adding colorbars
        fig.colorbar(img_ax, ax=ax[i, 0], orientation='vertical')
        fig.colorbar(mask_ax, ax=ax[i, 1], orientation='vertical')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.show()

# Example usage:
client_id = 0  # Replace with the desired client ID
display_client_samples(train_loaders[client_id], num_samples=3)


## Defining the UNET Architecture for Image Segmentation

In this section, we define the architecture of a U-Net, which is a convolutional neural network widely used for image segmentation tasks, particularly in medical imaging. The U-Net architecture is notable for its effectiveness in handling segmentation challenges where precise localization is crucial.

### Structure of the U-Net Model:
1. **Double Convolution Block (`DoubleConv`):**
   - This module is a core component of the U-Net, performing two consecutive convolution operations.
   - Each convolution is followed by batch normalization and a ReLU activation, ensuring non-linear transformations and stabilized learning.
   - The `DoubleConv` block acts as a foundational element in both the downsampling (encoder) and upsampling (decoder) paths of the U-Net.

2. **U-Net Architecture (`UNET`):**
   - The U-Net model is characterized by its encoder-decoder structure.
   - **Encoder:** The downsampling path consists of several `DoubleConv` blocks. The spatial dimensions of the input are reduced while increasing the feature depth, capturing complex features at different scales. Max pooling is used between these blocks for downsampling.
   - **Bottleneck:** This part processes the output from the last downsampling step, serving as a bridge between the encoder and decoder paths.
   - **Decoder:** The upsampling path mirrors the encoder, with transposed convolutions for upsampling the feature maps, followed by `DoubleConv` blocks. Skip connections from the corresponding encoder layers are concatenated with these upsampled outputs, aiding in precise localization and feature preservation.

3. **Final Convolution:**
   - A final convolutional layer at the end of the decoder path reduces the number of output channels to match the requirements of the specific segmentation task (e.g., binary mask for segmentation).

### Practical Implementation:
- The U-Net architecture is highly adaptable to various segmentation tasks. 
- Its design, featuring skip connections and deep feature extraction, is particularly effective in ensuring precise pixel-level segmentation, which is critical in fields like medical image analysis.

The U-Net model outlined here provides a robust framework for tackling a wide range of image segmentation problems. Its effectiveness lies in its ability to capture both context and localization information, making it a popular choice in the field of medical image analysis.


In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

## Functions for Model Evaluation and Training

This section includes key functions essential for training and evaluating the segmentation model. It defines the `confusion_matrix` function for computing performance metrics, `check_accuracy` for evaluating model accuracy, and the `train` function for training the model.

### Confusion Matrix Function (`confusion_matrix`):
1. **Purpose:** Computes the confusion matrix for binary segmentation. This matrix is vital for understanding the model's performance in terms of true positives, true negatives, false positives, and false negatives.
2. **Parameters:**
   - `preds`: Predictions from the model, assumed to be binary values after thresholding.
   - `y`: Ground truth labels.
3. **Returns:** True negatives (tn), false positives (fp), false negatives (fn), and true positives (tp).

### Accuracy Checking Function (`check_accuracy`):
1. **Functionality:** Evaluates the model's performance on a given dataset (data loader). It uses metrics like loss, dice score, and IoU (Intersection over Union) score for a comprehensive evaluation.
2. **Implementation Details:**
   - The function calculates loss using binary cross-entropy with logits.
   - It iterates over the dataset, computing the confusion matrix and other metrics for each batch.
   - The function aggregates these metrics to provide an overall evaluation of the model's accuracy on the dataset.
3. **Results:** Outputs include true negatives, false positives, false negatives, true positives, IoU score, dice score, and other relevant statistics.

### Training Function (`train`):
1. **Objective:** Manages the training process of the neural network on a specified training dataset.
2. **Key Steps:**
   - Sets up the loss function (binary cross-entropy with logits) and the optimizer (Adam).
   - Optionally uses multiple GPUs for training if available.
   - Employs gradient scaling to manage mixed precision training, enhancing performance.
   - Iterates over the epochs, performing forward and backward passes for each batch.
3. **Functionality:** This function is responsible for adjusting the model's weights based on the training data, with the aim of improving its segmentation accuracy.

These functions are central to the model's lifecycle, from training to performance evaluation. They ensure the model is trained effectively and provide detailed insights into its accuracy and reliability in segmenting images.


In [None]:
def confusion_matrix(preds, y):
    """
    Compute confusion matrix for binary segmentation.

    Args:
        preds: Predictions from the model. Assumes binary values after thresholding.
        y: Ground truth labels.

    Returns:
        tn: True negatives
        fp: False positives
        fn: False negatives
        tp: True positives
    """
    tp = (y * preds).sum().to(torch.float32)
    tn = ((1 - y) * (1 - preds)).sum().to(torch.float32)
    fp = ((1 - y) * preds).sum().to(torch.float32)
    fn = (y * (1 - preds)).sum().to(torch.float32)

    return tn, fp, fn, tp

def check_accuracy(loader, model, device="cuda"):
    """Test the network on the training set."""
    print("~~~~ In test ~~~~")
    criterion = torch.nn.BCEWithLogitsLoss()
    loss = 0    
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    iou_score  = 0
    result = []
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            loss += criterion(preds, y).item()
            tn, fp, fn, tp = confusion_matrix(preds, y)
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds.sum() + y.sum()) + 1e-8
            )
            # Calculate IoU score
            intersection = (preds * y).sum()
            union = (preds + y).sum() - intersection
            iou_score += (intersection + 1e-8) / (union + 1e-8)

    num_batches = len(loader)
    loss /= num_batches
    acc = num_correct/num_pixels*100
    diceS = dice_score/num_batches
    iouS = iou_score/num_batches
    correct_pixel = num_correct
    total_pixel = num_pixels
    print(f"True Negatives: {tn}")
    print(f"False Positives: {fp}")
    print(f"False Negatives: {fn}")
    print(f"True Positives: {tp}")
    print(f"IoU Score = {iouS}")
    print(f"Dice Score = {diceS}")
    print("~~~~~ Out of test ~~~~~")

    model.train()
    
    result = [acc.item(), diceS.item(), iouS.item(), loss, correct_pixel.item(), total_pixel]
    
    return result
  

def train(net, trainloader, epochs: int, lr, device="cuda"):
    """Train the network on the training set."""
    print("~~~~ In train ~~~~")


    criterion = torch.nn.BCEWithLogitsLoss()#nn.CrossEntropyLoss() #loss_fn
    optimizer = torch.optim.Adam(net.parameters(), lr)
    scaler = GradScaler() #torch.cuda.amp.

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        net = nn.DataParallel(net)
    net.to(DEVICE)     
    net.train()
    
    for epoch in range(epochs):
        print(f'epoch => {epoch}')
        print(f'length of training data {len(trainloader)}')
        for images, labels in trainloader:
            # examine image integrity here
            images = images.to(DEVICE)
            labels = labels.float().unsqueeze(1).to(device=DEVICE)
            # forward
            with torch.cuda.amp.autocast():
                outputs = net(images)
                loss = criterion(outputs, labels)
            # backward 
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            print("end of epoch")
    print("~~~~ Out of train ~~~~")


## Weighted Average Calculation for Aggregated Metrics

This section defines the `weighted_average` function, a crucial component in federated learning for aggregating metrics from different clients. The function calculates a weighted average of various performance metrics across clients, considering the number of examples each client contributed. This approach ensures that the aggregated metrics accurately reflect the performance across the entire distributed dataset.

### Functionality of `weighted_average`:
1. **Purpose:** To compute the weighted average of various metrics (accuracy, dice score, IoU score, and loss) from multiple clients in a federated learning setup.
2. **Parameters:** 
   - `metrics`: A list of tuples, where each tuple contains the number of examples used by a client and the metrics dictionary reported by that client.

3. **Processing Steps:**
   - **Metric Extraction and Weighting:** For each metric (accuracy, dice score, IoU score, loss), the function multiplies the metric value by the number of examples contributed by the client. This step ensures that clients with more data have a proportionally greater influence on the final averaged metric.
   - **Aggregation:** The function then calculates the sum of these weighted metrics and divides it by the total number of examples across all clients. This results in a weighted average for each metric.

4. **Return Value:**
   - The function returns a dictionary containing the aggregated metrics: weighted average accuracy, dice score, IoU score, and loss.

### Significance in Federated Learning:
- In federated learning, where data is distributed across multiple clients, simple averaging might not accurately represent model performance due to the varying amount of data each client holds.
- The `weighted_average` function addresses this by calculating weighted averages, providing a more representative measure of the model's overall performance across all clients.

This function plays a critical role in evaluating the federated learning model, ensuring that the aggregated metrics are a fair representation of the model's performance on the diverse and distributed dataset.


In [None]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    print("_____metrics_______")
    print(metrics)
    print(">>>>>>>>>>>")

    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] # Original
    dice = [ num_examples *m["dice_score"] for num_examples, m in metrics]
    iouS = [num_examples * m["iouS"] for num_examples, m in metrics]
    loss = [num_examples * m["loss"] for num_examples, m in metrics]

    #accuracies = [m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    # Aggregate and return custom metric (weighted average)
    acc = sum(accuracies)/sum(examples)
    dice_score = sum(dice)/sum(examples)
    iou_score = sum(iouS)/sum(examples)
    loss = sum(loss)/sum(examples)

    #wandb.log({"acc": acc, "dice_score": dice_score, "iou_score": iou_score, "loss": loss})

    return {"accuracies": acc, "dice_score": dice_score, "iouS": iou_score, "loss": loss}

## Custom Federated Averaging Strategy with Model Saving

In this part of the notebook, we define `SaveModelStrategy`, a custom class that extends Flower's federated averaging strategy (`FedAvg`). This class is tailored to not only aggregate model weights and metrics from different clients but also save the aggregated model state after each round of federated training.

### Key Features of `SaveModelStrategy`:
1. **Aggregate and Save Model Weights:**
   - The `aggregate_fit` method is overridden to add functionality for saving the model.
   - After aggregating parameters using the standard federated averaging method, the aggregated model weights are saved to a file. This is crucial for tracking the evolution of the model over training rounds and for potential later use or analysis.

2. **Process of Aggregation and Saving:**
   - **Parameter Aggregation:** Initially, the method calls its superclass (`FedAvg`) to perform the standard aggregation of parameters and metrics.
   - **Model Saving:** If the aggregation is successful, the method converts the aggregated parameters into a format compatible with PyTorch's `state_dict`, and then updates the model's state.
   - **Checkpoint Creation:** The model's state is saved as a checkpoint file, named according to the current training round. This enables checkpointing, allowing for model recovery and evaluation at different training stages.

3. **Return Value:**
   - The method returns the aggregated parameters and metrics, consistent with the expected output of an aggregation function in federated learning.

### Importance in Federated Learning:
- The ability to save model checkpoints during federated training is valuable for monitoring model progress, debugging, and potentially resuming training from a specific round.
- This custom strategy enhances the basic federated averaging by adding a crucial aspect of model management, making it more practical for real-world federated learning scenarios.

The `SaveModelStrategy` class provides a concrete example of how federated learning strategies can be extended and customized to meet specific requirements, such as model checkpointing in this case.

##### *Define the path you want to save each round's aggregated model in this function


In [None]:
class SaveModelStrategy(fl.server.strategy.FedAvg):
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate model weights using weighted average and store checkpoint"""

        # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
        aggregated_parameters, aggregated_metrics = super().aggregate_fit(server_round, results, failures)

        if aggregated_parameters is not None:
            print(f"Saving round {server_round} aggregated_parameters...")

            # Convert `Parameters` to `List[np.ndarray]`
            aggregated_ndarrays: List[np.ndarray] = fl.common.parameters_to_ndarrays(aggregated_parameters)

            # Convert `List[np.ndarray]` to PyTorch`state_dict`
            params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
            state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
            net.load_state_dict(state_dict, strict=True)

            # Save the model
            torch.save(net.state_dict(), f"review_3c_12bs_000001_bce/model_round_{server_round}.pth")

        return aggregated_parameters, aggregated_metrics

## Function for Saving Model Predictions as Images

This section of the notebook introduces the `save_predictions_as_imgs` function. It is designed to save the model's predictions as image files, allowing for a visual inspection of the model's performance on a given dataset. This function is particularly useful in scenarios where evaluating the quality of segmentation or other image-based outputs is crucial.

### Functionality of `save_predictions_as_imgs`:
1. **Purpose:** To generate and save prediction images from the model for a given dataset. This visual output helps in assessing the model's accuracy and effectiveness in tasks like image segmentation.
2. **Parameters:** 
   - `loader`: DataLoader containing the dataset to be processed.
   - `model`: The trained model used for generating predictions.
   - `client_id`: Identifier for the client (useful in federated learning contexts).
   - `folder`: Destination folder for saving the images.
   - `device`: The device (e.g., CPU, CUDA) on which the model is running.

3. **Process:**
   - The function iterates over the dataset provided by the loader.
   - For each batch, it computes the model's predictions, applies a threshold to obtain binary output, and saves these predictions as images.
   - Both the predictions and the corresponding ground truth labels (if available) are saved, allowing for a direct comparison.
   - Images are saved with a naming convention that includes the batch index and the client identifier, ensuring easy tracking and organization.

4. **Model State Management:**
   - The model is set to evaluation mode (`model.eval()`) before processing to disable any training-specific behaviors like dropout.
   - Once the predictions are saved, the model is reverted back to training mode (`model.train()`).

### Practical Applications:
- This function is invaluable for visually evaluating the model's performance, particularly in tasks where understanding the spatial accuracy of predictions (like segmentation) is essential.
- In federated learning, where models are trained across different clients, this function aids in understanding the model's performance specific to each client's data.

The `save_predictions_as_imgs` function exemplifies how model outputs can be materialized into a tangible format, providing a practical tool for visual analysis and presentation of results.


In [None]:
def save_predictions_as_imgs(
    loader, model, client_id, folder="prediction_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}_{client_id}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}_{client_id}.png")

    model.train()

## Functions for Handling Model Parameters in Federated Learning

In this section, we define two essential functions, `get_parameters` and `set_parameters`, which are crucial for managing the model's parameters in a federated learning context. These functions facilitate the extraction and update of model parameters, enabling the synchronization of local models with global parameters across different clients.

### Function for Extracting Model Parameters (`get_parameters`):
1. **Purpose:** Retrieves the parameters of a PyTorch model as a list of NumPy arrays. This is typically used to gather the local model parameters from a client in federated learning.
2. **Parameters:** 
   - `net`: The neural network model from which parameters are to be extracted.
3. **Functionality:** 
   - The function iterates over the model's state dictionary (`state_dict`) and converts each parameter tensor to a NumPy array.
   - This conversion facilitates easy transmission of parameters over networks in federated learning scenarios.

### Function for Updating Model with Global Parameters (`set_parameters`):
1. **Purpose:** Updates a local model with a new set of parameters. This is typically used to synchronize a client's local model with the global model parameters in federated learning.
2. **Parameters:** 
   - `device`: The device (e.g., CPU, GPU) on which the model is running.
   - `net`: The neural network model to be updated.
   - `parameters`: A list of NumPy arrays representing the new model parameters.
3. **Process:**
   - The function creates a new state dictionary for the model by mapping the provided parameters to the model's parameter keys.
   - The model's state is then updated with this new state dictionary, effectively synchronizing it with the provided parameters.

### Importance in Federated Learning:
- These functions are fundamental in federated learning frameworks, where model parameters need to be frequently exchanged between the server and clients.
- `get_parameters` allows for the efficient collection of local model parameters, while `set_parameters` ensures that local models are consistently updated with global advancements.

Utilizing these functions, federated learning systems can maintain coherence and synchronization across distributed models, ensuring that all participating models are updated and trained in unison.


In [None]:
#get local params
def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


#update local model with global params
def set_parameters(device, net, parameters: List[np.ndarray]):
    print("in outside SP !!!!!!!!")
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

## Custom Flower Client for Federated Learning

In this section, we introduce the `FlowerClient` class, a custom client implementation for federated learning using the Flower framework. This class is responsible for handling the training and evaluation of a neural network model (such as U-Net) on each client's local data. Additionally, we define the `client_fn` function, which instantiates and configures these clients.

### Overview of `FlowerClient` Class:
1. **Initialization:**
   - The class is initialized with a client ID, neural network model (`net`), training and validation data loaders, device, learning rate, and the number of training epochs.
2. **Parameter Synchronization Methods:**
   - `set_parameters`: Updates the client's local model with global parameters received from the server.
   - `get_parameters`: Retrieves the current parameters of the client's local model.
3. **Training and Evaluation:**
   - `fit`: Trains the local model on the client's data using the provided global parameters and returns the updated model parameters.
   - `evaluate`: Evaluates the model on the client's validation data and returns the performance metrics.

### Functionality of `client_fn`:
- **Purpose:** This function is a factory for creating `FlowerClient` instances.
- **Parameters:** It takes a client ID (`cid`) and uses it to set up a new client with its own model, training and validation data loaders, and training configurations.
- **Implementation:** It initializes a U-Net model, assigns the respective data loaders for training and validation, and sets the learning rate and epochs for training.

### Significance in Federated Learning:
- The `FlowerClient` class encapsulates the behavior of a federated learning client, including how it trains locally and communicates with the federated server.
- By using the `client_fn`, we can easily create multiple client instances, each with its unique dataset and model, facilitating the distributed training process in federated learning.

This custom implementation demonstrates how federated learning clients can be tailored to specific requirements, such as training specialized models like U-Net, and how they can be efficiently managed and configured within the Flower framework.


In [None]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader, device, learning_rate, epochs):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader
        self.device = device
        self.epochs = epochs
        self.lr = learning_rate

    def set_parameters(self, parameters):
        print(f"[Client {self.cid}] set_parameters")
        # set_parameters(self.device, self.net, parameters)
        params_dict = zip(self.net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        net.load_state_dict(state_dict, strict=True)

    def get_parameters(self, config):
        print(f"[Client {self.cid}] get_parameters")
        return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
        #return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.cid}] fit, config: {config}")
        self.set_parameters(parameters)
        train(self.net, self.trainloader, self.epochs, self.lr, self.device)
        return self.get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        self.set_parameters(parameters)
        #loss, accuracy = test(self.net, self.valloader)
        result = check_accuracy(self.valloader, self.net, self.device)
        # print some examples to a folderthu
        #save_predictions_as_imgs(self.valloader, self.net, self.cid, folder="prediction_images/",device=self.device)
        #[acc.item(), diceS.item(), iouS.item(), loss, correct_pixel.item(), total_pixel]
        print('~~~ loss = ', str(result[3]))
        return float(result[3]), len(self.valloader), {"accuracy": result[0], "dice_score": result[1], "iouS": result[2], "loss": result[3]}

def client_fn(cid) -> FlowerClient:
    net = UNET().to(DEVICE)
    trainloader = train_loaders[int(cid)]
    valloader = test_loaders[int(cid)]
    lr = 0.000001
    epochs = 30
    print("~~~~client created~~~~")
    return FlowerClient(cid, net, trainloader, valloader, DEVICE, lr, epochs)

## Initializing the Federated Learning Environment with Flower

This section of the notebook focuses on setting up the federated learning environment using Flower, a framework for building federated learning systems. The setup involves creating a model instance, defining a custom training strategy, and configuring the federated learning server and clients.

### Model Initialization and Parameter Extraction:
1. **Model Creation:** An instance of the U-Net model (`net`) is created and moved to the specified device (e.g., GPU or CPU).
2. **Parameter Extraction:** The initial parameters of the model are extracted using the `get_parameters` function. These parameters serve as the starting point for federated learning.

### Custom Strategy Definition (`SaveModelStrategy`):
- A custom strategy called `SaveModelStrategy` is defined, which extends Flower's federated averaging strategy.
- Key configurations of this strategy include:
   - `fraction_fit` and `fraction_evaluate`: Proportions of clients that participate in training and evaluation.
   - `min_fit_clients` and `min_evaluate_clients`: Minimum number of clients required for training and evaluation.
   - `initial_parameters`: The initial model parameters, converted to Flower's parameter format.
   - `evaluate_metrics_aggregation_fn`: Function to aggregate evaluation metrics, set to `weighted_average` for a representative aggregation.

### Federated Learning Simulation Configuration:
1. **Client Function:** The `client_fn` function is used to instantiate clients. Each client will have its instance of the U-Net model and its dataset.
2. **Number of Clients:** The total number of clients participating in the federated learning is set to `NUM_CLIENTS`.
3. **Server Configuration:** 
   - The server is configured to run a specified number of training rounds (`num_rounds`).
4. **Client Resources:** 
   - If the model is trained on a GPU, client resources are specified accordingly (e.g., number of GPUs).
5. **Starting the Simulation:** 
   - The simulation is started using `fl.simulation.start_simulation`, passing the client function, number of clients, server configuration, strategy, and client resources.

### Importance of This Setup:
- This setup is critical for launching a federated learning system where multiple clients train a shared model in a distributed manner.
- It demonstrates how to initialize a federated learning environment, define a custom strategy for training and evaluation, and configure the server and clients for simulation.

This initialization process sets the stage for federated training of the U-Net model, leveraging Flower's capabilities to manage and coordinate the distributed learning process.


In [None]:
net = UNET().to(DEVICE)
params = get_parameters(net)

strategy = SaveModelStrategy(
    fraction_fit=1.0,
    fraction_evaluate=1.0,
    min_fit_clients=3,
    min_evaluate_clients=2,
    min_available_clients=NUM_CLIENTS,
    initial_parameters=fl.common.ndarrays_to_parameters(params),
    evaluate_metrics_aggregation_fn=weighted_average,
)

# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if DEVICE.type == "cuda":
    client_resources = {"num_gpus": 4}

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=300),  # Just three rounds
    strategy=strategy,
    client_resources=client_resources,
)