## Import Libraries

In [None]:
import os 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torchvision
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as trns
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
from torch.utils.data import random_split, DataLoader, Dataset

%matplotlib inline
plt.rcParams['figure.figsize'] = [8, 10]

## Load dataset

In [None]:
# List available images
# ls '../datasets/nih-chest-xray/images'

In [None]:
data_path = '../datasets/nih-chest-xray/images'

In [None]:
images_list = dict()
for f in os.listdir(data_path):
    if os.path.isfile(os.path.join(data_path, f)) and f[-4:] == '.png':
        images_list[f] = os.path.join(data_path, f)

In [None]:
# List full paths
# images_list

In [None]:
df = pd.read_csv('../datasets/nih-chest-xray/sample_labels.csv')
df.head(10)

In [None]:
def plotImage(img):
    plt.imshow(np.array(img) / 255)
    
def openImage(str_path):
    return Image.open(str_path).convert('RGB')

In [None]:
plotImage(openImage(images_list[list(images_list.keys())[1]]))

In [None]:
classes = [
    'Atelectasis', 
    'Consolidation', 
    'Infiltration', 
    'Pneumothorax', 
    'Edema', 
    'Emphysema', 
    'Fibrosis', 
    'Effusion', 
    'Pneumonia', 
    'Pleural_thickening', 
    'Cardiomegaly', 
    'Nodule', 
    'Mass', 
    'Hernia', 
    'No Finding']

In [None]:
class DataLoaderCompose(Dataset):
    def __init__(self, data, transforms):
        self.image_paths = [images_list[f] for f in data[0]]
        self.labels = data[1]
        self.transforms = transforms
        
    def __len__(self):
        return len(data[0])
    
    def __getitem__(self, idx):
        image = self.transforms(openImage(self.image_paths[idx]))
        target = torch.tensor([int(cls in self.labels[idx]) for cls in classes], dtype=torch.float32)
        return (image, target)

In [None]:
data = (df.iloc[:5000, 0], [df.iloc[i, 1].split('|') for i in range(5000)])

In [None]:
dataset = DataLoaderCompose(data, trns.Compose([
    trns.Resize((240, 240)),
    trns.ToTensor(), 
    trns.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],inplace=True)
]))

In [None]:
print(dataset[1][1])

In [None]:
plt.imshow(dataset[1][0].permute((1, 2, 0)))

## Train and test split 80/20

In [None]:
train_dataset, validation_dataset = random_split(dataset, [int(len(dataset) * 0.80), len(dataset) - int(len(dataset) * 0.80)])

In [None]:
train_dataset_size = len(train_dataset)
validation_dataset_size = len(validation_dataset)
train_dataset_size, validation_dataset_size

# Benchmark Experiment

### Define base model

In [None]:
resnet34 = models.resnet34(pretrained=True)
resnet34

### Constants for all models

In [None]:
num_classes = 15
input_shape = (3, 240, 240)

batch_size = 64
learning_rate = 1e-1

epochs_per_client = 1 #5 #If not federated learning model, just epoch number for single client
opt_func = torch.optim.Adam

### Constants for federated learning models

In [None]:
num_clients = 2 #5
rounds = 2 #10

### DataLoader functions

In [None]:
def get_device():
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader(DataLoader):
        def __init__(self, dl, device):
            self.dl = dl
            self.device = device

        def __iter__(self):
            for batch in self.dl:
                yield to_device(batch, self.device)

        def __len__(self):
            return len(self.dl)

device = get_device()

In [None]:
print(f'Device detected is {device}')

## Benchmark 1: Vanilla PyTorch model

In [None]:
def accuracy(out, labels):
    _, preds = torch.max(out, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch
        out =self(images)
        loss = F.cross_entropy(out, labels)
        return loss
    
    def validation_step(self, batch):
        images, labels = batch
        out = self(images)
        loss = F.cross_entropy(out, labels)
        acc = accuracy(out, labels)
        return {"val_loss": loss.detach(), "val_acc": acc}
    
    def validation_epoch_end(self, outputs):
        batch_loss = [x["val_loss"] for x in outputs]
        epoch_loss = torch.stack(batch_loss).mean()
        batch_acc = [x["val_acc"] for x in outputs]
        epoch_acc = torch.stack(batch_acc).mean()
        return {"val_loss": epoch_loss.item(), "val_acc": epoch_acc.item()}
    
    def epoch_end(self, epoch, epochs, result):
        print("Epoch: [{}/{}], last_lr: {:.4f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
        epoch+1, epochs, result["lrs"][-1], result["train_loss"], result["val_loss"], result["val_acc"]))
        
