In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
LEARNING_RATE = 3e-3 #0.14
EPOCHS = 25
TARGET_FOLDER = "weights"
K_FOLDS = 3

In [14]:
import torch.nn as nn
import torch
from torchvision import datasets, transforms
from sklearn.model_selection import KFold
import numpy as np

In [4]:
transform = transforms.Compose([transforms.Resize((256, 256)),
                                 transforms.ToTensor()])

data_set = datasets.ImageFolder('/content/drive/MyDrive/data_aml/X-Ray Image DataSet', transform=transform)


In [5]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super(ConvBlock, self).__init__()
        for_pad = lambda s: s if s > 2 else 3
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=(for_pad(kernel_size) - 1)//2, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        return out

class TripleConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TripleConvBlock, self).__init__()
        self.conv_block_1 = ConvBlock(in_channels, out_channels)
        self.conv_block_2 = ConvBlock(out_channels, in_channels, kernel_size=1)  
        self.conv_block_3 = ConvBlock(in_channels, out_channels)

    def forward(self, x):
        out = self.conv_block_1(x)
        out = self.conv_block_2(out)
        out = self.conv_block_3(out)
        return out

class Model3(nn.Module):
    def __init__(self):
        super(Model3, self).__init__()
        self.seq = nn.Sequential(
        ConvBlock(3, 8),
        nn.MaxPool2d(2, stride=2),
        ConvBlock(8, 16),
        nn.MaxPool2d(2, stride=2),
        TripleConvBlock(16, 32),
        nn.MaxPool2d(2, stride=2),
        TripleConvBlock(32,64),
        nn.MaxPool2d(2, stride=2),
        TripleConvBlock(64,128),
        nn.MaxPool2d(2, stride=2),
        TripleConvBlock(128,256),
        ConvBlock(256, 128, kernel_size=1),
        ConvBlock(128, 256),
        nn.Conv2d(256, 3, 3, padding=(3-1)//2, stride=1),
        nn.ReLU(),
        nn.BatchNorm2d(3),
        nn.Flatten(),
        nn.Linear(507,3)
        )

    def forward(self, x):
        return self.seq(x)

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from torchsummary import summary
model = Model3().to(device)
summary(model, (3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 8, 256, 256]             216
       BatchNorm2d-2          [-1, 8, 256, 256]              16
         LeakyReLU-3          [-1, 8, 256, 256]               0
         ConvBlock-4          [-1, 8, 256, 256]               0
         MaxPool2d-5          [-1, 8, 128, 128]               0
            Conv2d-6         [-1, 16, 128, 128]           1,152
       BatchNorm2d-7         [-1, 16, 128, 128]              32
         LeakyReLU-8         [-1, 16, 128, 128]               0
         ConvBlock-9         [-1, 16, 128, 128]               0
        MaxPool2d-10           [-1, 16, 64, 64]               0
           Conv2d-11           [-1, 32, 64, 64]           4,608
      BatchNorm2d-12           [-1, 32, 64, 64]              64
        LeakyReLU-13           [-1, 32, 64, 64]               0
        ConvBlock-14           [-1, 32,

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [7]:
def calc_accuracy(result, labels):
    result = torch.round(result)
    probs = torch.softmax(result, dim=1)

    correct_results_sum = (probs.argmax(dim=1) == labels).sum().float()
    acc = correct_results_sum/labels.shape[0]
    acc = torch.round(acc * 100)
    
    return acc


In [8]:
def train(model, data_loader, optimizer, criterion):
    for step, [images, labels] in enumerate(data_loader,0):
        optimizer.zero_grad()
        result = model(images)
        
        loss = criterion(result, labels.long().squeeze())
                
        # backpropagation
        loss.backward()
        optimizer.step()
                                    
        # if step % 10 == 0:
        #    print(f"Step: {step}, loss: {loss}")

In [9]:
def test(model, test_loader, criterion):
    loss = 0
    accuracy = 0

    for step, [images, labels] in enumerate(test_loader,0):
        result = model(images)
        loss += criterion(result.detach(), labels.detach().long().squeeze())
        accuracy += calc_accuracy(result.detach(), labels.detach())
    loss /= step
    accuracy /=  step
  
    #print(f"Loss: {loss}, Accuracy: {accuracy}")
    return f"Accuracy is {accuracy.item():.3f}%"

In [10]:
import torchvision

class DatasetWithTransform(torch.utils.data.Dataset):
    dataset: torch.utils.data.Dataset
    transform: torchvision.transforms.Compose

    def __init__(self, dataset: torch.utils.data.Dataset, transform: torchvision.transforms.Compose) -> None:
        self.dataset = dataset
        self.transform = transform
                      
    def __getitem__(self, index):
        x, y = self.dataset[index]
        x = self.transform(x)
        
        if isinstance(y, int):
          y = torch.IntTensor([y])
        elif isinstance(y, long):
          y = torch.LongTensor([y])
        elif isinstance(y, float):
          y = torch.FloatTensor([y])
        elif isinstance(y, double):
          y = torch.DoubleTensor([y])
    
        return x, y

    def __len__(self):
        return len(self.dataset)

In [15]:
from __future__ import annotations
from typing import List, Tuple, Callable, Any
from collections import OrderedDict
import copy
import time

class FederatedLearningTest:
    __batch_size = 32
    __shuffle_train_data = True
    __num_workers = 0
    __pin_memory = True
    
    def __init__(
        self,
        model: torch.nn.Module,
        train_dataset: torch.utils.data.Dataset,
        test_dataset: torch.utils.data.Dataset,
        train_epoch_fn: Callable[[torch.nn.Module, torch.utils.data.DataLoader, torch.optim.Optimizer, torch.nn.modules.loss._Loss], None], 
        test_fn: Callable[[torch.nn.Module, torch.utils.data.DataLoader, torch.nn.modules.loss._Loss], Any],
        use_gpu: bool,
        epochs_to_train: int,
        local_epochs_to_train: int
    ):
        self.__model = model
        self.__train_dataset = train_dataset
        self.__test_dataset = test_dataset
        self.__train_epoch_fn = train_epoch_fn
        self.__test_fn = test_fn
        self.__use_gpu = use_gpu
        self.__epochs_to_train = epochs_to_train
        self.__local_epochs_to_train = local_epochs_to_train

    def __update_client_model(self, federated_model, client_model):
        client_model.load_state_dict(federated_model.state_dict(), True)
        return client_model

    def __federated_average(self, federated_model, client_models):
        average_weights = OrderedDict()

        number_of_clients = len(client_models)
        for client_model in client_models:
            for key, value in client_model.state_dict().items():
                if key in average_weights:
                    average_weights[key] += (1./number_of_clients) * value.clone()
                else:
                    average_weights[key] = (1./number_of_clients) * value.clone()


        federated_model.load_state_dict(average_weights, True)
        return federated_model
        
    # TODO future: currently only works with image data, make it more generic
    def __prepare_train_data(
        self,
        train_dataset: torch.utils.data.Dataset,
        no_of_clients: int,
        augment_data: bool,
        full_data_on_each_client: bool,
        batch_size: int, shuffle: bool, num_workers: int, pin_memory: bool
    ) -> List[torch.utils.data.DataLoader]:
    
        if augment_data:
            # flip, then rotate and shift
            transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.RandomAffine(degrees=30, translate=(0.1, 0.1)),
                transforms.ToTensor()
            ])
        else:
            transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])

        data_loaders = []
        if full_data_on_each_client:
            data_loaders = [
                torch.utils.data.DataLoader(
                    DatasetWithTransform(train_dataset, transform),
                    batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
                )
            for _ in range(no_of_clients)]
        else:
            chunk_size_train = len(train_dataset) // no_of_clients
            indices_train = np.random.permutation(np.arange(chunk_size_train * no_of_clients)) 

            for idx in range(no_of_clients):
                data_loader_train = torch.utils.data.Subset(train_dataset, indices_train[idx*chunk_size_train:(idx+1)*chunk_size_train])
                data_loaders += [
                    torch.utils.data.DataLoader(
                        DatasetWithTransform(data_loader_train, transform),
                        batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
                    )
                ]

        return data_loaders
    
    def __get_device(self):
        if not self.__use_gpu:
            device = torch.device("cpu")
            print("Using CPU")
        elif torch.cuda.is_available():
            device = torch.device("cuda")
            print("Using GPU")
        else:
            device = torch.device("cpu")
            print("You requested to use GPU, but CUDA is not available. Using CPU instead")
        
        return device
    
    def __wrap_data_loader(self, loader, device):
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            yield x, y
    
    def set_batch_size(self, batch_size: int) -> FederatedLearningTest:
        self.__batch_size = batch_size
        return self
    
    def set_shuffle_train_data(self, shuffle_train_data: bool) -> FederatedLearningTest:
        self.__shuffle_train_data = shuffle_train_data
        return self
    
    def set_dataloader_workers(self, num_workers: int) -> FederatedLearningTest:
        self.__num_workers = num_workers
        return self
    
    def set_pin_training_memory(self, pin_memory: bool) -> FederatedLearningTest:
        self.__pin_memory = pin_memory
        return self
    
    def compare(
        self,
        augment_data: bool,
        full_data_on_each_client: bool,
        no_of_clients: int,
        construct_optimizer_fn: Callable[[torch.nn.Module], torch.optim.Optimizer],
        construct_loss_fn: Callable[[], torch.nn.modules.loss._Loss],
        test_after_each_epoch: bool = False
    ):
        device = self.__get_device()
        if device.type == "cuda": 
            torch.cuda.empty_cache()
        
        sequential_training_start = time.time()
        print("Training sequential model")
        sequential_model = copy.deepcopy(self.__model).to(device)
        
        sequential_optimizer = construct_optimizer_fn(sequential_model)
        sequential_loss = construct_loss_fn().to(device)
        
        sequential_train_loader = torch.utils.data.DataLoader(
            self.__train_dataset,
            batch_size=self.__batch_size,
            shuffle=self.__shuffle_train_data,
            num_workers=self.__num_workers,
            pin_memory=self.__pin_memory
        )
        
        sequential_test_loader = torch.utils.data.DataLoader(
            self.__test_dataset,
            batch_size=self.__batch_size,
            shuffle=False,
            num_workers=self.__num_workers
        )
        
        for idx in range(self.__epochs_to_train):
            print(f"Epoch: {idx+1}")
            sequential_model.train()
            sequential_train_loader_device = self.__wrap_data_loader(sequential_train_loader, device)
            self.__train_epoch_fn(sequential_model, sequential_train_loader_device, sequential_optimizer, sequential_loss)
            sequential_train_loader_device = None
            
            if test_after_each_epoch:
                sequential_model.eval()
                with torch.no_grad():
                    print(f"Test results for epoch {idx+1}:")
                    sequential_test_loader_device = self.__wrap_data_loader(sequential_test_loader, device)
                    print(self.__test_fn(sequential_model, sequential_test_loader_device, sequential_loss))
                    sequential_test_loader_device = None
        
        print(f"Sequential training complete after {time.time() - sequential_training_start:.2f} seconds, testing ...")
        
        sequential_test_start = time.time()
        sequential_model.eval()
        with torch.no_grad():
            sequential_test_loader = self.__wrap_data_loader(sequential_test_loader, device)
            sequential_test_results = self.__test_fn(sequential_model, sequential_test_loader, sequential_loss)
        
        sequential_model.to("cpu")
        sequential_train_loader = None
        sequential_test_loader = None
        sequential_loss = None
        sequential_optimizer = None
        sequential_model = None
        
        print(f"Sequential testing complete after {time.time() - sequential_test_start:.2f} seconds, results:")
        print(sequential_test_results)
        
        if device.type == "cuda": 
            torch.cuda.empty_cache()
            
        print("\n++++++++++++++++++++++++++++++++++++++++\n")
        
        federated_training_start = time.time()
        print("Training federated model(s)")
        federated_model = copy.deepcopy(self.__model)
        client_models = [copy.deepcopy(self.__model) for _ in range(no_of_clients)]
        
        federated_loss = construct_loss_fn().to(device)
        
        federated_training_dataloaders = self.__prepare_train_data(
            self.__train_dataset,
            no_of_clients,
            augment_data,
            full_data_on_each_client,
            self.__batch_size,
            self.__shuffle_train_data,
            self.__num_workers,
            self.__pin_memory
        )
        
        federated_test_loader = torch.utils.data.DataLoader(
            self.__test_dataset,
            batch_size=self.__batch_size,
            shuffle=False,
            num_workers=self.__num_workers
        )

        for idx in range(self.__epochs_to_train):
            for client_idx in range (no_of_clients):
                print(f"Epoch: {idx+1}, client {client_idx+1}")
                
                client_model = client_models[client_idx].to(device)
                
                client_model = self.__update_client_model(federated_model, client_model)
                client_optimizer = construct_optimizer_fn(client_model)
                
                for local_epoch_idx in range(self.__local_epochs_to_train): 
                    client_training_loader_device = self.__wrap_data_loader(federated_training_dataloaders[client_idx], device)
                    self.__train_epoch_fn(client_model, client_training_loader_device, client_optimizer, federated_loss)
                    client_training_loader_device = None

                if device.type == "cuda": 
                    torch.cuda.empty_cache()

                client_model.to("cpu")
                client_optimizer = None
            
            federated_model = self.__federated_average(federated_model, client_models)

            if test_after_each_epoch:
                federated_model.to(device)
                federated_model.eval()
                with torch.no_grad():
                    print(f"Test results for epoch {idx+1}:")
                    federated_test_loader_device = self.__wrap_data_loader(federated_test_loader, device)
                    print(self.__test_fn(federated_model, federated_test_loader_device, federated_loss))
                    federated_test_loader_device = None
                    
                federated_model.to("cpu")
        
        
        print(f"Federated training complete after {time.time() - federated_training_start:.2f} seconds, testing ...")
        
        federated_test_start = time.time()
        federated_model.to(device)
        federated_model.eval()
        with torch.no_grad():
            federated_test_loader = self.__wrap_data_loader(federated_test_loader, device)
            federated_test_results = self.__test_fn(federated_model, federated_test_loader, federated_loss)
        
        federated_model.to("cpu")
        federated_training_dataloaders = None
        federated_test_loader = None
        federated_loss = None
        client_models = None
        federated_model = None
        
        print(f"Federated testing complete after {time.time() - federated_test_start:.2f} seconds, results:")
        print(federated_test_results)

