In [25]:
import os
import copy

import numpy as np
import torch
import torch.nn as nn
import math

from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from unet import UNet
from dice_loss import dice_coeff
####################################################
# for data splitting
####################################################
import pandas as pd
####################################################
# for data preparation
####################################################
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score
####################################################
# for plotting
####################################################
import matplotlib.pyplot as plt
from IPython.display import clear_output
############################
# Helper func
############################
from helper import * 
###################################
TRAIN_RATIO = 0.8
RS = 30448
N_CHANNELS, N_CLASSES = 1, 1 
bilinear = True
BATCH_SIZE, EPOCHS = 16, 150
IMAGE_SIZE = (600, 600)
CROP_SIZE = (600, 600)

sss = StratifiedShuffleSplit(n_splits=1, test_size=1-TRAIN_RATIO, random_state=RS)

#There are 7 types of classes in the dataset for lesions as specified:
lesion_type_dict = {
    'seg': 'Segmentation', # 0
}
lesion_type_dict_malignant = {
    'seg': 'ben', # 0
}
#########################################
DIR = 'data/'

print()




In [26]:
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

In [27]:
import os
from sklearn.model_selection import train_test_split

CLIENTS = ['miccai', 'bns']
CLIENTS_2 = [cl + '_2' for cl in CLIENTS]
###################################################################
TOTAL_CLIENTS = len(CLIENTS)

DIR_DATA = 'data/imagesTrAug/'
DIR_GT = 'data/labelsTrBW/'

# Costruire una lista dei nomi dei file per ciascun nuovo dataset.
skin_dataset = dict()
skin_dataset['miccai'] = ['miccai_{:03d}'.format(i) for i in range(1, 201)]  # Nomi dei file per il dataset "miccai"
skin_dataset['bns'] = ['bns_{:03d}'.format(i) for i in range(1, 181)]  # Nomi dei file per il dataset "bns"

split_dataset = dict()
STATIC_WEIGHT = [0, 0]
order = 0

for client in skin_dataset:
    tmp = skin_dataset[client]
    x_ = [os.path.join(DIR_DATA, f + '.png') for f in tmp]
    y_ = [os.path.join(DIR_GT, f + '.png') for f in tmp]
    
    x_train, x_test, y_train, y_test = train_test_split(
        x_, y_, test_size=1 - TRAIN_RATIO, random_state=RS)

    split_dataset[client + '_train'] = Cancer(x_train, y_train, train=True, \
                                              IMAGE_SIZE=IMAGE_SIZE \
                                              , CROP_SIZE=CROP_SIZE)
    STATIC_WEIGHT[order] = len(x_train)
    order += 1

    split_dataset[client + '_test'] = Cancer(x_test, y_test, train=False, \
                                             IMAGE_SIZE=IMAGE_SIZE \
                                             , CROP_SIZE=CROP_SIZE)
    print(client)


miccai
bns


In [28]:
STATIC_WEIGHT = [item / sum(STATIC_WEIGHT) for item in STATIC_WEIGHT]
print(STATIC_WEIGHT)

[0.5263157894736842, 0.47368421052631576]


In [29]:
device = torch.device('cuda:0')
LR, WD, TH = 1e-3, 1e-4, 0.9
best_avg_acc, best_epoch = 0.0, 0

In [30]:
training_clients, testing_clients = dict(), dict()
training_clients_pl = dict()

acc_train, acc_test, loss_train, loss_test = dict(), dict(), \
                                            dict(), dict()
    
nets, optimizers = dict(), dict()

nets['global'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)

nets['global_2'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)

for client in CLIENTS:
    print(client)
    training_clients[client] = DataLoader(split_dataset[client+'_train'], batch_size=2,\
                 shuffle=True, num_workers=8)
    training_clients_pl[client] = DataLoader(split_dataset[client+'_train'], batch_size=1, \
                shuffle=True, num_workers=8)
    ###################################################################################
    testing_clients[client] = DataLoader(split_dataset[client+'_test'], batch_size=1,\
                         shuffle=False, num_workers=1)
    
    acc_train[client], acc_test[client] = [], []
    loss_train[client], loss_test[client] = [], []
        
    nets[client] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)
    nets[client+'_2'] = UNet(n_channels=N_CHANNELS, n_classes=N_CLASSES, \
                      bilinear=True).to(device)
    optimizers[client]= optim.Adam(nets[client].parameters(), \
                                   lr=LR,weight_decay=WD)
    optimizers[client+'_2']= optim.Adam(nets[client+'_2'].parameters(), \
                                   lr=LR,weight_decay=WD)

miccai
bns


In [31]:
# Stampa 10 esempi di contenuto di split_dataset
for i, (key, value) in enumerate(split_dataset.items()):
    if i >= 10:
        break
    print("Key:", key)
    print("Value:")
    for j in range(min(10, len(value))):
        x, y, bbox_mask = value[j]
        print("  - Sample", j+1)
        print("    - Image shape:", x.shape)
        print("    - Label shape:", y.shape)
        print("    - BBox mask shape:", bbox_mask.shape)
    print("\n")


Key: miccai_train
Value:
  - Sample 1
    - Image shape: torch.Size([1, 600, 600])
    - Label shape: torch.Size([1, 600, 600])
    - BBox mask shape: torch.Size([1, 600, 600])
  - Sample 2
    - Image shape: torch.Size([1, 600, 600])
    - Label shape: torch.Size([1, 600, 600])
    - BBox mask shape: torch.Size([1, 600, 600])
  - Sample 3
    - Image shape: torch.Size([1, 600, 600])
    - Label shape: torch.Size([1, 600, 600])
    - BBox mask shape: torch.Size([1, 600, 600])
  - Sample 4
    - Image shape: torch.Size([1, 600, 600])
    - Label shape: torch.Size([1, 600, 600])
    - BBox mask shape: torch.Size([1, 600, 600])
  - Sample 5
    - Image shape: torch.Size([1, 600, 600])
    - Label shape: torch.Size([1, 600, 600])
    - BBox mask shape: torch.Size([1, 600, 600])
  - Sample 6
    - Image shape: torch.Size([1, 600, 600])
    - Label shape: torch.Size([1, 600, 600])
    - BBox mask shape: torch.Size([1, 600, 600])
  - Sample 7
    - Image shape: torch.Size([1, 600, 600])
    -

