#### Central Training Notebook without differential Privacy
This notebook helps you set up a central machine learning workflow involving a Neural Network based on UNET architecture. It consists o f  sections starting from loading necessary external modules, i.e., pytorch, albumentations, and wandb, to name a few. Next, we'll move to loading our data, modifying it to ensure consistency, and finally defining our model and training. The last section involves examining the results of training, to ensure the validity of the model and ensuring privacy is preserved since we intend to use the model on sensitive medical data.

### 1. Loading External Libraries

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import torchvision
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim

### 2. Custom Dataloader definition
In this section we specify the path to our data on disk and load the images and associated masks

In [None]:
class WoundDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, num_images=None, start_index=0):
        """
        :param image_dir: str, path to the directory containing the images
        :param mask_dir: str, path to the directory containing the masks
        :param transform: callable, optional transform to be applied on a sample
        :param num_images: int, optional number of images to load
        :param start_index: int, index to start loading images from
        """
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        self.images.sort()  # Make sure images are sorted to have consistent order
        
        if num_images is not None:
            end_index = start_index + num_images
            self.images = self.images[start_index:end_index]
        else:
            self.images = self.images[start_index:]
            
    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index])
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0
        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]
        return image, mask


### 3. UNET Model Definition

The provided class defines a deep learning architecture called "UNET," which is a convolutional neural network (CNN) commonly used for semantic image segmentation tasks. UNET is known for its symmetric architecture with a contracting path (encoder) and an expansive path (decoder). The code is structured into two main classes: DoubleConv and UNET.

The DoubleConv class represents a basic building block within the UNET architecture. It consists of two consecutive convolutional layers, each followed by batch normalization and a rectified linear unit (ReLU) activation function. This sequence of operations is designed to capture and enhance important image features while maintaining spatial information. The class is parameterized by the number of input channels (in_channels) and the number of output channels (out_channels) for the convolutional layers. These building blocks are used both in the contracting and expanding paths of the UNET architecture.

The UNET class defines the UNET architecture itself. It takes several arguments, including the number of input channels (in_channels), the number of output channels (out_channels), and a list of feature map sizes (features) that determine the depth of the network. The UNET architecture consists of a downsampling (encoder) path, a bottleneck layer, and an upsampling (decoder) path. The downsampling path repeatedly applies the DoubleConv blocks while reducing spatial dimensions through max-pooling operations. The bottleneck layer further captures abstract features. The upsampling path includes transposed convolutional layers to upsample the feature maps and concatenate them with skip connections from the downsampling path. Finally, a 1x1 convolutional layer is applied to produce the segmentation mask.

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

### 4. Utility Function Definitions

save_checkpoint: This function is used to save the current state of the training process, allowing you to resume training at a later time or use the trained model for inference. The state of the model, along with other necessary parameters, is saved to a file specified by the filename parameter.

state: This is a dictionary containing various elements of the model's state, such as the model parameters, optimizer state, and potentially other relevant information needed to resume training or perform inference.
filename: The name of the file where the checkpoint will be saved.
load_checkpoint: This function is used to load a saved checkpoint into a model, restoring its state to continue training or perform inference.

checkpoint: A dictionary containing the saved state of the model.
model: The neural network model into which the checkpoint will be loaded.
get_loaders: This function is used to create and return the data loaders for the training and validation datasets.

train_dir, train_maskdir: The directory paths for the training images and their corresponding masks.
val_dir, val_maskdir: The directory paths for the validation images and their corresponding masks.
batch_size: The number of samples per batch.
train_transform, val_transform: Transformations to be applied to the training and validation datasets, respectively.
num_workers: The number of worker threads to use for data loading.
pin_memory: If set to True, this will copy Tensors into CUDA pinned memory before returning them, which can result in faster data transfer to CUDA-enabled GPUs.
The function returns two data loaders: train_loader for the training dataset and val_loader for the validation dataset.
save_predictions_as_imgs: This function is used to save the model's predictions on a dataset as images to a specified folder.