In [16]:
number_of_clients = 4

def construct_optimizer(model):
    return torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

def construct_loss():
    return torch.nn.CrossEntropyLoss()

In [None]:
kfold = KFold(n_splits=K_FOLDS, shuffle=True)

for fold, (train_ids, test_ids) in enumerate(kfold.split(data_set)):
    
    print(f'FOLD {fold+1}')
    print('--------------------------------')
    
    dataset_train = torch.utils.data.Subset(data_set, train_ids)
    dataset_test = torch.utils.data.Subset(data_set, test_ids)
  
    model_to_test = Model3()
      
    federated_test = FederatedLearningTest(
        model_to_test, dataset_train, dataset_test,
        train_epoch_fn=train, 
        test_fn=test,
        use_gpu=True, epochs_to_train=EPOCHS, local_epochs_to_train=1
    ).set_batch_size(32).set_shuffle_train_data(True).set_dataloader_workers(0).set_pin_training_memory(True)

    federated_test.compare(
        augment_data=False,
        full_data_on_each_client=False,
        no_of_clients=number_of_clients,
        construct_optimizer_fn = construct_optimizer,
        construct_loss_fn = construct_loss,
        test_after_each_epoch = True
    )

FOLD 1
--------------------------------
Using GPU
Training sequential model
Epoch: 1
Test results for epoch 1:
Accuracy is 11.636%
Epoch: 2
Test results for epoch 2:
Accuracy is 66.273%
Epoch: 3
Test results for epoch 3:
Accuracy is 32.182%
Epoch: 4
Test results for epoch 4:
Accuracy is 74.818%
Epoch: 5
Test results for epoch 5:
Accuracy is 59.455%
Epoch: 6
Test results for epoch 6:
Accuracy is 68.636%
Epoch: 7
Test results for epoch 7:
Accuracy is 63.545%
Epoch: 8
Test results for epoch 8:
Accuracy is 52.909%
Epoch: 9
Test results for epoch 9:
Accuracy is 79.091%
Epoch: 10
Test results for epoch 10:
Accuracy is 78.909%
Epoch: 11
Test results for epoch 11:
Accuracy is 70.273%
Epoch: 12
Test results for epoch 12:
Accuracy is 81.727%
Epoch: 13
Test results for epoch 13:
Accuracy is 61.455%
Epoch: 14
Test results for epoch 14:
Accuracy is 60.636%
Epoch: 15
Test results for epoch 15:
Accuracy is 45.000%
Epoch: 16
Test results for epoch 16:
Accuracy is 74.909%
Epoch: 17
Test results for epo