In [32]:
CLIENTS_SUPERVISION = ['labeled', 'labeled']
# CLIENTS_SUPERVISION = ['unlabeled', 'unlabeled', 'labeled', 'unlabeled']
# CLIENTS_SUPERVISION = ['bbox','bbox','labeled', 'bbox']

In [33]:
bbox_supervision = False

# FedAvg

In [34]:
CLIENTS_SUPERVISION = ['labeled', 'labeled']

In [35]:
WEIGHTS = STATIC_WEIGHT

In [36]:
import os
import torch

best_model_dir = "C:\\Users\\utente\\Desktop"

best_avg_acc, best_epoch_avg = 0, 0
best_models = {}  # Dizionario per memorizzare i migliori modelli per ciascun cliente

for epoch in range(EPOCHS):        
    index.append(epoch)
    
    #################### copy fed model ###################
    copy_fed(CLIENTS, nets, fed_name='global')
    
    #### conduct training #####
    for client, supervision_t in zip(CLIENTS, CLIENTS_SUPERVISION):
        train_model(training_clients[client], nets[client], \
                                  optimizers[client], device, \
                                  acc = acc_train[client], \
                                  loss = loss_train[client], \
                                  supervision_type = supervision_t)
        
    aggr_fed(CLIENTS, WEIGHTS, nets, fed_name='global')
    ################### test ##############################
    avg_acc = 0.0
    for client in CLIENTS:
        test(epoch, testing_clients[client], nets['global'], device, acc_test[client],\
             loss_test[client])
        avg_acc += acc_test[client][-1]
        
        # Salva il modello se l'accuratezza del cliente migliora
        if acc_test[client][-1] > best_models.get(client, {'accuracy': 0})['accuracy']:
            best_models[client] = {'model': nets[client].state_dict(), 'accuracy': acc_test[client][-1]}
            torch.save(nets[client].state_dict(), os.path.join(best_model_dir, f"best_model_{client}.pt"))
        
    avg_acc = avg_acc / TOTAL_CLIENTS
    ############################################################
    ########################################################
    if avg_acc > best_avg_acc:
        best_avg_acc = avg_acc
        best_epoch = epoch
    
    ################################
    # plot #########################
    ################################
    clear_output(wait=True)
    print("Epoch:", epoch, "/", EPOCHS)
    print("Current Average Accuracy:", avg_acc, "Best Average Accuracy:", best_avg_acc)
    plot_graphs(0, CLIENTS, index, acc_train, 'acc_train')
    plot_graphs(1, CLIENTS, index, loss_train, 'loss_train')
    plot_graphs(2, CLIENTS, index, acc_test, ' acc_test')

print("Best Average Accuracy:", best_avg_acc, "at Epoch:", best_epoch)
for client in CLIENTS:
    print(client)
    tmp = best_epoch
    best_epoch = best_epoch 
    print("shared epoch specific")
    print(acc_test[client][best_epoch])
    print("max client-specific")
    print(np.max(acc_test[client]))
    best_epoch = tmp


KeyboardInterrupt: 

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Funzione per caricare un'immagine di test
def load_test_image(image_path):
    image = Image.open(image_path)
    return image

# Funzione per visualizzare l'immagine e la segmentazione ottenuta dal modello
def visualize(image, segmentation):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    axes[0].imshow(image)
    axes[0].set_title('Test Image')
    axes[0].axis('off')
    axes[1].imshow(segmentation, cmap='gray')
    axes[1].set_title('Segmentation')
    axes[1].axis('off')
    plt.show()

# Percorso della directory contenente i modelli salvati
saved_models_dir = "C:\\Users\\utente\\Desktop"

# Cliente per cui è stato salvato il modello
client = "miccai"

# Carica il modello salvato per il cliente specificato
model_path = os.path.join(saved_models_dir, f"best_model_{client}.pt")
if os.path.exists(model_path):
    model = YourModel()  # Sostituisci YourModel() con la classe/modello specifico che hai utilizzato
    model.load_state_dict(torch.load(model_path))
    model.eval()
else:
    print(f"Model for client {client} not found at {model_path}")

# Percorso dell'immagine di test
test_image_path = "C:\Users\utente\Desktop\Università\Tesi magistrale\RC Nuclei Cellulari\RC_Nuclei\Nuclei_Segmentation_Experiments_Demo_dataset\MICCAI2017\Nuclei_segmentation_testing\imagesTs\image02.png"  # Sostituisci con il percorso dell'immagine di test

# Carica l'immagine di test
test_image = load_test_image(test_image_path)

# Esegui la segmentazione sull'immagine di test utilizzando il modello caricato
if model:
    # Assicurati di adattare questa parte del codice in base alla tua implementazione specifica
    with torch.no_grad():
        # Preprocessa l'immagine di test se necessario
        preprocessed_image = preprocess(test_image)
        # Effettua la segmentazione
        segmentation_output = model(preprocessed_image)
        # Postprocessa la segmentazione se necessario
        postprocessed_segmentation = postprocess(segmentation_output)
    
    # Visualizza l'immagine di test e la segmentazione ottenuta
    visualize(test_image, postprocessed_segmentation)
else:
    print("Model not found for client 'miccai'")