loader: A DataLoader object providing batches of input data and corresponding ground truth labels.
model: The neural network model whose predictions will be saved.
folder: The directory path where the prediction images will be saved.
device: The device on which the model and data should be loaded before making predictions. Typically this 

is a CUDA-enabled GPU.
The model is set to evaluation mode before making predictions, ensuring that certain layers (such as dropout and batch normalization) behave differently during inference compared to training.
The function saves two sets of images: the modelâ€™s predictions and the ground truth masks, allowing for easy comparison between the pred
icted and actual values.
Together, these functions provide utility for training, evaluating, and saving the results of a deep learning model, specifically for image segmentation tasks. The WoundDataset class mentioned in the functions seems to be a custom dataset class for handling images and their corresponding segmentation masks, although its implementation is not provided in the given code snippet.

In [None]:
def save_checkpoint(state, filename="my_checkpoint_large_04lr.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_ds = WoundDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
        num_images=10,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = WoundDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

def save_predictions_as_imgs(
    loader, model, folder="central/adam/prediction_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

### 5. Train & Validation Function Definitions

confusion_matrix(preds, y):
This function calculates the confusion matrix values, including true positives (tp), false positives (fp), false negatives (fn), and true negatives (tn), given the model's predictions (preds) and ground truth labels (y). These values are essential for evaluating binary classification performance.

validate_model(loader, model, device="cuda"):
This function evaluates the model's performance on a given data loader (loader) in validation mode. It calculates various metrics such as loss, accuracy, Dice score, IoU (Intersection over Union) score, and confusion matrix values for binary segmentation. It also tracks the total number of correct pixels and total pixels processed during evaluation.

train_model(loader, model, optimizer, loss_fn, scaler):
This function is responsible for training the model on a provided data loader (loader). It iterates through the data in batches, computes predictions, calculates the loss using a specified loss function (loss_fn), and updates the model's weights through backpropagation. It uses gradient scaling with torch.cuda.amp.autocast() to handle mixed-precision training when applicable. Additionally, it reduces the learning rate on a plateau using a learning rate scheduler (optim.lr_scheduler.ReduceLROnPlateau).

In [None]:
def confusion_matrix(preds, y):
    tp = ((preds == 1) & (y == 1)).sum()
    fp = ((preds == 1) & (y == 0)).sum()
    fn = ((preds == 0) & (y == 1)).sum()
    tn = ((preds == 0) & (y == 0)).sum()

    return tn, fp, fn, tp

def validate_model(loader, model, criterion, device="cuda"):
    """Test the network on the training set."""
    print("~~~~ In test ~~~~")
    #criterion = nn.BCEWithLogitsLoss()
    loss = 0
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    iou_score  = 0
    tn_sum = fp_sum = fn_sum = tp_sum = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            loss += criterion(preds, y).item()
            preds = (preds > 0.5).float()
            tn, fp, fn, tp = confusion_matrix(preds, y)
            tn_sum += tn
            fp_sum += fp
            fn_sum += fn
            tp_sum += tp
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds.sum() + y.sum()) + 1e-8
            )
            intersection = (preds * y).sum()
            union = (preds + y).sum() - intersection
            iou_score += (intersection + 1e-8) / (union + 1e-8)

    num_batches = len(loader)
    loss /= num_batches
    acc = num_correct/num_pixels*100
    diceS = dice_score/num_batches
    iouS = iou_score/num_batches
    correct_pixel = num_correct
    total_pixel = num_pixels
    print(f"Loss = {loss}")
    print(f"IoU Score = {iouS}")
    print(f"Dice Score = {diceS}")
    print("~~~~~ Out of test ~~~~~")

    model.train()
    
    result = [acc.item(), diceS.item(), iouS.item(), loss, correct_pixel.item(), total_pixel, tn_sum.item(), fp_sum.item(), fn_sum.item(), tp_sum.item()]
    
    return result

