In [None]:
!pip install opendatasets

In [None]:
!pip install tensorflow

In [None]:
!pip install tensorflow_federated

In [None]:
!nvidia-smi

In [None]:
!pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio===0.9.0 -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
import opendatasets as od
od.download("https://www.kaggle.com/prashant268/chest-xray-covid19-pneumonia")
od.download("https://www.kaggle.com/sudalairajkumar/novel-corona-virus-2019-dataset")

In [1]:
import torch
import torchvision
from torchvision import datasets, transforms
import numpy as np

In [2]:
LEARNING_RATE = 0.001 # 0.0001
MAX_EPOCHS = 10
TARGET_FOLDER = "weights"

In [3]:
transform = transforms.Compose([transforms.Resize((244, 244))
                                , transforms.ToTensor()]
                               #, transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # find mean and std of dataset
                              )

test_set = datasets.ImageFolder('chest-xray-covid19-pneumonia/Data/test', transform=transform)

train_set = dataset = datasets.ImageFolder('chest-xray-covid19-pneumonia/Data/train', transform=transform)

In [4]:
def label_preparation(labels):
    labels = np.array(labels)
    labels[labels > 0] = 1
    return list(labels)

def label_preparation_tensor(labels):
    labels[labels > 0] = 1
    return labels

train_set.targets = label_preparation(train_set.targets)

test_set.targets = label_preparation(test_set.targets)

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(f"Current device: {device}")

Current device: cuda:0


In [6]:
def calc_accuracy(result, labels):
    result = torch.sigmoid(result).round()
    
    correct_results_sum = (result == labels).sum().float()
    acc = correct_results_sum/labels.shape[0]
    acc *= 100
    
    return acc


# Federated

In [7]:
number_of_clients = 4

In [8]:
from collections import OrderedDict

def update_client(federated_model, client):
    client.load_state_dict(federated_model.state_dict(), True)
    return client
    
def federated_average(federated_model, client_models):
    average_weights = OrderedDict()
    
    number_of_clients = len(client_models)
    for client_model in client_models:
        for key, value in client_model.state_dict().items():
            if key in average_weights:
                average_weights[key] += (1./number_of_clients) * value.clone()
            else:
                average_weights[key] = (1./number_of_clients) * value.clone()
                
                
    federated_model.load_state_dict(average_weights, True)
    return federated_model
    

In [9]:
def train_federated(model, data_loader, optimizer, loss):
    """
    model -- neural net
    data_loader -- dataloader for train images
    optimizer -- optimizer
    """
    model.train()
    
    accuracy = 0
    for step, [images, labels] in enumerate(data_loader, 1):
        images = images.to(device)
        labels = label_preparation_tensor(labels.to(device))

        optimizer.zero_grad()
        
        result = model(images)
        targets = labels.float()

        loss_value = loss(result.float(), targets)

        # backpropagation
        loss_value.backward()
        optimizer.step()
                                    
        if step % 10 == 0:
            accuracy += calc_accuracy(result, targets)
            print(f"TRAINING - Step: {step}, loss: {loss_value}, rolling accuracy: {accuracy*10/step}")

In [10]:
def test_federated(model, test_loader, loss):
    """    
    model -- neural net 
    test_loader -- dataloader of test images
    epoch -- current epoch
    """
    model.eval()
    model.to(device)
    
    with torch.no_grad():
        loss_value = 0
        accuracy = 0
        for step, [images, labels] in enumerate(test_loader, 1):
            images = images.to(device)
            labels = label_preparation_tensor(labels.to(device))

            result = model(images)
            targets = labels.detach().unsqueeze(1).float()

            loss_value += loss(result.detach(), targets)
            accuracy += calc_accuracy(result.detach(), targets)

        loss_value /= step
        accuracy /=  step
        
        if device.type == "cuda": 
            torch.cuda.empty_cache()
      
    model.to("cpu")
    print(f"TESTING - Loss: {loss_value}, Accuracy: {accuracy}")
    return accuracy

