In [17]:
import os
import sys
from tqdm import tqdm
from tensorboardX import SummaryWriter
import shutil
import argparse
import logging
import time
import random
import numpy as np
import collections
from collections import OrderedDict
from glob import glob
import copy
import natsort
import cv2

import albumentations as A
from torchvision import transforms
import torch
from torch.autograd import Variable
import torch.optim as optim
from torchvision import transforms
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import torch.nn as nn
from pytorch_metric_learning import losses
from pytorch_lightning import seed_everything

from networks.unet2d import Unet2D, UNet_OG
'''Efficient-Unet'''
from networks.EfficientUnet.efficientunet import *
from utils.losses import *
from utils.util import _eval_dice, _eval_dice_mri, _eval_haus, _connectivity_region_analysis, parse_fn_haus
from utils.metrics import dice_coef_metric, iou_metric
from utils.aggregators import *
#from utils.mislabel import *
from dataloaders.prostate_dataloader import Dataset_Prostate, normalize
from dataloaders.transforms import train_pro_tfm, eval_pro_tfm
from sklearn.model_selection import train_test_split

Global seed set to 1337
Global seed set to 1337


In [18]:
# define dataset, model, optimizer for each client
def worker_init_fn(worker_id):
    random.seed(1337+worker_id)

In [19]:
def ellipsis_mask(mask):
    #np.random.seed(1)
    rands = np.random.randint(1, 5, (2,))
    center = np.random.randint(-5, 5, (2,))
    
    x0 = center[0]; a = rands[0]  # x center, half width                                       
    y0 = center[1]; b = rands[1]  # y center, half height
    x = np.linspace(-10, 10, mask.shape[-1])  # x values of interest
    y = np.linspace(-10, 10, mask.shape[-1])[:,None]  # y values of interest, as a "column" array
    ellipse = ((x-x0)/a)**2 + ((y-y0)/b)**2 <= 1  # True for points inside the ellipse
    ellipse = ellipse.astype('int').reshape(mask.shape)
    return ellipse

In [20]:
def mislabeling(masklist, epc, mis_type=''):
    '''set seed'''
    random.seed(100+epc)
    np.random.seed(100+epc)
    
    '''convert to numpy'''
    masklist = masklist.data.cpu().numpy()
    '''split idx for zoom in or out'''
    idx = masklist.shape[0] // 2
    mismask_full = np.zeros_like(masklist)
    
    for i, mask in enumerate(masklist):
        if mis_type == 'both':
            if i <= idx:
                '''affine transformation'''
                mismask = A.Affine(scale=0.5, rotate=180, p=1)(image=mask)['image']
            else:
                '''ellipsis mislabeling'''
                mismask = ellipsis_mask(mask)
        if mis_type == 'affine':
            mismask = A.Affine(scale=0.5, rotate=180, p=1)(image=mask)['image']
        if mis_type == 'ellip':
            mismask = ellipsis_mask(mask)
            
        mismask_full[i] = mismask
    return torch.from_numpy(mismask_full)

## Prostate Evaluation

In [12]:
client_num = 6
mis_type = 'ellip'
source_site_idx = [0, 1, 2, 3, 4, 5]
client_name = ['client1', 'client2', 'client3', 'client4', 'client5', 'client6']
train_loader_clients = []
mislabel_rate = 0.4
iterations = 10
mis_pool, gd_pool = [], []
mse_list = []
w2b_list, b2w_list, flip_list = [], [], []

device = "cuda" if torch.cuda.is_available() else "cpu"