def train_model(loader, model, optimizer, loss_fn, scaler):
    print("In Train")
    print(f'length of loader {len(loader)}')
    model.train()
    loop = tqdm(loader)
    #scheduler =  optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5, verbose=True)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)
        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)
        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())
    #scheduler.step(loss)

    print("Out of Train")

### 6. Define Hyperparameters and dataframe

In [None]:
column_names = {'epoch':[]
            , 'accuracy':[]
            , 'dice_score':[]
            , 'IoU_score':[]
            , 'Loss': []
            , 'correct_pixels':[]
            , 'total_pixel':[]
            , 'tn': []
            , 'fp': []
            , 'fn': []
            , 'tp': []
            , 'lr': []}

labels = ["Accuracy", "Dice Score", "IOU", "Loss", "Correct Pixel", "Total Pixel", "tn", "fp", "fn", "tp","lr"]

# run a loop to test various VS, LR, and Image sizes
BATCH_SIZES = [25, 4, 6, 8, 10, 12, 15, 20, 25, 30]
learning_rates = [0.00001, 0.00001, 0.000001]
imageSize = [140, 240, 512]

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PIN_MEMORY = True
LOAD_MODEL = False
# define address for your training and val data 
TRAIN_IMG_DIR = "../../wound_data/data/woundData/train_images"
TRAIN_MASK_DIR = "../../wound_data/data/woundData/train_masks"
VAL_IMG_DIR = "../../wound_data/data/woundData/val_images"
VAL_MASK_DIR = "../../wound_data/data/woundData/val_masks"

In [None]:
print(DEVICE)

### 7. Set optimizer

In [None]:
optim_flag = "adam"
#optim_flag = "sgd"

### 8. Image tranform functions

##### train_transform:

This transformation pipeline is designed for augmenting and preprocessing images in the training dataset. It includes several operations:
A.Resize: Resizes the input images to a specified height and width (IMAGE_HEIGHT and IMAGE_WIDTH), ensuring that all training images have consistent dimensions.
A.Rotate: Randomly rotates the images within a specified limit (35 degrees in this case) to increase the diversity of training samples. This helps the model generalize better to different orientations of objects in the images.
A.HorizontalFlip: Randomly flips images horizontally with a probability of 0.5. This operation simulates variations in object orientation, contributing to better model robustness.
A.VerticalFlip: Randomly flips images vertically with a lower probability (0.1). It introduces additional variability by considering vertical reflections of objects.
A.Normalize: Standardizes pixel values by subtracting the mean and dividing by the standard deviation. In this case, it centers the data around zero with unit variance. The max_pixel_value parameter ensures that pixel values are in the [0, 1] range.
ToTensorV2: Converts the image and mask data into PyTorch tensors, making them compatible with the model's input format.

##### val_transforms:

This transformation pipeline is intended for preprocessing images in the validation dataset. It includes:
A.Resize: Resizes the validation images to the same height and width as specified for the training images, ensuring consistent dimensions between the training and validation datasets.
A.Normalize: Performs the same pixel value normalization as in the training transform. However, it doesn't include data augmentation operations like rotation or flips, as the goal of validation is to evaluate the model's performance on unaltered data.

### 9. Training our Model

Hyperparameter Exploration Loop: The code defines a loop that iterates over different batch sizes and image sizes. Inside this loop, it reconfigures the batch size and image dimensions according to the current iteration. This allows for a systematic exploration of how these hyperparameters affect model training and evaluation.

Training and Evaluation: For each combination of batch size and image size, the code sets up data loaders, model, loss function, optimizer, and other training-related components. It then performs the following steps:

Validates the model on the validation dataset to record initial performance metrics.
Conducts training for a specified number of epochs, saving model checkpoints and recording training metrics (e.g., accuracy, dice score, IoU) for each epoch.
Appends the results of each epoch to a DataFrame, allowing for easy tracking and analysis of training progress.
Saves the DataFrame as a CSV file, preserving the training and evaluation results for later analysis.