In [11]:
def run_federated_training(federated_model, client_models, client_training_loader):
    # use pos weights because of unbalanced data set
    federated_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1./10])).to(device) # binary crossentropy
    # federated_loss = torch.nn.CrossEntropyLoss(weight=torch.tensor([1./10])).to(device) # sparce categorical crossentropy (federated)

    # start training
    for epoch in range(MAX_EPOCHS):
        for client_idx in range (number_of_clients):
            print(f"+++ FEDERATED MODEL {client_idx}, EPOCH: {epoch+1} +++++++++")

            client_model = client_models[client_idx]
            client_model.to(device)
            client_model = update_client(federated_model, client_model)
            client_optimizer = torch.optim.Adam(client_model.parameters())
        
            train_federated(client_model, client_training_loader[client_idx], client_optimizer, federated_loss)
        
            if device.type == "cuda": 
                torch.cuda.empty_cache()
            
            client_model.to("cpu")
            
            # save interim weights
            #torch.save(client_model.state_dict(), f'./{TARGET_FOLDER}/client_model_{client_idx}_epoch_{epoch}.ckpt')

        federated_model = federated_average(federated_model, client_models)
        
        # save interim weights
        #torch.save(federated_model.state_dict(), f'./{TARGET_FOLDER}/epoch_{epoch}.ckpt')
        
        if test_federated(federated_model, federated_test_loader, federated_loss) > 97 and epoch >= 4:
            print("Early return: SUCCESS")
            return


In [12]:
class DatasetWithTransform(torch.utils.data.Dataset):
    dataset: torch.utils.data.Dataset
    transform: torchvision.transforms.Compose

    def __init__(self, dataset: torch.utils.data.Dataset, transform: torchvision.transforms.Compose) -> None:
        self.dataset = dataset
        self.transform = transform
                      
    def __getitem__(self, index):
        x, y = self.dataset[index]
        x = self.transform(x)
        y = torch.Tensor([y])
        return x, y

    def __len__(self):
        return len(self.dataset)

In [13]:
from typing import List, Tuple

def prepare_train_data(
    train_set: torch.utils.data.Dataset,
    no_of_clients: int,
    augment_data: bool,
    full_data_on_each_client: bool,
    batch_size=32, shuffle=True, num_workers=0, pin_memory=True
) -> List[torch.utils.data.DataLoader]:
    
    if augment_data:
        # flip, shift & rotate
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomAffine(degrees=90, translate=(0.2,0.2)),
            transforms.ToTensor()
        ])
    else:
        transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])
    
    data_loaders = []
    if full_data_on_each_client:
        data_loaders = [
            torch.utils.data.DataLoader(
                DatasetWithTransform(train_set, transform),
                batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
            )
        for _ in range(no_of_clients)]
    else:
        chunk_size_train = len(train_set) // no_of_clients
        indices_train = np.random.permutation(np.arange(chunk_size_train * no_of_clients)) 

        for idx in range(no_of_clients):
            data_loader_train = torch.utils.data.Subset(train_set, indices_train[idx*chunk_size_train:(idx+1)*chunk_size_train])
            data_loaders += [
                torch.utils.data.DataLoader(
                    DatasetWithTransform(data_loader_train, transform),
                    batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
                )
            ]
    
    return data_loaders

In [15]:
federated_model = torchvision.models.resnet18(pretrained=False, num_classes=1)
federated_test_loader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=False, num_workers=0)
client_models = [torchvision.models.resnet18(pretrained=False, num_classes=1) for _ in range(number_of_clients)]

train_loader_no_split_no_augment = prepare_train_data(train_set, number_of_clients, augment_data=False, full_data_on_each_client=True)

run_federated_training(federated_model, client_models, train_loader_no_split_no_augment)

train_loader_no_split_no_augment = None

