In [None]:
import os
import copy

import numpy as np
import torch
import torch.nn as nn
import math

from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from unet import UNet
from dice_loss import dice_coeff
####################################################
# for data splitting
####################################################
import pandas as pd
####################################################
# for data preparation
####################################################
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score
####################################################
# for plotting
####################################################

import matplotlib.pyplot as plt
from IPython.display import clear_output
import random
############################
# Helper func
############################
from helper import * 

########################################
N_CHANNELS = 1 #greyscale
N_CLASSES = 3 # classes, IRF, SRF, PED

BATCH_SIZE, EPOCHS = 16, 150
IMAGE_SIZE = (224, 224)
CROP_SIZE = (200, 200)

In [None]:
PTH = 'dataset/retouch/mask/'
PTH_IM = 'dataset/retouch/slices/'

Cirrus_ = np.arange(1,25)
Spectralis_ = np.arange(25,49)
Topcon_ = np.arange(49,71)

cirrus_test = np.random.choice(Cirrus_, size=5, replace=False)
cirrus_test = [str(a) for a in cirrus_test]
cirrus_train = [str(a) for a in Cirrus_ if str(a) not in cirrus_test]
print(len(cirrus_train)/len(Cirrus_))

spectralis_test = np.random.choice(Spectralis_, size=5, replace=False)
spectralis_test = [str(a) for a in spectralis_test]
spectralis_train = [str(a) for a in Spectralis_ if str(a) not in spectralis_test]
print(len(spectralis_train)/len(Spectralis_))

topcon_test = np.random.choice(Topcon_, size=5, replace=False)
topcon_test = [str(a) for a in topcon_test]
topcon_train = [str(a) for a in Topcon_ if str(a) not in topcon_test]
print(len(topcon_train)/len(Topcon_))

whole_data =  os.listdir(PTH)

In [None]:
TEST_SPLIT = dict()
TEST_SPLIT['Cirrus'] = cirrus_test
TEST_SPLIT['Topcon'] = topcon_test
TEST_SPLIT['Spectralis'] = spectralis_test

In [None]:
WHOLE_DATA_TRAIN = dict()
WHOLE_DATA_TEST = dict()

WHOLE_DATA_TRAIN['Cirrus'] = []
WHOLE_DATA_TRAIN['Topcon'] = []
WHOLE_DATA_TRAIN['Spectralis'] = []

WHOLE_DATA_TEST['Cirrus'] = []
WHOLE_DATA_TEST['Topcon'] = []
WHOLE_DATA_TEST['Spectralis'] = []

for item in whole_data:
    separator = item.split('_') # identify the source
    sample_number = separator[1].split('0')[-1]
    
    '''
    Check the item label
    '''
    label_slice = np.unique(np.array(Image.open(PTH+item)))
    if len(label_slice) <=1:
        continue
    
    if sample_number in TEST_SPLIT[separator[0]]:
        # testing set
        data_path = (PTH_IM+item, PTH+item) #x,y source
        WHOLE_DATA_TEST[separator[0]].append(data_path)
    else:
        data_path = (PTH_IM+item, PTH+item) #x,y source
        WHOLE_DATA_TRAIN[separator[0]].append(data_path)


In [None]:
LEN_SPECTRALIS = 0
print('name\t\t train\t test')
for item in WHOLE_DATA_TRAIN:
    if item == 'Spectralis':
        print(item,'\t', len(WHOLE_DATA_TRAIN[item]),'\t', len(WHOLE_DATA_TEST[item]))
        LEN_SPECTRALIS = WHOLE_DATA_TRAIN[item]
    else:
        print(item,'\t\t', len(WHOLE_DATA_TRAIN[item]),'\t', len(WHOLE_DATA_TEST[item]))

In [None]:
device = torch.device('cuda:0')

WEIGHTS_CL = [0.0,0.0,0.0]
CLIENTS = ['Cirrus', 'Topcon', 'Spectralis']
CLIENTS_2 = [cl +'_2' for cl in CLIENTS]
TOTAL_CLIENTS = len(CLIENTS)

LR = 1.5e-3
WD = 1e-5
TH = 0.9

LAMBDA_ =2
BETA_=3
TH = 0.9

for idx, client in enumerate(WHOLE_DATA_TRAIN):
    WEIGHTS_CL[idx] = len(WHOLE_DATA_TRAIN[client])

    
total_weight = sum(WEIGHTS_CL)
WEIGHTS_CL = [s/total_weight for s in WEIGHTS_CL]

In [None]:
CLIENTS_SUPERVISION = ['labeled', 'labeled', 'labeled']

In [None]:
split_dataset = dict()
for cl in CLIENTS:
    split_dataset[cl+'_train'] = retouch(WHOLE_DATA_TRAIN[cl], train=True)
    split_dataset[cl+'_test'] = retouch(WHOLE_DATA_TEST[cl], train=False)

In [None]:
TOTAL_CLIENTS = 3
GLOBAL_ACC = 0.0