In [None]:
for lr in learning_rates:
    for bs in BATCH_SIZES:
        for img_size in imageSize:
            
            df = pd.DataFrame(columns=column_names)
            LEARNING_RATE = lr
            BATCH_SIZE = bs
            NUM_EPOCHS = 200
            NUM_WORKERS = 0
            IMAGE_HEIGHT = img_size 
            IMAGE_WIDTH = img_size  
        
            csv_filename = str(optim_flag) + "_output_" + "bs" + str(BATCH_SIZE) + "_lr" + str(LEARNING_RATE) +"_imgSize"+ str(IMAGE_HEIGHT)+"_epoch"+ str(NUM_EPOCHS)+".csv"

            train_transform = A.Compose(
                [
                    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
                    A.Rotate(limit=35, p=1.0),
                    A.HorizontalFlip(p=0.5),
                    A.VerticalFlip(p=0.1),
                    A.Normalize(
                        mean=[0.0, 0.0, 0.0],
                        std=[1.0, 1.0, 1.0],
                        max_pixel_value=255.0,
                    ),
                    ToTensorV2(),
                ],
            )
            
            # used to transform validation set
            val_transforms = A.Compose(
                [
                    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
                    A.Normalize(
                        mean=[0.0, 0.0, 0.0],
                        std=[1.0, 1.0, 1.0],
                        max_pixel_value=255.0,
                    ),
                    ToTensorV2(),
                ],
            )
            
            
            train_loader, val_loader = get_loaders(
                TRAIN_IMG_DIR,
                TRAIN_MASK_DIR,
                VAL_IMG_DIR,
                VAL_MASK_DIR,
                BATCH_SIZE,
                train_transform,
                val_transforms,
                NUM_WORKERS,
                PIN_MEMORY,
            )
            
            
            model = UNET().to(DEVICE)
            loss_fn = nn.BCEWithLogitsLoss()
            
            if optim_flag == "adam":
                optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
            else:
                optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)
    
            checkpoint_name = "central/"+str(optim_flag)+"/checkpoint2/"+str(optim_flag)+"_checkpoint_imgSize" + str(IMAGE_WIDTH) + "_bs" + str(BATCH_SIZE) + "_lr" + str(LEARNING_RATE) + ".pth.tar"
            print('check point')
            print(checkpoint_name)
            
            if LOAD_MODEL:
                load_checkpoint(torch.load(checkpoint_name), model)
            
            single_epoch_result = validate_model(val_loader, model, loss_fn, device=DEVICE)
            print(f'length of train loader -> {len(train_loader)}')
            print("First Check - Before training starts")
            for label, item in zip(labels, single_epoch_result):
                print(f"{label}: {item}")
                
            scaler = torch.cuda.amp.GradScaler()
            
            for epoch in range(NUM_EPOCHS):
                single_epoch_result = []
                train_model(train_loader, model, optimizer, loss_fn, scaler)
                
                # save model
                checkpoint = {
                    "state_dict": model.state_dict(),
                    "optimizer":optimizer.state_dict(),
                }
                save_checkpoint(checkpoint, checkpoint_name)
                print('checkpoint saved')
                
                # Saving model accuracy to a csv 
                #single_epoch_result.append(epoch)
                single_epoch_result = validate_model(val_loader, model, loss_fn, device=DEVICE)
                single_epoch_result.append(LEARNING_RATE)
                
                single_epoch_result.insert(0, epoch)
                df.loc[len(df)] = single_epoch_result
            
                # print some examples to a folderthu
                # save_predictions_as_imgs(
                #     val_loader, model, folder="/central/adam/prediction_images/", device=DEVICE
                # )
            
            
            # Save the DataFrame as a CSV file
            csv_path = 'central/'+str(optim_flag)+'/dataframe_fix/' + csv_filename
            df.to_csv(csv_path, index=False)
