In [None]:
!pip install opendatasets
!pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio===0.9.0 -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
!nvidia-smi

In [None]:
import opendatasets as od
od.download("https://www.kaggle.com/prashant268/chest-xray-covid19-pneumonia")
od.download("https://www.kaggle.com/sudalairajkumar/novel-corona-virus-2019-dataset")

In [12]:
import torch
import torchvision
from torchvision import datasets, transforms
import numpy as np

In [13]:
LEARNING_RATE = 0.001 # 0.0001
MAX_EPOCHS = 10
TARGET_FOLDER = "weights"

In [14]:
transform = transforms.Compose([transforms.Resize((244, 244))
                                , transforms.ToTensor()]
                               #, transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # find mean and std of dataset
                              )

test_set = datasets.ImageFolder('chest-xray-covid19-pneumonia/Data/test', transform=transform)

train_set = dataset = datasets.ImageFolder('chest-xray-covid19-pneumonia/Data/train', transform=transform)

In [15]:
def label_preparation(labels):
    labels = np.array(labels)
    labels[labels > 0] = 1
    return list(labels)

def label_preparation_tensor(labels):
    labels[labels > 0] = 1
    return labels

train_set.targets = label_preparation(train_set.targets)

test_set.targets = label_preparation(test_set.targets)

In [16]:
def calc_accuracy(result, labels):
    result = torch.sigmoid(result).round()
    
    correct_results_sum = (result == labels).sum().float()
    acc = correct_results_sum/labels.shape[0]
    acc *= 100
    
    return acc


# Federated Training

In [17]:
def train_fn(model, data_loader, optimizer, loss):
    accuracy = 0
    for step, [images, labels] in enumerate(data_loader, 1):
        labels = label_preparation_tensor(labels)

        optimizer.zero_grad()
        
        result = model(images)
        targets = labels.float()
        
        # normal dataloader and custom dataloader return different sized targets
        # normal has a shape of [32], while custom dataloader (correctly) uses [32, 1]
        if len(targets.shape) == 1:
            targets = targets.unsqueeze(1)

        loss_value = loss(result.float(), targets)

        # backpropagation
        loss_value.backward()
        optimizer.step()
                                    
        #if step % 10 == 0:
        #    accuracy += calc_accuracy(result, targets)
        #    print(f"TRAINING - Step: {step}, loss: {loss_value}, rolling accuracy: {accuracy*10/step}")

In [18]:
def test_fn(model, test_loader, loss):
    with torch.no_grad():
        loss_value = 0
        accuracy = 0
        for step, [images, labels] in enumerate(test_loader, 1):
            labels = label_preparation_tensor(labels)

            result = model(images)
            targets = labels.detach().unsqueeze(1).float()

            loss_value += loss(result.detach(), targets)
            accuracy += calc_accuracy(result.detach(), targets)

        loss_value /= step
        accuracy /=  step
        
    #print(f"TESTING - Loss: {loss_value}, Accuracy: {accuracy}")
    return f"Test accuracy is {accuracy.item():.3f}%"

In [19]:
class DatasetWithTransform(torch.utils.data.Dataset):
    dataset: torch.utils.data.Dataset
    transform: torchvision.transforms.Compose

    def __init__(self, dataset: torch.utils.data.Dataset, transform: torchvision.transforms.Compose) -> None:
        self.dataset = dataset
        self.transform = transform
                      
    def __getitem__(self, index):
        x, y = self.dataset[index]
        x = self.transform(x)
        
        if isinstance(y, int):
            y = torch.IntTensor([y])
        elif isinstance(y, long):
            y = torch.LongTensor([y])
        elif isinstance(y, float):
            y = torch.FloatTensor([y])
        elif isinstance(y, double):
            y = torch.DoubleTensor([y])
        
        return x, y

    def __len__(self):
        return len(self.dataset)

In [21]:
from __future__ import annotations
from typing import List, Tuple, Callable, Any
from collections import OrderedDict
import copy
import time