+++ FEDERATED MODEL 0, EPOCH: 1 +++++++++


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


TRAINING - Step: 10, loss: 0.050081975758075714, rolling accuracy: 90.625
TRAINING - Step: 20, loss: 0.022462524473667145, rolling accuracy: 92.1875
TRAINING - Step: 30, loss: 0.01131326425820589, rolling accuracy: 93.75000762939453
TRAINING - Step: 40, loss: 0.04491039365530014, rolling accuracy: 91.40625
TRAINING - Step: 50, loss: 0.023222673684358597, rolling accuracy: 91.875
TRAINING - Step: 60, loss: 0.045296888798475266, rolling accuracy: 90.10417175292969
TRAINING - Step: 70, loss: 0.06982383131980896, rolling accuracy: 90.625
TRAINING - Step: 80, loss: 0.03420264273881912, rolling accuracy: 89.84375
TRAINING - Step: 90, loss: 0.005413800943642855, rolling accuracy: 90.97222137451172
TRAINING - Step: 100, loss: 0.03828900307416916, rolling accuracy: 90.3125
TRAINING - Step: 110, loss: 0.024426111951470375, rolling accuracy: 90.625
TRAINING - Step: 120, loss: 0.02831759862601757, rolling accuracy: 91.14583587646484
TRAINING - Step: 130, loss: 0.008062878623604774, rolling accurac

TRAINING - Step: 40, loss: 0.01668170839548111, rolling accuracy: 94.53125
TRAINING - Step: 50, loss: 0.0321396142244339, rolling accuracy: 95.625
TRAINING - Step: 60, loss: 0.15118177235126495, rolling accuracy: 90.62500762939453
TRAINING - Step: 70, loss: 0.00391416298225522, rolling accuracy: 91.96428680419922
TRAINING - Step: 80, loss: 0.01864749565720558, rolling accuracy: 92.96875
TRAINING - Step: 90, loss: 0.06693791598081589, rolling accuracy: 93.40277862548828
TRAINING - Step: 100, loss: 0.040872469544410706, rolling accuracy: 91.875
TRAINING - Step: 110, loss: 0.01573413610458374, rolling accuracy: 92.32954406738281
TRAINING - Step: 120, loss: 0.029136471450328827, rolling accuracy: 92.44792175292969
TRAINING - Step: 130, loss: 0.029816988855600357, rolling accuracy: 92.78845977783203
TRAINING - Step: 140, loss: 0.038741081953048706, rolling accuracy: 92.63392639160156
TRAINING - Step: 150, loss: 0.03776372969150543, rolling accuracy: 92.91667175292969
TRAINING - Step: 160, l

TRAINING - Step: 60, loss: 0.00802704505622387, rolling accuracy: 94.79167175292969
TRAINING - Step: 70, loss: 0.015077120624482632, rolling accuracy: 95.53571319580078
TRAINING - Step: 80, loss: 0.07080896198749542, rolling accuracy: 95.3125
TRAINING - Step: 90, loss: 0.011457400396466255, rolling accuracy: 95.1388931274414
TRAINING - Step: 100, loss: 0.00916589330881834, rolling accuracy: 95.0
TRAINING - Step: 110, loss: 0.04465518891811371, rolling accuracy: 94.88636016845703
TRAINING - Step: 120, loss: 0.127092644572258, rolling accuracy: 94.79167175292969
TRAINING - Step: 130, loss: 0.004293479491025209, rolling accuracy: 95.19230651855469
TRAINING - Step: 140, loss: 0.012415893375873566, rolling accuracy: 95.08928680419922
TRAINING - Step: 150, loss: 0.020268872380256653, rolling accuracy: 95.20833587646484
TRAINING - Step: 160, loss: 0.06175714731216431, rolling accuracy: 94.921875
+++ FEDERATED MODEL 1, EPOCH: 4 +++++++++
TRAINING - Step: 10, loss: 0.010696601122617722, rolling

