In [None]:
import os
import copy

import numpy as np
import torch
import torch.nn as nn
import math

from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from unet import UNet
from dice_loss import dice_coeff

####################################################
# for data preparation
####################################################
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score
####################################################
# for plotting
####################################################
import matplotlib.pyplot as plt
from IPython.display import clear_output
############################
# Helper func
############################
from helper import * 
#################################
N_CHANNELS, N_CLASSES = 1, 1 
bilinear = True
BATCH_SIZE, EPOCHS = 16, 300
IMAGE_SIZE = (256, 256)
CROP_SIZE = (224, 224)
#########################################
DIR = 'dataset/breast'
CLIENTS = ['BUS', 'BUSIS', 'UDIAT']
CLIENTS_2 = [cl +'_2' for cl in CLIENTS]
TOTAL_CLIENTS = len(CLIENTS)
#####################################
# add the classification segment ####

In [None]:
CONSISTENT_PATH = 'breast_5folds.npy'
device = torch.device('cuda:0')
LR, WD = 1e-3, 1e-4

LAM, BETA, TH = 10, 1.5,0.9
VERSION = 1
WEIGHTS_CL = [0.0, 0.0, 0.0]

In [None]:
WARMUP_EPOCH = 150
CLIENTS_SUPERVISION = ['unlabeled', 'unlabeled', 'labeled']

MODE = 'FedST' 
# MODE = 'FedRGD'

CE_ = nn.BCELoss()
WEIGHTS = [0.0, 0.0, 1.0]

if MODE == 'FedST':
    WEIGHTS_POSTWARMUP = [0.05, 0.05, 0.9] #put more weight to client with strong supervision
else:
    WEIGHTS_POSTWARMUP = [0.5/2., 0.5/2., 0.5] #random client, but effectively, put 1/2 trust to strongly supervised client

# load training-test path

In [None]:
consistent_path = np.load(CONSISTENT_PATH, allow_pickle=True).item()
breast_dataset = dict()

idx_ = 0
denom_ = 0
for client, sup in zip(CLIENTS, CLIENTS_SUPERVISION):
    dir_of_interest = consistent_path[client][VERSION]
    x_train = dir_of_interest['x_train']
    x_test = dir_of_interest['x_test']
    y_train = dir_of_interest['y_train']
    y_test = dir_of_interest['y_test']
    '''
    add normal images 
    '''
    if sup == 'unlabeled':
#     if client != 'UDIAT': 
        DATA_TYPE = ['original', 'GT']
        for _,_, files in os.walk(DIR + '/'+ 'BUS' +'/classification/GT'):
            selected = [f for f in files if f[:6] =='normal']
            # update accordingly #
            for data in DATA_TYPE:
                tmp = [DIR + '/'+ 'BUS'  + '/classification/' + data + '/' + f for f in selected]
                if data == 'GT':
                    y_train += tmp
                else:
                    x_train += tmp
    
    print('full training data')
    WEIGHTS_CL[idx_] = len(x_train)
    denom_ += len(x_train)
    idx_ += 1
    breast_dataset[client+'_train']=Cancer(x_train, y_train, train=True,\
                                          IMAGE_SIZE=IMAGE_SIZE\
                                           , CROP_SIZE=CROP_SIZE)

    
    breast_dataset[client+'_test'] =Cancer(x_test, y_test, train=False,\
                                          IMAGE_SIZE=IMAGE_SIZE\
                                           , CROP_SIZE=CROP_SIZE)

for idx_ in range(len(WEIGHTS_CL)):
    WEIGHTS_CL[idx_] = WEIGHTS_CL[idx_]/denom_
    

# storage file

In [None]:
training_clients, testing_clients = dict(), dict()
training_clients_pl = dict()

acc_train, acc_test, loss_train, loss_test = dict(), dict(), \
                                            dict(), dict()
    
nets, optimizers = dict(), dict()

In [None]:
nets['global'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)

nets['global_2'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)

for client in CLIENTS:
    training_clients[client] = DataLoader(breast_dataset[client+'_train'], batch_size=16,\
                 shuffle=True, num_workers=8)
    training_clients_pl[client] = DataLoader(breast_dataset[client+'_train'], batch_size=1, \
                shuffle=True, num_workers=8)
    ###################################################################################
    testing_clients[client] = DataLoader(breast_dataset[client+'_test'], batch_size=1,\
                         shuffle=False, num_workers=1)
    
    acc_train[client], acc_test[client] = [], []
    loss_train[client], loss_test[client] = [], []
        
    nets[client] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)
    nets[client+'_2'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)
    optimizers[client]= optim.Adam(nets[client].parameters(), \
                                   lr=LR,weight_decay=WD)
    optimizers[client+'_2']= optim.Adam(nets[client+'_2'].parameters(), \
                                   lr=LR,weight_decay=WD)

## FedST or FedRGD: both require warmup

In [None]:
best_avg_acc, best_epoch_avg = 0, 0
index = []

for client in CLIENTS:
    acc_train[client], acc_test[client] = [], []
    loss_train[client], loss_test[client] = [], []

USE_UNLABELED_CLIENT = False
for epoch in range(EPOCHS):
    if epoch == WARMUP_EPOCH:
        WEIGHTS = WEIGHTS_POSTWARMUP
        USE_UNLABELED_CLIENT = True
        
    index.append(epoch)
    
    #################### copy fed model ###################
    copy_fed(CLIENTS, nets, fed_name='global')
    
    #### conduct training #####
    for client, supervision_t in zip(CLIENTS, CLIENTS_SUPERVISION):
        if supervision_t == 'unlabeled':
            if not USE_UNLABELED_CLIENT:
                acc_train[client].append(0)
                loss_train[client].append(0)
                continue
        
        train_model(training_clients[client], nets[client], \
                                  optimizers[client], device, \
                                  acc = acc_train[client], \
                                  loss = loss_train[client], \
                                  supervision_type = supervision_t, \
                                 warmup=True, CE_LOSS = CE_)
        
    aggr_fed(CLIENTS, WEIGHTS, nets)
    ################### test ##############################
    avg_acc = 0.0
    for client in CLIENTS:
        test(epoch, testing_clients[client], nets['global'], device, acc_test[client],\
             loss_test[client])
        avg_acc += acc_test[client][-1]
        
    avg_acc = avg_acc / TOTAL_CLIENTS
    ############################################################
    ########################################################
    if avg_acc > best_avg_acc:
        best_avg_acc = avg_acc
        best_epoch = epoch
    
    ################################
    # plot #########################
    ################################
    clear_output(wait=True)
    print(avg_acc, best_avg_acc)
    plt.figure(0)
    plt.plot(index, acc_train['UDIAT'], colors[0], label='UDIAT train')
    plt.plot(index, acc_train['BUS'], colors[1], label='BUS  train')
    plt.plot(index, acc_train['BUSIS'], colors[2], label='BUSIS  train')

    plt.legend()
    plt.show()

    plt.figure(1)
    plt.plot(index, loss_train['UDIAT'], colors[0], label='UDIAT loss train')
    plt.plot(index, loss_train['BUS'], colors[1], label='BUS  loss train')
    plt.plot(index, loss_train['BUSIS'], colors[2], label='BUSIS  loss train')
    plt.legend()
    plt.show()
    
    plot_graphs(2, CLIENTS, index, acc_test, ' acc_test')

print(best_avg_acc, best_epoch)
for client in CLIENTS:
    print(client)
    tmp = best_epoch
    best_epoch = best_epoch 
    print("shared epoch specific")
    print(acc_test[client][best_epoch])
    print("max client-specific")
    print(np.max(acc_test[client]))
    best_epoch = tmp