class ResNet34(ImageClassificationBase):
    def __init__(self):
        super().__init__()
        self.network = models.resnet34(pretrained=True)
        number_of_features = self.network.fc.in_features
        self.network.fc = nn.Linear(number_of_features, num_classes)
        
    def forward(self, xb):
        return self.network(xb)
    
    def freeze(self): #by freezing all the layers but the last one we allow it to warm up (the others are already good at training)
        for param in self.network.parameters():
            param.require_grad=False
        for param in self.network.fc.parameters():
            param.require_grad=True
            
    def unfreeze(self):
        for param in self.network.parameters():
            param.require_grad=True

In [None]:
model = ResNet34()
model

In [None]:
# Instantiate network
model = to_device(model, device)

In [None]:
train_dl = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=3, pin_memory=True)
val_dl = DataLoader(validation_dataset, batch_size, num_workers=3, pin_memory=True)

In [None]:
train_dl = DeviceDataLoader(train_dl, device)
val_dl = DeviceDataLoader(val_dl, device)

In [None]:
@torch.no_grad()
def evaluate(model, val_dl):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_dl]
    return model.validation_epoch_end(outputs)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]
    
def train(epochs, max_lr, model, train_dl, val_dl, weight_decay=0,
                 grad_clip=None, opt_func=torch.optim.Adam):
    torch.cuda.empty_cache()
    
    history = []
    opt = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)
    sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr, epochs=epochs,
                                               steps_per_epoch=len(train_dl))
    
    for epoch in range(epochs):
        model.train()
        train_loss = []
        lrs = []
        for batch in tqdm(train_dl):
            loss = model.training_step(batch)
            train_loss.append(loss)
            loss.backward()
            
            if grad_clip:
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
                
            opt.step()
            opt.zero_grad()
            
            lrs.append(get_lr(opt))
            sched.step()
            
        result = evaluate(model, val_dl)
        result["train_loss"] = torch.stack(train_loss).mean().item()
        result["lrs"] = lrs
        model.epoch_end(epoch, epochs, result)
        history.append(result)
    return history

In [None]:
result = evaluate(model, val_dl)
result

In [None]:
model.freeze()

In [None]:
%%time

history = train(epochs_per_client, learning_rate, model, train_dl, val_dl,
                        grad_clip=grad_clip, opt_func=opt_func)

In [None]:
accuracy = [x["val_acc"] for x in history]
plt.plot(accuracy, "-bx")

plt.title("Acccuracy vs number of epochs")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")

## Benchmark 2: Vanilla federated learning model

In [None]:
class SimpleFederatedNetwork(torch.nn.Module):    
    def __init__(self):
        super().__init__()
        self.network = models.resnet34(pretrained=True)
        self.network.fc = torch.nn.Linear(self.network.fc.in_features, num_classes)
        self.track_layers = {
            'layer4':  self.network.layer4,
            'linear': self.network.fc
        }
        self.freeze()

    def freeze(self):
        for param in self.network.parameters():
            param.requires_grad = False
        for layer_name in self.track_layers:
            for param in self.track_layers[layer_name].parameters():
                param.requires_grad = True
    
    def forward(self, x_batch):
        out = torch.sigmoid(self.network(x_batch))
        return out
    
    def get_track_layers(self):
        return self.track_layers
    
    def apply_parameters(self, parameters_dict):
        with torch.no_grad():
            for layer_name in parameters_dict:
                layer_params = list(self.track_layers[layer_name].parameters())
                for i in range(len(layer_params)):
                    layer_params[i].data = (layer_params[i].data + (parameters_dict[layer_name][i] - 
                                                layer_params[i].data))
    
    def get_parameters(self):
        parameters_dict = dict()
        for layer_name in self.track_layers:
            parameters_dict[layer_name] = [param.data.clone().detach() for param in self.track_layers
                                                [layer_name].parameters()]
        return parameters_dict
    
    def batch_accuracy(self, outputs, labels):
        with torch.no_grad():
            return torch.tensor(torch.sum((outputs > 0.5) == labels).item() / len(outputs))
    
    def _process_batch(self, batch):
        images, labels = batch
        outputs = self(images)
        loss = torch.nn.functional.binary_cross_entropy(outputs, labels)
        accuracy = self.batch_accuracy(outputs, labels)
        return (loss, accuracy)
    
    def fit(self, dataset, epochs, lr, batch_size=128, opt=torch.optim.SGD):
        self.train()
        dataloader = DeviceDataLoader(DataLoader(dataset, batch_size, shuffle=True), device)
        optimizer = opt(self.parameters(), lr)
        history = []
        for epoch in range(epochs):
            losses = []
            accs = []
            for batch in dataloader:
                loss, acc = self._process_batch(batch)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                loss.detach()
                losses.append(loss)
                accs.append(acc)
            avg_loss = torch.stack(losses).mean().item()
            avg_acc = torch.stack(accs).mean().item()
            history.append((avg_loss, avg_acc))
        return history
    
    def evaluate(self, dataset, batch_size=64):
        self.eval()
        dataloader = DeviceDataLoader(DataLoader(dataset, batch_size), device)
        losses = []
        accs = []
        with torch.no_grad():
            for batch in dataloader:
                loss, acc = self._process_batch(batch)
                losses.append(loss)
                accs.append(acc)
        avg_loss = torch.stack(losses).mean().item()
        avg_acc = torch.stack(accs).mean().item()
        return (avg_loss, avg_acc)