TRAINING - Step: 90, loss: 0.18532055616378784, rolling accuracy: 95.1388931274414
TRAINING - Step: 100, loss: 0.011924189515411854, rolling accuracy: 95.625
TRAINING - Step: 110, loss: 0.027639111503958702, rolling accuracy: 94.31817626953125
TRAINING - Step: 120, loss: 0.01886804774403572, rolling accuracy: 94.53125762939453
TRAINING - Step: 130, loss: 0.01733272522687912, rolling accuracy: 94.47115325927734
TRAINING - Step: 140, loss: 0.019774161279201508, rolling accuracy: 94.19642639160156
TRAINING - Step: 150, loss: 0.022301238030195236, rolling accuracy: 94.16667175292969
TRAINING - Step: 160, loss: 0.0027332142926752567, rolling accuracy: 94.53125
+++ FEDERATED MODEL 3, EPOCH: 5 +++++++++
TRAINING - Step: 10, loss: 0.005727320909500122, rolling accuracy: 96.875
TRAINING - Step: 20, loss: 0.137340247631073, rolling accuracy: 95.3125
TRAINING - Step: 30, loss: 0.021540846675634384, rolling accuracy: 96.87500762939453
TRAINING - Step: 40, loss: 0.01501383539289236, rolling accurac

In [15]:
federated_model = torchvision.models.resnet18(pretrained=False, num_classes=1)
federated_test_loader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=False, num_workers=0)
client_models = [torchvision.models.resnet18(pretrained=False, num_classes=1) for _ in range(number_of_clients)]

train_loader_split_no_augment = prepare_train_data(train_set, number_of_clients, augment_data=False, full_data_on_each_client=False)

run_federated_training(federated_model, client_models, train_loader_split_no_augment)

train_loader_split_no_augment = None

+++ FEDERATED MODEL 0, EPOCH: 1 +++++++++


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


TRAINING - Step: 10, loss: 0.05574168264865875, rolling accuracy: 84.375
TRAINING - Step: 20, loss: 0.07420063018798828, rolling accuracy: 82.8125
TRAINING - Step: 30, loss: 0.06762982159852982, rolling accuracy: 84.37500762939453
TRAINING - Step: 40, loss: 0.018281031399965286, rolling accuracy: 88.28125
+++ FEDERATED MODEL 1, EPOCH: 1 +++++++++
TRAINING - Step: 10, loss: 0.16669614613056183, rolling accuracy: 84.375
TRAINING - Step: 20, loss: 0.02745596319437027, rolling accuracy: 90.625
TRAINING - Step: 30, loss: 0.12275668978691101, rolling accuracy: 88.54167175292969
TRAINING - Step: 40, loss: 0.1068161129951477, rolling accuracy: 89.84375
+++ FEDERATED MODEL 2, EPOCH: 1 +++++++++
TRAINING - Step: 10, loss: 0.07781316339969635, rolling accuracy: 53.125
TRAINING - Step: 20, loss: 0.045210178941488266, rolling accuracy: 73.4375
TRAINING - Step: 30, loss: 0.01338769868016243, rolling accuracy: 80.20833587646484
TRAINING - Step: 40, loss: 0.060781337320804596, rolling accuracy: 82.812

+++ FEDERATED MODEL 3, EPOCH: 6 +++++++++
TRAINING - Step: 10, loss: 0.0034376138355582952, rolling accuracy: 100.0
TRAINING - Step: 20, loss: 0.06453358381986618, rolling accuracy: 90.625
TRAINING - Step: 30, loss: 0.015329144895076752, rolling accuracy: 92.70833587646484
TRAINING - Step: 40, loss: 0.00958099402487278, rolling accuracy: 93.75
TESTING - Loss: 0.015391766093671322, Accuracy: 96.18901824951172
+++ FEDERATED MODEL 0, EPOCH: 7 +++++++++
TRAINING - Step: 10, loss: 0.016877103596925735, rolling accuracy: 93.75
TRAINING - Step: 20, loss: 0.013673205859959126, rolling accuracy: 93.75
TRAINING - Step: 30, loss: 0.0312855988740921, rolling accuracy: 92.70833587646484
TRAINING - Step: 40, loss: 0.021787084639072418, rolling accuracy: 92.96875
+++ FEDERATED MODEL 1, EPOCH: 7 +++++++++
TRAINING - Step: 10, loss: 0.005191548261791468, rolling accuracy: 100.0
TRAINING - Step: 20, loss: 0.00571422278881073, rolling accuracy: 100.0
TRAINING - Step: 30, loss: 0.01611458510160446, rollin

