In [14]:
import os
import copy

import numpy as np
import torch
import torch.nn as nn
import math

from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from unet import UNet
from dice_loss import dice_coeff

####################################################
# for data preparation
####################################################
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score
####################################################
# for plotting
####################################################
import matplotlib.pyplot as plt
from IPython.display import clear_output
############################
# Helper func
############################
from helper import * 
#################################
N_CHANNELS, N_CLASSES = 1, 1 
bilinear = True
BATCH_SIZE, EPOCHS = 16, 300
IMAGE_SIZE = (600, 600)
CROP_SIZE = (224, 224)
#########################################
DIR = 'dataset/breast'
CLIENTS = ['miccai', 'bns']
CLIENTS_2 = [cl +'_2' for cl in CLIENTS]
TOTAL_CLIENTS = len(CLIENTS)
#####################################
# add the classification segment ####

In [15]:
device = torch.device('cuda:0')
LR, WD = 1e-3, 1e-4

LAM, BETA, TH = 10, 1.5,0.9
VERSION = 1
CLIENTS_SUPERVISION = ['labeled', 'labeled']
WEIGHTS_CL = [0.0, 0.0]

# load training-test path

In [16]:
import os
import glob
import random

# Definisci il percorso delle cartelle delle immagini e delle etichette
IMAGES_DIR = "C:/Users/utente/Desktop/Università/Tesi magistrale/FedMix/data/imagesTrGray"
LABELS_DIR = "C:/Users/utente/Desktop/Università/Tesi magistrale/FedMix/data/labelsTrGray"

breast_dataset = dict()
idx_ = 0
denom_ = 0

for client, sup in zip(CLIENTS, CLIENTS_SUPERVISION):
    # Carica le immagini e le etichette dal percorso specificato
    x_train = glob.glob(os.path.join(IMAGES_DIR, "*.png"))
    y_train = glob.glob(os.path.join(LABELS_DIR, "*.png"))

    # Se il cliente non ha supervisione, aggiungi anche le immagini normali
    if sup == 'unlabeled':
        DATA_TYPE = ['original', 'GT']
        for data in DATA_TYPE:
            # Carica le immagini normali
            normal_images = glob.glob(os.path.join(IMAGES_DIR, "miccai", "classification", data, "normal*.jpg"))
            if data == 'GT':
                y_train += normal_images
            else:
                x_train += normal_images

    # Scegli casualmente il 10% delle immagini e delle etichette come set di test
    num_test_samples = int(0.1 * len(x_train))
    test_indices = random.sample(range(len(x_train)), num_test_samples)
    x_test = [x_train[i] for i in test_indices]
    y_test = [y_train[i] for i in test_indices]
    x_train = [x_train[i] for i in range(len(x_train)) if i not in test_indices]
    y_train = [y_train[i] for i in range(len(y_train)) if i not in test_indices]

    print('Full training data')
    WEIGHTS_CL[idx_] = len(x_train)
    denom_ += len(x_train)
    idx_ += 1
    
    # Carica il dataset utilizzando la classe Cancer (assicurati di avere la definizione di questa classe)
    breast_dataset[client+'_train'] = Cancer(x_train, y_train, train=True,\
                                             IMAGE_SIZE=IMAGE_SIZE, CROP_SIZE=CROP_SIZE)

    # Carica il dataset di test
    breast_dataset[client+'_test'] = Cancer(x_test, y_test, train=False,\
                                            IMAGE_SIZE=IMAGE_SIZE, CROP_SIZE=CROP_SIZE)

# Normalizza i pesi dei clienti
for idx_ in range(len(WEIGHTS_CL)):
    WEIGHTS_CL[idx_] = WEIGHTS_CL[idx_]/denom_


Full training data
Full training data


# storage file

In [17]:
len(breast_dataset[client+'_train'])

342

In [18]:
training_clients, testing_clients = dict(), dict()
training_clients_pl = dict()

acc_train, acc_test, loss_train, loss_test = dict(), dict(), \
                                            dict(), dict()
    
nets, optimizers = dict(), dict()

In [19]:
nets['global'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)

nets['global_2'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)

for client in CLIENTS:
    training_clients[client] = DataLoader(breast_dataset[client+'_train'], batch_size=16,\
                 shuffle=True, num_workers=8)
    training_clients_pl[client] = DataLoader(breast_dataset[client+'_train'], batch_size=1, \
                shuffle=True, num_workers=8)
    ###################################################################################
    testing_clients[client] = DataLoader(breast_dataset[client+'_test'], batch_size=1,\
                         shuffle=False, num_workers=1)
    
    acc_train[client], acc_test[client] = [], []
    loss_train[client], loss_test[client] = [], []
        
    nets[client] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)
    nets[client+'_2'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)
    optimizers[client]= optim.Adam(nets[client].parameters(), \
                                   lr=LR,weight_decay=WD)
    optimizers[client+'_2']= optim.Adam(nets[client+'_2'].parameters(), \
                                   lr=LR,weight_decay=WD)