In [None]:
class Client:
    def __init__(self, client_id, dataset):
        self.client_id = client_id
        self.dataset = dataset
    
    def get_dataset_size(self):
        return len(self.dataset)
    
    def get_client_id(self):
        return self.client_id
    
    def train(self, parameters_dict, return_model_dict=False):
        net = to_device(SimpleFederatedNetwork(), device)
        net.apply_parameters(parameters_dict)
        train_history = net.fit(self.dataset, epochs_per_client, learning_rate, batch_size)
        print(self.client_id + ':')
        for i, res in enumerate(train_history):
            print('Epoch [{}]: Loss = {}, Accuracy = {}'.format(i + 1, round(res[0], 4), round(res[1], 4)))
        return net.get_parameters(), net.state_dict() if return_model_dict else None

In [None]:
examples_per_client = train_dataset_size // num_clients

In [None]:
lengths = [min(i + examples_per_client, train_dataset_size) - i for i in range(0, train_dataset_size, examples_per_client)]

In [None]:
lengths

In [None]:
# Define client splits
client_datasets = random_split(train_dataset, lengths)
clients = [Client('client_' + str(i), client_datasets[i]) for i in range(num_clients)]

In [None]:
# Instantiate federate network
fl = to_device(SimpleFederatedNetwork(), device)

In [None]:
history = []

for i in range(rounds):
    print('Start Round {} ...'.format(i + 1))
    i_parameters = fl.get_parameters()
    i_1_parameters = dict([(layer_name, [0 for param in fl.track_layers[layer_name].parameters()]) for layer_name in i_parameters])
    
    # Iterate through clients
    for j, client in enumerate(clients):
        client_parameters, state_dict = client.train(i_parameters, (j == len(clients) - 1))
        if j == len(clients) - 1:
            fl.load_state_dict(state_dict)
            fl.apply_parameters(client_parameters)
            train_loss, train_acc = fl.evaluate(train_dataset)
            val_loss, val_acc = fl.evaluate(validation_dataset)
            print('Results round {}, train_loss = {}, val_loss = {}, val_acc = {}\n'.format(i + 1, round(train_loss, 4), 
                    round(val_loss, 4), round(val_acc, 4)))
            history.append((train_loss, val_loss))

        fraction = client.get_dataset_size() / train_dataset_size
        for layer_name in client_parameters:
            for j in range(len(client_parameters[layer_name])):
                i_1_parameters[layer_name][j] += fraction * client_parameters[layer_name][j]

    fl.apply_parameters(i_1_parameters)

In [None]:
plt.plot([i + 1 for i in range(len(history))], [history[i][0] for i in range(len(history))], color='r', label='train loss')
plt.plot([i + 1 for i in range(len(history))], [history[i][1] for i in range(len(history))], color='b', label='val loss')
plt.title('Training history')
plt.legend()
plt.show()

## Benchmark 3: Encrypted model