In [16]:
federated_model = torchvision.models.resnet18(pretrained=False, num_classes=1)
federated_test_loader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=False, num_workers=0)
client_models = [torchvision.models.resnet18(pretrained=False, num_classes=1) for _ in range(number_of_clients)]

train_loader_split_augment = prepare_train_data(train_set, number_of_clients, augment_data=True, full_data_on_each_client=False)

run_federated_training(federated_model, client_models, train_loader_split_augment)

train_loader_split_augment = None

+++ FEDERATED MODEL 0, EPOCH: 1 +++++++++
TRAINING - Step: 10, loss: 0.19133463501930237, rolling accuracy: 46.875
TRAINING - Step: 20, loss: 0.08497636020183563, rolling accuracy: 56.25
TRAINING - Step: 30, loss: 0.07978109270334244, rolling accuracy: 62.500003814697266
TRAINING - Step: 40, loss: 0.07718903571367264, rolling accuracy: 68.75
+++ FEDERATED MODEL 1, EPOCH: 1 +++++++++
TRAINING - Step: 10, loss: 0.12506169080734253, rolling accuracy: 56.25
TRAINING - Step: 20, loss: 0.14163631200790405, rolling accuracy: 70.3125
TRAINING - Step: 30, loss: 0.11963967233896255, rolling accuracy: 68.75
TRAINING - Step: 40, loss: 0.08996886759996414, rolling accuracy: 71.875
+++ FEDERATED MODEL 2, EPOCH: 1 +++++++++
TRAINING - Step: 10, loss: 0.14307551085948944, rolling accuracy: 81.25
TRAINING - Step: 20, loss: 0.0527038648724556, rolling accuracy: 82.8125
TRAINING - Step: 30, loss: 0.12159614264965057, rolling accuracy: 71.875
TRAINING - Step: 40, loss: 0.10530561208724976, rolling accurac

TRAINING - Step: 10, loss: 0.11272535473108292, rolling accuracy: 71.875
TRAINING - Step: 20, loss: 0.11196562647819519, rolling accuracy: 79.6875
TRAINING - Step: 30, loss: 0.07710064947605133, rolling accuracy: 80.20833587646484
TRAINING - Step: 40, loss: 0.11920598149299622, rolling accuracy: 79.6875
TESTING - Loss: 0.1738169640302658, Accuracy: 92.30182647705078
+++ FEDERATED MODEL 0, EPOCH: 7 +++++++++
TRAINING - Step: 10, loss: 0.1313421130180359, rolling accuracy: 87.5
TRAINING - Step: 20, loss: 0.0959039032459259, rolling accuracy: 68.75
TRAINING - Step: 30, loss: 0.06914462894201279, rolling accuracy: 76.04167175292969
TRAINING - Step: 40, loss: 0.03857017308473587, rolling accuracy: 77.34375
+++ FEDERATED MODEL 1, EPOCH: 7 +++++++++
TRAINING - Step: 10, loss: 0.06267137825489044, rolling accuracy: 81.25
TRAINING - Step: 20, loss: 0.10168406367301941, rolling accuracy: 76.5625
TRAINING - Step: 30, loss: 0.07913432270288467, rolling accuracy: 79.16667175292969
TRAINING - Step: 