## FedAvg

In [20]:
best_avg_acc, best_epoch_avg = 0, 0
index = []

for client in CLIENTS:
    acc_train[client], acc_test[client] = [], []
    loss_train[client], loss_test[client] = [], []

for epoch in range(EPOCHS):
    index.append(epoch)
    #################### copy fed model ###################
    copy_fed(CLIENTS, nets, fed_name='global')
    
    #### conduct training #####
    for client, supervision_t in zip(CLIENTS, CLIENTS_SUPERVISION):
        train_model(training_clients[client], nets[client], \
                                  optimizers[client], device, \
                                  acc = acc_train[client], \
                                  loss = loss_train[client], \
                                  supervision_type = supervision_t)
        
    aggr_fed(CLIENTS, WEIGHTS_CL, nets, fed_name='global')
    ################### test ##############################
    avg_acc = 0.0
    for client in CLIENTS:
        test(epoch, testing_clients[client], nets['global'], device, acc_test[client],\
             loss_test[client])
        avg_acc += acc_test[client][-1]
        
    avg_acc = avg_acc / TOTAL_CLIENTS
    ############################################################
    ########################################################
    if avg_acc > best_avg_acc:
        best_avg_acc = avg_acc
        best_epoch = epoch
    
    ################################
    # plot #########################
    ################################
    clear_output(wait=True)
    print(avg_acc, best_avg_acc)
    plt.figure(0)
    plt.plot(index, acc_train['miccai'], colors[1], label='miccai train')
    plt.plot(index, acc_train['bns'], colors[2], label='bns train')

    plt.legend()
    plt.show()

    plt.figure(1)
    plt.plot(index, loss_train['miccai'], colors[1], label='miccai loss train')
    plt.plot(index, loss_train['bns'], colors[2], label='bns loss train')
    plt.legend()
    plt.show()
    
    plot_graphs(2, CLIENTS, index, acc_test, ' acc_test')

print(best_avg_acc, best_epoch)
for client in CLIENTS:
    print(client)
    tmp = best_epoch
    best_epoch = best_epoch 
    print("shared epoch specific")
    print(acc_test[client][best_epoch])
    print("max client-specific")
    print(np.max(acc_test[client]))
    best_epoch = tmp

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "C:\Users\utente\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\_utils\worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "C:\Users\utente\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\_utils\fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "C:\Users\utente\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\data\_utils\fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "C:\Users\utente\Desktop\Università\Tesi magistrale\FedMix\helper.py", line 165, in __getitem__
    w_min, w_max, h_min, h_max = torch.min(w)-margin[0], \
RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.


In [22]:
IMAGE_SIZE[0]

256

In [21]:
from PIL import Image

# Verifica delle dimensioni delle immagini
for img_path in x_train + x_test:
    img = Image.open(img_path)
    width, height = img.size
    if width != IMAGE_SIZE[0] or height != IMAGE_SIZE[1]:
        print("Immagine con dimensioni non valide:", img_path)

# Verifica delle dimensioni delle etichette
for label_path in y_train + y_test:
    label = Image.open(label_path)
    width, height = label.size
    if width != IMAGE_SIZE[0] or height != IMAGE_SIZE[1]:
        print("Etichetta con dimensioni non valide:", label_path)


Immagine con dimensioni non valide: C:/Users/utente/Desktop/Università/Tesi magistrale/FedMix/data/imagesTrGray\bns_001.png
Immagine con dimensioni non valide: C:/Users/utente/Desktop/Università/Tesi magistrale/FedMix/data/imagesTrGray\bns_002.png
Immagine con dimensioni non valide: C:/Users/utente/Desktop/Università/Tesi magistrale/FedMix/data/imagesTrGray\bns_003.png
Immagine con dimensioni non valide: C:/Users/utente/Desktop/Università/Tesi magistrale/FedMix/data/imagesTrGray\bns_004.png
Immagine con dimensioni non valide: C:/Users/utente/Desktop/Università/Tesi magistrale/FedMix/data/imagesTrGray\bns_005.png
Immagine con dimensioni non valide: C:/Users/utente/Desktop/Università/Tesi magistrale/FedMix/data/imagesTrGray\bns_006.png
Immagine con dimensioni non valide: C:/Users/utente/Desktop/Università/Tesi magistrale/FedMix/data/imagesTrGray\bns_007.png
Immagine con dimensioni non valide: C:/Users/utente/Desktop/Università/Tesi magistrale/FedMix/data/imagesTrGray\bns_008.png
Immagine