In [17]:
kfold = KFold(n_splits=K_FOLDS, shuffle=True)

for fold, (train_ids, test_ids) in enumerate(kfold.split(data_set)):
    
    print(f'FOLD {fold+1}')
    print('--------------------------------')
    
    dataset_train = torch.utils.data.Subset(data_set, train_ids)
    dataset_test = torch.utils.data.Subset(data_set, test_ids)
  
    model_to_test = Model3()
      
    federated_test = FederatedLearningTest(
        model_to_test, dataset_train, dataset_test,
        train_epoch_fn=train, 
        test_fn=test,
        use_gpu=True, epochs_to_train=EPOCHS, local_epochs_to_train=5
    ).set_batch_size(32).set_shuffle_train_data(True).set_dataloader_workers(0).set_pin_training_memory(True)

    federated_test.compare(
        augment_data=False,
        full_data_on_each_client=False,
        no_of_clients=number_of_clients,
        construct_optimizer_fn = construct_optimizer,
        construct_loss_fn = construct_loss,
        test_after_each_epoch = False
    )

FOLD 1
--------------------------------
Using GPU
Training sequential model
Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5
Epoch: 6
Epoch: 7
Epoch: 8
Epoch: 9
Epoch: 10
Epoch: 11
Epoch: 12
Epoch: 13
Epoch: 14
Epoch: 15
Epoch: 16
Epoch: 17
Epoch: 18
Epoch: 19
Epoch: 20
Epoch: 21
Epoch: 22
Epoch: 23
Epoch: 24
Epoch: 25
Sequential training complete after 541.10 seconds, testing ...
Sequential testing complete after 11.35 seconds, results:
Accuracy is 48.636%

++++++++++++++++++++++++++++++++++++++++

Training federated model(s)
Epoch: 1, client 1
Epoch: 1, client 2
Epoch: 1, client 3
Epoch: 1, client 4
Epoch: 2, client 1
Epoch: 2, client 2
Epoch: 2, client 3
Epoch: 2, client 4
Epoch: 3, client 1
Epoch: 3, client 2
Epoch: 3, client 3
Epoch: 3, client 4
Epoch: 4, client 1
Epoch: 4, client 2
Epoch: 4, client 3
Epoch: 4, client 4
Epoch: 5, client 1
Epoch: 5, client 2
Epoch: 5, client 3
Epoch: 5, client 4
Epoch: 6, client 1
Epoch: 6, client 2
Epoch: 6, client 3
Epoch: 6, client 4
Epoch: 7, clien