In [None]:
import os
import copy

import numpy as np
import torch
import torch.nn as nn
import math

from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from unet import UNet
from dice_loss import dice_coeff

####################################################
# for data preparation
####################################################
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score
####################################################
# for plotting
####################################################
import matplotlib.pyplot as plt
from IPython.display import clear_output
############################
# Helper func
############################
from helper import * 
#################################
TRAIN_RATIO = 0.8
RS = 30448 # random state
N_CHANNELS, N_CLASSES = 1, 1 
bilinear = True
BATCH_SIZE, EPOCHS = 16, 300
IMAGE_SIZE = (256, 256)
CROP_SIZE = (224, 224)
#########################################
DIR = 'dataset/3datasets_segment_v2'
CLIENTS = ['BUS', 'BUSIS', 'UDIAT']
CLIENTS_2 = [cl +'_2' for cl in CLIENTS]
TOTAL_CLIENTS = len(CLIENTS)
#####################################
# add the classification segment ####
#####################################
DIR_CLASSIFICATION = DIR + '/BUS/classification'

In [None]:
device = torch.device('cuda:0')
LR, WD, TH = 1e-3, 1e-5, 0.9

## Training path - Testing path

In [None]:
TOTAL_DATA = []

In [None]:
class Cancer(Dataset):
    def __init__(self, im_path, mask_path, train=False, \
                IMAGE_SIZE=(256,256), CROP_SIZE=(224,224)):
        self.data = im_path
        self.label = mask_path
        self.train = train
        self.IMAGE_SIZE = IMAGE_SIZE
        self.CROP_SIZE = CROP_SIZE

    def transform(self, image, mask, train):
        resize = Resize(self.IMAGE_SIZE)
        image = resize(image)
        mask = resize(mask)
        if train:
            # Random crop
            i, j, h, w = RandomCrop.get_params(
                image, output_size=(self.CROP_SIZE))
            image = TF.crop(image, i, j, h, w)
            mask = TF.crop(mask, i, j, h, w)
            # Random horizontal flipping
            if random.random() > 0.5:
                image = TF.hflip(image)
                mask = TF.hflip(mask)
            # Random vertical flipping
            if random.random() > 0.5:
                image = TF.vflip(image)
                mask = TF.vflip(mask)
        # Transform to tensor
        image = TF.to_tensor(image)
        mask = TF.to_tensor(mask)
        return image, mask
        
    def __len__(self):
        return len(self.label)
    
    def __getitem__(self, idx):
        image = Image.open(self.data[idx]).convert('L')
        mask = Image.open(self.label[idx]).convert('L')
        tmp = TF.to_tensor(mask)

        x, y = self.transform(image, mask, self.train)
        ##########################################################
        # generate bbox mask #####################################
        bbox_mask = torch.zeros(y.shape)
        # if normal images no bbox / black # 
        if torch.sum(y) > 0:
            _, w, h = torch.where(y == 1)
            w_min, w_max, h_min, h_max = torch.min(w)-2, torch.max(w)+2, torch.min(h)-2, torch.max(h)+2
            bbox_mask[:, w_min:w_max, h_min:h_max] = 1
        return x, y, bbox_mask


In [None]:
consistent_path = np.load('dict_path.npy', allow_pickle=True).item()
breast_dataset = dict()
for client in CLIENTS:
    print("loading data from ", client)
    dir_of_interest = consistent_path[client]
    x_train = dir_of_interest['x_train']
    x_test = dir_of_interest['x_test']
    y_train = dir_of_interest['y_train']
    y_test = dir_of_interest['y_test']
    
    # add normal images for bus # 
    DIR_INTEREST = DIR + '/'+ client 
    DATA_TYPE = ['original', 'GT']
    if client == 'BUS':
        for _,_, files in os.walk(DIR_INTEREST +'/classification/GT'):
            selected = [f for f in files if f[:6] =='normal']
            # update accordingly #
            for data in DATA_TYPE:
                tmp = [DIR_INTEREST + '/classification/' + data + '/' + f for f in selected]
                if data == 'GT':
                    y_train += tmp
                else:
                    x_train += tmp
    
    # to measure the weight # 
    TOTAL_DATA.append(len(x_train))
    
    breast_dataset[client+'_train']=Cancer(x_train, y_train, train=True,\
                                          IMAGE_SIZE=IMAGE_SIZE\
                                           , CROP_SIZE=CROP_SIZE)
    
    breast_dataset[client+'_test'] =Cancer(x_test, y_test, train=False,\
                                          IMAGE_SIZE=IMAGE_SIZE\
                                           , CROP_SIZE=CROP_SIZE)
    

In [None]:
DATA_AMOUNT = sum(TOTAL_DATA)
WEIGHTS = [t/DATA_AMOUNT for t in TOTAL_DATA]
WEIGHTS_DATA = copy.deepcopy(WEIGHTS)

# storage file

In [None]:
training_clients, testing_clients = dict(), dict()
training_clients_pl = dict()

acc_train, acc_test, loss_train, loss_test = dict(), dict(), \
                                            dict(), dict()
    
nets, optimizers = dict(), dict()

In [None]:
nets['global'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)

nets['global_2'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)

for client in CLIENTS:
    training_clients[client] = DataLoader(breast_dataset[client+'_train'], batch_size=16,\
                 shuffle=True, num_workers=8)
    training_clients_pl[client] = DataLoader(breast_dataset[client+'_train'], batch_size=1, \
                shuffle=True, num_workers=8)
    ###################################################################################
    testing_clients[client] = DataLoader(breast_dataset[client+'_test'], batch_size=1,\
                         shuffle=False, num_workers=1)
    
    acc_train[client], acc_test[client] = [], []
    loss_train[client], loss_test[client] = [], []
        
    nets[client] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)
    nets[client+'_2'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)
    optimizers[client]= optim.Adam(nets[client].parameters(), \
                                   lr=LR,weight_decay=WD)
    optimizers[client+'_2']= optim.Adam(nets[client+'_2'].parameters(), \
                                   lr=LR,weight_decay=WD)
    