for ite in range(iterations):
    for client_idx in range(client_num):
            #Prostate
            image_list = glob('./dataset/Prostate/{}/data_npy/*'.format(client_name[client_idx]))
            #train test split
            train, _ = train_test_split(image_list, test_size=0.1, random_state=1337)
            train, _ = train_test_split(train, test_size=0.1, random_state=1337)
            #we can perform augmentation
            train_set = Dataset_Prostate(train, train_pro_tfm)

            #dataloader
            train_loader = DataLoader(train_set, batch_size=64, shuffle=True,  num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)
            train_loader_clients.append(train_loader)
            dataloader_current = train_loader_clients[client_idx]

            for i_batch, sampled_batch in enumerate(dataloader_current):
                time2 = time.time()

                # obtain training data
                #volume_batch: (64, 3, 256,256), label_batch: (64, 1, 256, 256)
                volume_batch_raw, label_batch = sampled_batch['image'].to(device), sampled_batch['label'].to(device)

                '''if mislabeling'''
                if mislabel_rate > 0:
                    mis_idx = round(mislabel_rate*label_batch.shape[0]) #compute mislabel idx for normal/mislabel split
                    sd = ite + i_batch #seed for mislabeling
                    mis_label = mislabeling(label_batch[:mis_idx], sd, mis_type).to(device)
                    mis_batch= torch.vstack([mis_label, label_batch[mis_idx:]])
                    gd_label = copy.deepcopy(label_batch)
                    #append to list
                    mis_pool.append(mis_batch.data.cpu().numpy())
                    gd_pool.append(gd_label.data.cpu().numpy())
            
    #stack 6 clients            
    mis_masks = np.vstack(mis_pool)
    gd_masks = np.vstack(gd_pool)
    
    tot_pixels = len(gd_masks.flatten())
    tot_whites = gd_masks.sum()
    tot_blacks = tot_pixels - tot_whites
    
    #compute difference and count
    diff = gd_masks - mis_masks
    w2b = (diff > 0).sum()
    b2w = (diff < 0).sum()
    #compute rate
    w2b_rate = w2b/tot_whites
    b2w_rate = b2w/tot_blacks
    flip_rate = (w2b+b2w) / tot_pixels
    
    print(f'iteration: {ite}, w2b_rate={round(w2b_rate,3)}, b2w_rate={round(b2w_rate,3)}, flip_rate={round(flip_rate,3)}')
    w2b_list.append(w2b_rate)
    b2w_list.append(b2w_rate)
    flip_list.append(flip_rate)
    #eval_mse = round(mse_loss(mis_masks, gd_masks),3)
    #print(f'iteration: {ite}, mse={eval_mse}')
    #mse_list.append(eval_mse)


mean_w2b_rate, mean_b2w_rate, mean_flip_rate = round(np.mean(w2b_list),3), round(np.mean(b2w_list),3), round(np.mean(flip_list),3) 
print(mean_w2b_rate, mean_b2w_rate, mean_flip_rate)
#mean_mse = round(np.mean(mse_list),3)
#mean_mse

total 421 slices
total 782 slices
total 233 slices
total 984 slices
total 468 slices
total 257 slices
iteration: 0, w2b_rate=0.356, b2w_rate=0.019, flip_rate=0.025
total 421 slices
total 782 slices
total 233 slices
total 984 slices
total 468 slices
total 257 slices
iteration: 1, w2b_rate=0.335, b2w_rate=0.019, flip_rate=0.025
total 421 slices
total 782 slices
total 233 slices
total 984 slices
total 468 slices
total 257 slices
iteration: 2, w2b_rate=0.33, b2w_rate=0.019, flip_rate=0.025
total 421 slices
total 782 slices
total 233 slices
total 984 slices
total 468 slices
total 257 slices
iteration: 3, w2b_rate=0.327, b2w_rate=0.019, flip_rate=0.024
total 421 slices
total 782 slices
total 233 slices
total 984 slices
total 468 slices
total 257 slices
iteration: 4, w2b_rate=0.328, b2w_rate=0.019, flip_rate=0.024
total 421 slices
total 782 slices
total 233 slices
total 984 slices
total 468 slices
total 257 slices
iteration: 5, w2b_rate=0.332, b2w_rate=0.019, flip_rate=0.024
total 421 slices


# LGG Evaluation

In [23]:
client_num = 4
mis_type = 'both'
source_site_idx = [0, 1, 2, 3]
client_name = ['client1', 'client2', 'client3', 'client4']
train_loader_clients = []
mislabel_rate = 0.8
iterations = 10
mis_pool, gd_pool = [], []
mse_list = []
w2b_list, b2w_list, flip_list = [], [], []

device = "cuda" if torch.cuda.is_available() else "cpu"