class FederatedLearningTest:
    __batch_size = 32
    __shuffle_train_data = True
    __num_workers = 0
    __pin_memory = True
    
    def __init__(
        self,
        model: torch.nn.Module,
        train_dataset: torch.utils.data.Dataset,
        test_dataset: torch.utils.data.Dataset,
        train_epoch_fn: Callable[[torch.nn.Module, torch.utils.data.DataLoader, torch.optim.Optimizer, torch.nn.modules.loss._Loss], None], 
        test_fn: Callable[[torch.nn.Module, torch.utils.data.DataLoader, torch.nn.modules.loss._Loss], Any],
        use_gpu: bool,
        epochs_to_train: int,
        local_epochs_to_train: int
    ):
        self.__model = model
        self.__train_dataset = train_dataset
        self.__test_dataset = test_dataset
        self.__train_epoch_fn = train_epoch_fn
        self.__test_fn = test_fn
        self.__use_gpu = use_gpu
        self.__epochs_to_train = epochs_to_train
        self.__local_epochs_to_train = local_epochs_to_train

    def __update_client_model(self, federated_model, client_model):
        client_model.load_state_dict(federated_model.state_dict(), True)
        return client_model

    def __federated_average(self, federated_model, client_models):
        average_weights = OrderedDict()

        number_of_clients = len(client_models)
        for client_model in client_models:
            for key, value in client_model.state_dict().items():
                if key in average_weights:
                    average_weights[key] += (1./number_of_clients) * value.clone()
                else:
                    average_weights[key] = (1./number_of_clients) * value.clone()


        federated_model.load_state_dict(average_weights, True)
        return federated_model
        
    # TODO future: currently only works with image data, make it more generic
    def __prepare_train_data(
        self,
        train_dataset: torch.utils.data.Dataset,
        no_of_clients: int,
        augment_data: bool,
        full_data_on_each_client: bool,
        batch_size: int, shuffle: bool, num_workers: int, pin_memory: bool
    ) -> List[torch.utils.data.DataLoader]:
    
        if augment_data:
            # flip, then rotate and shift
            transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.RandomAffine(degrees=30, translate=(0.1, 0.1)),
                transforms.ToTensor()
            ])
        else:
            transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])

        data_loaders = []
        if full_data_on_each_client:
            data_loaders = [
                torch.utils.data.DataLoader(
                    DatasetWithTransform(train_dataset, transform),
                    batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
                )
            for _ in range(no_of_clients)]
        else:
            chunk_size_train = len(train_dataset) // no_of_clients
            indices_train = np.random.permutation(np.arange(chunk_size_train * no_of_clients)) 

            for idx in range(no_of_clients):
                data_loader_train = torch.utils.data.Subset(train_dataset, indices_train[idx*chunk_size_train:(idx+1)*chunk_size_train])
                data_loaders += [
                    torch.utils.data.DataLoader(
                        DatasetWithTransform(data_loader_train, transform),
                        batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
                    )
                ]

        return data_loaders
    
    def __get_device(self):
        if not self.__use_gpu:
            device = torch.device("cpu")
            print("Using CPU")
        elif torch.cuda.is_available():
            device = torch.device("cuda")
            print("Using GPU")
        else:
            device = torch.device("cpu")
            print("You requested to use GPU, but CUDA is not available. Using CPU instead")
        
        return device
    
    def __wrap_data_loader(self, loader, device):
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            yield x, y
    
    def set_batch_size(self, batch_size: int) -> FederatedLearningTest:
        self.__batch_size = batch_size
        return self
    
    def set_shuffle_train_data(self, shuffle_train_data: bool) -> FederatedLearningTest:
        self.__shuffle_train_data = shuffle_train_data
        return self
    
    def set_dataloader_workers(self, num_workers: int) -> FederatedLearningTest:
        self.__num_workers = num_workers
        return self
    
    def set_pin_training_memory(self, pin_memory: bool) -> FederatedLearningTest:
        self.__pin_memory = pin_memory
        return self
    
    def compare(
        self,
        augment_data: bool,
        full_data_on_each_client: bool,
        no_of_clients: int,
        construct_optimizer_fn: Callable[[torch.nn.Module], torch.optim.Optimizer],
        construct_loss_fn: Callable[[], torch.nn.modules.loss._Loss],
        test_after_each_epoch: bool = False
    ):
        device = self.__get_device()
        if device.type == "cuda": 
            torch.cuda.empty_cache()
        
        sequential_training_start = time.time()
        print("Training sequential model")
        sequential_model = copy.deepcopy(self.__model).to(device)
        
        sequential_optimizer = construct_optimizer_fn(sequential_model)
        sequential_loss = construct_loss_fn().to(device)
        
        sequential_train_loader = torch.utils.data.DataLoader(
            self.__train_dataset,
            batch_size=self.__batch_size,
            shuffle=self.__shuffle_train_data,
            num_workers=self.__num_workers,
            pin_memory=self.__pin_memory
        )
        
        sequential_test_loader = torch.utils.data.DataLoader(
            self.__test_dataset,
            batch_size=self.__batch_size,
            shuffle=False,
            num_workers=self.__num_workers
        )
        
        for idx in range(self.__epochs_to_train):
            print(f"Epoch: {idx+1}")
            sequential_model.train()
            sequential_train_loader_device = self.__wrap_data_loader(sequential_train_loader, device)
            self.__train_epoch_fn(sequential_model, sequential_train_loader_device, sequential_optimizer, sequential_loss)
            sequential_train_loader_device = None
            
            if test_after_each_epoch:
                sequential_model.eval()
                with torch.no_grad():
                    print(f"Test results for epoch {idx+1}:")
                    sequential_test_loader_device = self.__wrap_data_loader(sequential_test_loader, device)
                    print(self.__test_fn(sequential_model, sequential_test_loader_device, sequential_loss))
                    sequential_test_loader_device = None
        
        print(f"Sequential training complete after {time.time() - sequential_training_start:.2f} seconds, testing ...")
        
        sequential_test_start = time.time()
        sequential_model.eval()
        with torch.no_grad():
            sequential_test_loader = self.__wrap_data_loader(sequential_test_loader, device)
            sequential_test_results = self.__test_fn(sequential_model, sequential_test_loader, sequential_loss)
        
        sequential_model.to("cpu")
        sequential_train_loader = None
        sequential_test_loader = None
        sequential_loss = None
        sequential_optimizer = None
        sequential_model = None
        
        print(f"Sequential testing complete after {time.time() - sequential_test_start:.2f} seconds, results:")
        print(sequential_test_results)
        
        if device.type == "cuda": 
            torch.cuda.empty_cache()
            
        print("\n++++++++++++++++++++++++++++++++++++++++\n")
        
        federated_training_start = time.time()
        print("Training federated model(s)")
        federated_model = copy.deepcopy(self.__model)
        client_models = [copy.deepcopy(self.__model) for _ in range(no_of_clients)]
        
        federated_loss = construct_loss_fn().to(device)
        
        federated_training_dataloaders = self.__prepare_train_data(
            self.__train_dataset,
            no_of_clients,
            augment_data,
            full_data_on_each_client,
            self.__batch_size,
            self.__shuffle_train_data,
            self.__num_workers,
            self.__pin_memory
        )
        
        federated_test_loader = torch.utils.data.DataLoader(
            self.__test_dataset,
            batch_size=self.__batch_size,
            shuffle=False,
            num_workers=self.__num_workers
        )

        for idx in range(self.__epochs_to_train):
            for client_idx in range (no_of_clients):
                print(f"Epoch: {idx+1}, client {client_idx+1}")
                
                client_model = client_models[client_idx].to(device)
                
                client_model = self.__update_client_model(federated_model, client_model)
                client_optimizer = construct_optimizer_fn(client_model)
                
                for local_epoch_idx in range(self.__local_epochs_to_train): 
                    client_training_loader_device = self.__wrap_data_loader(federated_training_dataloaders[client_idx], device)
                    self.__train_epoch_fn(client_model, client_training_loader_device, client_optimizer, federated_loss)
                    client_training_loader_device = None

                if device.type == "cuda": 
                    torch.cuda.empty_cache()

                client_model.to("cpu")
                client_optimizer = None
            
            federated_model = self.__federated_average(federated_model, client_models)

            if test_after_each_epoch:
                federated_model.to(device)
                federated_model.eval()
                with torch.no_grad():
                    print(f"Test results for epoch {idx+1}:")
                    federated_test_loader_device = self.__wrap_data_loader(federated_test_loader, device)
                    print(self.__test_fn(federated_model, federated_test_loader_device, federated_loss))
                    federated_test_loader_device = None
                    
                federated_model.to("cpu")
        
        
        print(f"Federated training complete after {time.time() - federated_training_start:.2f} seconds, testing ...")
        
        federated_test_start = time.time()
        federated_model.to(device)
        federated_model.eval()
        with torch.no_grad():
            federated_test_loader = self.__wrap_data_loader(federated_test_loader, device)
            federated_test_results = self.__test_fn(federated_model, federated_test_loader, federated_loss)
        
        federated_model.to("cpu")
        federated_training_dataloaders = None
        federated_test_loader = None
        federated_loss = None
        client_models = None
        federated_model = None
        
        print(f"Federated testing complete after {time.time() - federated_test_start:.2f} seconds, results:")
        print(federated_test_results)