training_clients, testing_clients = dict(), dict()
########## aditional #####################
training_clients_pl = dict()
acc_train, acc_test, loss_train, loss_test = dict(), dict(), \
                                            dict(), dict()

nets, optimizers = dict(), dict()

nets['global'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)

nets['global_2'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)

for client, c_sup in zip(CLIENTS, CLIENTS_SUPERVISION):
    if c_sup == 'labeled':
        training_clients[client] = DataLoader(split_dataset[client+'_train'], batch_size=16,\
                     shuffle=True, num_workers=8)
        training_clients_pl[client] = DataLoader(split_dataset[client+'_train'], batch_size=1, \
                    shuffle=True, num_workers=8)
    else:
        training_clients[client] = DataLoader(split_dataset[client+'_train'], batch_size=16,\
                             shuffle=True, num_workers=8)
        ################# additional dataloader ##########################################
        training_clients_pl[client] = DataLoader(split_dataset[client+'_train'], batch_size=1,\
                             shuffle=True, num_workers=8)
        training_clients_pl[client+'_2'] = DataLoader(split_dataset[client+'_train'], batch_size=1,\
                             shuffle=True, num_workers=8)
    ###################################################################################
    testing_clients[client] = DataLoader(split_dataset[client+'_test'], batch_size=16,\
                         shuffle=False, num_workers=1)
    
    acc_train[client], acc_test[client] = [], []
    loss_train[client], loss_test[client] = [], []
        
    nets[client] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)
    nets[client+'_2'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)
    optimizers[client]= optim.Adam(nets[client].parameters(), \
                                   lr=LR,weight_decay=WD)
    optimizers[client+'_2']= optim.Adam(nets[client+'_2'].parameters(), \
                                   lr=LR,weight_decay=WD)

In [None]:
torch.multiprocessing.set_sharing_strategy('file_system')

# FedST or FedRGD

In [None]:
WEIGHTS_POSTWARMUP = [0.1/2, 0.1/2, 0.9] #put more weight to client with strong supervision
WARMUP_EPOCH = 100
CLIENTS_SUPERVISION = ['unlabeled', 'unlabeled', 'labeled']


MODE = 'FedST' 
# MODE = 'FedRGD'

CE_ = nn.BCELoss()
WEIGHTS = [0.0, 0.0, 1.0]

if MODE == 'FedST':
    WEIGHTS_POSTWARMUP = [0.1/2, 0.1/2, 0.9] #put more weight to client with strong supervision
else:
    WEIGHTS_POSTWARMUP = [0.5/2, 0.5/2, 0.5] #random client, but effectively, put 1/2 trust to strongly supervised client
    
CE_ = nn.BCELoss()


In [None]:
best_avg_acc, best_epoch_avg = 0,0
index = []

for client in CLIENTS:
    acc_train[client], acc_test[client] = [], []
    loss_train[client], loss_test[client] = [], [] 
USE_UNLABELED_CLIENT = False    
for epoch in range(EPOCHS):
    index.append(epoch)
    if epoch == WARMUP_EPOCH:
        WEIGHTS = WEIGHTS_POSTWARMUP
        USE_UNLABELED_CLIENT = True
    ####### conduct training #####
    #################### copy fed model ###################
    copy_fed(CLIENTS, nets, fed_name='global')
    
    for client, supervision_t in zip(CLIENTS, CLIENTS_SUPERVISION):
        if supervision_t == 'unlabeled':
            if not USE_UNLABELED_CLIENT:
                acc_train[client].append(0)
                loss_train[client].append(0)
                continue
        print(client)
        train_model_multiclasses(training_clients[client], nets[client], \
                                  optimizers[client], device, \
                                  acc = acc_train[client], \
                                  loss = loss_train[client], \
                                  supervision_type = supervision_t,\
                                warmup=True, CE_LOSS=CE_)
        
    aggr_fed(CLIENTS, WEIGHTS, nets, fed_name='global')
    ################### test ##############################
    avg_acc = 0.0
    for client in CLIENTS:
        test_multiclasses(epoch, testing_clients[client], nets[client], device, acc_test[client],\
             loss_test[client])
        avg_acc += acc_test[client][-1]
        
    avg_acc = avg_acc / TOTAL_CLIENTS
    ############################################################
    ########################################################
    if avg_acc > best_avg_acc:
        best_avg_acc = avg_acc
        best_epoch = epoch
    
    ################################
    # plot #########################
    ################################
    clear_output(wait=True)
    print(avg_acc, best_avg_acc)
    plot_graphs(0, CLIENTS, index, acc_train, 'acc_train')
    plot_graphs(1, CLIENTS, index, loss_train, 'loss_train')
    plot_graphs(2, CLIENTS, index, acc_test, ' acc_test')

print(best_avg_acc, best_epoch)
for client in CLIENTS:
    print(client)
    tmp = best_epoch
    best_epoch = best_epoch 
    print("shared epoch specific")
    print(acc_test[client][best_epoch])
    print("max client-specific")
    print(np.max(acc_test[client]))
    best_epoch = tmp