for ite in range(iterations):
    for client_idx in range(client_num):
            #Prostate
            image_list = glob('./dataset/LGG/{}/data_npy/*'.format(client_name[client_idx]))
            #train test split
            train, _ = train_test_split(image_list, test_size=0.1, random_state=1337)
            train, _ = train_test_split(train, test_size=0.1, random_state=1337)
            #we can perform augmentation
            train_set = Dataset_Prostate(train, train_pro_tfm)

            #dataloader
            train_loader = DataLoader(train_set, batch_size=64, shuffle=True,  num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)
            train_loader_clients.append(train_loader)
            dataloader_current = train_loader_clients[client_idx]

            for i_batch, sampled_batch in enumerate(dataloader_current):
                time2 = time.time()

                # obtain training data
                #volume_batch: (64, 3, 256,256), label_batch: (64, 1, 256, 256)
                volume_batch_raw, label_batch = sampled_batch['image'].to(device), sampled_batch['label'].to(device)

                '''if mislabeling'''
                if mislabel_rate > 0:
                    mis_idx = round(mislabel_rate*label_batch.shape[0]) #compute mislabel idx for normal/mislabel split
                    sd = ite + i_batch #seed for mislabeling
                    mis_label = mislabeling(label_batch[:mis_idx], sd, mis_type).to(device)
                    mis_batch= torch.vstack([mis_label, label_batch[mis_idx:]])
                    gd_label = copy.deepcopy(label_batch)
                    #append to list
                    mis_pool.append(mis_batch.data.cpu().numpy())
                    gd_pool.append(gd_label.data.cpu().numpy())
            
    #stack 6 clients            
    mis_masks = np.vstack(mis_pool)
    gd_masks = np.vstack(gd_pool)
    
    tot_pixels = len(gd_masks.flatten())
    tot_whites = gd_masks.sum()
    tot_blacks = tot_pixels - tot_whites
    
    #compute difference and count
    diff = gd_masks - mis_masks
    w2b = (diff > 0).sum()
    b2w = (diff < 0).sum()
    #compute rate
    w2b_rate = w2b/tot_whites
    b2w_rate = b2w/tot_blacks
    flip_rate = (w2b+b2w) / tot_pixels
    
    print(f'iteration: {ite}, w2b_rate={round(w2b_rate,3)}, b2w_rate={round(b2w_rate,3)}, flip_rate={round(flip_rate,3)}')
    w2b_list.append(w2b_rate)
    b2w_list.append(b2w_rate)
    flip_list.append(flip_rate)
    #eval_mse = round(mse_loss(mis_masks, gd_masks),3)
    #print(f'iteration: {ite}, mse={eval_mse}')
    #mse_list.append(eval_mse)


mean_w2b_rate, mean_b2w_rate, mean_flip_rate = round(np.mean(w2b_list),3), round(np.mean(b2w_list),3), round(np.mean(flip_list),3) 
print(mean_w2b_rate, mean_b2w_rate, mean_flip_rate)
#mean_mse = round(np.mean(mse_list),3)
#mean_mse

total 1521 slices
total 833 slices
total 518 slices
total 289 slices
iteration: 0, w2b_rate=0.654, b2w_rate=0.02, flip_rate=0.027
total 1521 slices
total 833 slices
total 518 slices
total 289 slices
iteration: 1, w2b_rate=0.661, b2w_rate=0.02, flip_rate=0.027
total 1521 slices
total 833 slices
total 518 slices
total 289 slices
iteration: 2, w2b_rate=0.661, b2w_rate=0.02, flip_rate=0.027
total 1521 slices
total 833 slices
total 518 slices
total 289 slices
iteration: 3, w2b_rate=0.66, b2w_rate=0.02, flip_rate=0.027
total 1521 slices
total 833 slices
total 518 slices
total 289 slices
iteration: 4, w2b_rate=0.662, b2w_rate=0.02, flip_rate=0.027
total 1521 slices
total 833 slices
total 518 slices
total 289 slices
iteration: 5, w2b_rate=0.661, b2w_rate=0.02, flip_rate=0.027
total 1521 slices
total 833 slices
total 518 slices
total 289 slices
iteration: 6, w2b_rate=0.662, b2w_rate=0.02, flip_rate=0.027
total 1521 slices
total 833 slices
total 518 slices
total 289 slices
iteration: 7, w2b_rate