In [24]:
number_of_clients = 4

def construct_optimizer(model):
    return torch.optim.Adam(model.parameters())

def construct_loss():
    # use pos weights because of unbalanced data set
    return torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1./10])) # binary crossentropy


model_to_test = torchvision.models.resnet18(pretrained=False, num_classes=1)

# local_epochs_to_train=1 means using FedSGD
federated_test = FederatedLearningTest(
    model_to_test, train_set, test_set,
    train_epoch_fn=train_fn, 
    test_fn=test_fn,
    use_gpu=True, epochs_to_train=5, local_epochs_to_train=1
).set_batch_size(32).set_shuffle_train_data(True).set_dataloader_workers(0).set_pin_training_memory(True)

federated_test.compare(
    augment_data=False,
    full_data_on_each_client=False,
    no_of_clients=number_of_clients,
    construct_optimizer_fn = construct_optimizer,
    construct_loss_fn = construct_loss,
    test_after_each_epoch = False
)


Using GPU
Training sequential model
Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5
Sequential training complete after 713.35 seconds, testing ...
Sequential testing complete after 34.34 seconds, results:
Test accuracy is 88.491%

++++++++++++++++++++++++++++++++++++++++

Training federated model(s)
Epoch: 1, client 1
Epoch: 1, client 2
Epoch: 1, client 3
Epoch: 1, client 4
Epoch: 2, client 1
Epoch: 2, client 2
Epoch: 2, client 3
Epoch: 2, client 4
Epoch: 3, client 1
Epoch: 3, client 2
Epoch: 3, client 3
Epoch: 3, client 4
Epoch: 4, client 1
Epoch: 4, client 2
Epoch: 4, client 3
Epoch: 4, client 4
Epoch: 5, client 1
Epoch: 5, client 2
Epoch: 5, client 3
Epoch: 5, client 4
Federated training complete after 756.69 seconds, testing ...
Federated testing complete after 36.42 seconds, results:
Test accuracy is 96.265%


