In [1]:
from torchvision import datasets, transforms
from fastai.callback.all import *
from fastai.vision.all import *
import flwr as fl

from copy import deepcopy
from torch import optim

from torch.nn.utils import clip_grad_norm_
from opacus.utils import module_modification

import warnings
warnings.filterwarnings('ignore')

from flwr.common import *

In [2]:
import numpy as np
import torch
import torch.utils.data
import torchvision
from typing import Callable
import pandas as pd
import matplotlib.pyplot as plt
import random
from sklearn.metrics import roc_curve, auc
from fastai.vision.all import ClassificationInterpretation

def set_seed(dls, seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
    dls.rng.seed(seed)

def save_matrix(learn, path):
    interp = ClassificationInterpretation.from_learner(learn)
    interp.plot_confusion_matrix(figsize=(7,7))
    plt.savefig(path)


def save_roc(learn, path):
    preds, y, loss = learn.get_preds(with_loss=True)
    probs = np.exp(preds[:, 1])
    fpr, tpr, _ = roc_curve(y, probs, pos_label=1)
    roc_auc = auc(fpr, tpr)
    print(f'ROC area is {roc_auc}')
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', linestyle='--', label='Worst case')
    plt.xlim([-0.01, 1.0])
    plt.ylim([0.0, 1.01])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic')
    plt.legend(loc="lower right")
    plt.savefig(path)


class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
    '''
    Samples elements randomly from a given list of indices for imbalanced dataset
    Parameters
    ----------
    indices: list
        a list of indices
    num_samples: int
        number of samples to draw
    callback_get_label: Callable
        a callback-like function which takes two arguments - dataset and index
    '''

    def __init__(self, dataset, indices: list = None, num_samples: int = None, callback_get_label: Callable = None):
        # if indices is not provided, all elements in the dataset will be considered
        self.indices = list(range(len(dataset))) if indices is None else indices

        # define custom callback
        self.callback_get_label = callback_get_label

        # if num_samples is not provided, draw len(indices) samples in each iteration
        self.num_samples = len(self.indices) if num_samples is None else num_samples
        print(self.num_samples)

        # distribution of classes in the dataset
        df = pd.DataFrame()
        df["label"] = self._get_labels(dataset)
        df.index = self.indices
        df = df.sort_index()

        label_to_count = df["label"].value_counts()

        weights = 1.0 / label_to_count[df["label"]]
        print(weights)

        self.weights = torch.DoubleTensor(weights.to_list())

    def _get_labels(self, dataset):
        if self.callback_get_label:
            return self.callback_get_label(dataset)
        elif isinstance(dataset, torchvision.datasets.MNIST):
            return dataset.train_labels.tolist()
        elif isinstance(dataset, torchvision.datasets.ImageFolder):
            return [x[1] for x in dataset.imgs]
        elif isinstance(dataset, torchvision.datasets.DatasetFolder):
            return dataset.samples[:][1]
        elif isinstance(dataset, torch.utils.data.Subset):
            return dataset.dataset.imgs[:][1]
        elif isinstance(dataset, torch.utils.data.Dataset):
            return dataset.get_labels()
        else:
            raise NotImplementedError

    def __iter__(self):
        return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, replacement=True))

    def __len__(self):
        return self.num_samples

def get_imbalance_weights(ds):
    labels = [x[1] for x in ds.train]
    _,label_counts = np.unique(labels,return_counts=True)
    weights = torch.DoubleTensor((1/label_counts)[labels])
    return weights

In [3]:
# set device
device = torch.device('cuda:0')
torch.cuda.set_device(0)

In [4]:
# get model
model = resnet18

In [5]:
data_path = Path('data/test') 

In [6]:
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
    get_items = get_image_files,
    get_y = parent_label,
# Antoine has resized images in some cases               
#    item_tfms = [Resize(32)],
    splitter= None)

In [8]:
ds = dblock.datasets(data_path)

In [9]:
dls = dblock.dataloaders(data_path, bs=64, device=device, dl_type=WeightedDL, wgts=get_imbalance_weights(ds), num_workers=0)

In [10]:
set_seed(dls, 42)

In [11]:
# Get model and change last layer
learn = cnn_learner(dls, model, metrics=[accuracy, RocAucBinary()])
learn.model = module_modification.convert_batchnorm_modules(learn.model)
learn.model[1][8] = nn.Linear(512, 2, bias=False)

In [12]:
learn.model.to(device);

In [13]:
loaded_weights = np.load('weights/Split_03_03_03/cancer_database/federated03_03_03/round-3-weights.npz', allow_pickle=True)

In [14]:
def set_parameters(learn, parameters):
    params_dict = zip(learn.model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    learn.model.load_state_dict(state_dict, strict=True)

In [15]:
set_parameters(learn, parameters_to_weights(loaded_weights['arr_0'][None][0]))

In [79]:
test_fnames = get_image_files(data_path)

In [80]:
tst_dl = dls.test_dl(test_fnames, with_labels=True)

In [81]:
learn.validate(dl=tst_dl)

(#3) [0.47540485858917236,0.8208417296409607,0.8882455247120778]