# FedMix

### hyperparameters

In [None]:
print(CLIENTS)

In [None]:
# CLIENTS_SUPERVISION = ['unlabeled', 'unlabeled', 'labeled']
# CLIENTS_SUPERVISION = ['image', 'unlabeled', 'labeled']
CLIENTS_SUPERVISION = ['bbox', 'bbox', 'labeled']

In [None]:
LAMBDA_ = 10
BETA_ = 1.5
TH = 0.9

In [None]:
best_avg_acc, best_epoch_avg = 0, 0
index = []

for client in CLIENTS:
    acc_train[client], acc_test[client] = [], []
    loss_train[client], loss_test[client] = [], []
    
score = [0,0,0]

for epoch in range(EPOCHS):
    index.append(epoch)
    #################### copy fed model ###################
    copy_fed(CLIENTS, nets, fed_name='global')
    copy_fed(CLIENTS_2, nets, fed_name='global_2')
    ######################################################
    # generate and refine pseudo labels ##################
    ######################################################
    for order, client in enumerate(CLIENTS):
        bbox, image = False, False
        if CLIENTS_SUPERVISION[order] == 'labeled':
            continue
        elif CLIENTS_SUPERVISION[order] == 'bbox':
            bbox = True
        elif CLIENTS_SUPERVISION[order] == 'image':
            image= True
        ##################################################
        # save pl ########################################
        ##################################################
        im_store, pl1_store, pl2_store = [], [], []
        
        select_pl(nets['global'], nets['global_2'], device,\
                      training_clients_pl[client], im_store, pl1_store, \
                      pl2_store, TH=TH, bbox=bbox, image=image)
        
        if len(im_store) >= 1:
            tmp_dataset = cancer_v2(im_store, pl1_store, pl2_store)
            training_clients[client] = DataLoader(tmp_dataset, batch_size=16,\
                             shuffle=True, num_workers=8)

    #######################################################
    #### Conduct training #################################
    #######################################################
    for order, (client, supervision_t) in enumerate(zip(CLIENTS, CLIENTS_SUPERVISION)):
        if supervision_t == 'labeled':
            # train network 1 #
            train_model(training_clients[client], nets[client], optimizers[client], device, \
                       acc=acc_train[client], loss=loss_train[client], \
                        supervision_type=supervision_t)
            
            # train network 2 # 
            train_model(training_clients[client], nets[client+'_2'], optimizers[client+'_2'], device, \
                       acc=None, loss=None, \
                        supervision_type=supervision_t)
            
        else: # train using pseudo label # 
            # train network 1 #
            train_model(training_clients[client], nets[client], optimizers[client], device, \
                       acc=acc_train[client], loss=loss_train[client], \
                        supervision_type=supervision_t, FedMix_network=1)
            
            # train network 2 # 
            train_model(training_clients[client], nets[client+'_2'], optimizers[client+'_2'], device, \
                       acc=None, loss=None, \
                        supervision_type=supervision_t, FedMix_network=2)
        
        
        # save loss for future reweighting # 
        score[order] = loss_train[client][-1] ** BETA_
    ###################################
    ####### dynamic weighting #########
    ###################################
    denominator = sum(score)
    score = [s/denominator for s in score]
    for order, _ in enumerate(WEIGHTS):
        WEIGHTS[order] = WEIGHTS_DATA[order] + LAMBDA_ * score[order]
        
    ### normalize #####################
    denominator = sum(WEIGHTS)
    WEIGHTS = [w/denominator for w in WEIGHTS]

    ###################################
    ####### aggregation ###############
    ###################################
    aggr_fed(CLIENTS, WEIGHTS, nets, fed_name='global')
    aggr_fed(CLIENTS_2, WEIGHTS, nets, fed_name='global_2')
    
    ################### test ##############################
    avg_acc = 0.0
    for client in CLIENTS:
        test(epoch, testing_clients[client], nets['global'], device, acc_test[client],\
             loss_test[client])
        avg_acc += acc_test[client][-1]
        
    avg_acc = avg_acc / TOTAL_CLIENTS
    ############################################################
    ########################################################
    if avg_acc > best_avg_acc:
        best_avg_acc = avg_acc
        best_epoch = epoch
    
    
    
    clear_output(wait=True)
    print(avg_acc, best_avg_acc)
    plt.figure(0)
    plt.plot(index, acc_train['UDIAT'], colors[0], label='UDIAT train')
    plt.plot(index, acc_train['BUS'], colors[1], label='BUS  train')
    plt.plot(index, acc_train['BUSIS'], colors[3], label='BUSIS  train')

    plt.legend()
    plt.show()

    plt.figure(1)
    plt.plot(index, loss_train['UDIAT'], colors[0], label='UDIAT loss train')
    plt.plot(index, loss_train['BUS'], colors[1], label='BUS  loss train')
    plt.plot(index, loss_train['BUSIS'], colors[3], label='BUSIS  loss train')
    plt.legend()
    plt.show()
    
    plot_graphs(2, CLIENTS, index, acc_test, ' acc_test')

print(best_avg_acc, best_epoch)
for client in CLIENTS:
    print(client)
    tmp = best_epoch
    best_epoch = best_epoch 
    print("shared epoch specific")
    print(acc_test[client][best_epoch])
    print("max client-specific")
    print(np.max(acc_test[client]))
    best_epoch = tmp