In [23]:
number_of_clients = 4

def construct_optimizer(model):
    return torch.optim.Adam(model.parameters())

def construct_loss():
    # use pos weights because of unbalanced data set
    return torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1./10])) # binary crossentropy


model_to_test = torchvision.models.resnet18(pretrained=False, num_classes=1)

# local_epochs_to_train > 1 means using FedAvg, should perform better than FedSGD
federated_test = FederatedLearningTest(
    model_to_test, train_set, test_set,
    train_epoch_fn=train_fn, 
    test_fn=test_fn,
    use_gpu=True, epochs_to_train=5, local_epochs_to_train=5
).set_batch_size(32).set_shuffle_train_data(True).set_dataloader_workers(0).set_pin_training_memory(True)

federated_test.compare(
    augment_data=False,
    full_data_on_each_client=False,
    no_of_clients=number_of_clients,
    construct_optimizer_fn = construct_optimizer,
    construct_loss_fn = construct_loss,
    test_after_each_epoch = False
)

Using GPU
Training sequential model
Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5
Sequential training complete after 642.70 seconds, testing ...
Sequential testing complete after 31.87 seconds, results:
Test accuracy is 96.494%

++++++++++++++++++++++++++++++++++++++++

Training federated model(s)
Epoch: 1, client 1
Epoch: 1, client 2
Epoch: 1, client 3
Epoch: 1, client 4
Epoch: 2, client 1
Epoch: 2, client 2
Epoch: 2, client 3
Epoch: 2, client 4
Epoch: 3, client 1
Epoch: 3, client 2
Epoch: 3, client 3
Epoch: 3, client 4
Epoch: 4, client 1
Epoch: 4, client 2
Epoch: 4, client 3
Epoch: 4, client 4
Epoch: 5, client 1
Epoch: 5, client 2
Epoch: 5, client 3
Epoch: 5, client 4
Federated training complete after 3510.09 seconds, testing ...
Federated testing complete after 31.76 seconds, results:
Test accuracy is 92.226%
