In [1]:
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from progressBar import printProgressBar


import medicalDataLoader
import argparse
import utils
from utils import inferenceTeacher
from torch.utils.data import ConcatDataset


from UNet_Base import *
import random
import torch
import torch.nn as nn
import torch.optim as optim
import pdb
import matplotlib.pyplot as plt
import numpy as np
import os

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
classes = ("background", "1–tbd", "2–tbd", "3–tbd")

In [4]:
# Inspired by https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

def images_to_probs(net, images):
    '''
    Generates predictions and corresponding probabilities from a trained
    network and a list of images
    '''
    output = net(images)
    # convert output probabilities to predicted class
    _, preds_tensor = torch.max(output, 1)
    preds = np.squeeze(preds_tensor.numpy())
    return preds, [F.softmax(el, dim=0)[i].item() for i, el in zip(preds, output)]

def plot_classes_preds(net, images, labels):
    '''
    Generates matplotlib Figure using a trained network, along with images
    and labels from a batch, that shows the network's top prediction along
    with its probability, alongside the actual label, coloring this
    information based on whether the prediction was correct or not.
    Uses the "images_to_probs" function.
    '''
    preds, probs = images_to_probs(net, images)
    # plot the images in the batch, along with predicted and true labels
    fig = plt.figure(figsize=(12, 48))
    for idx in np.arange(4):
        ax = fig.add_subplot(1, 4, idx+1, xticks=[], yticks=[])
        matplotlib_imshow(images[idx], one_channel=True)
        ax.set_title("{0}, {1:.1f}%\n(label: {2})".format(
            classes[preds[idx]],
            probs[idx] * 100.0,
            classes[labels[idx]]),
                    color=("green" if preds[idx]==labels[idx].item() else "red"))
    return fig

In [5]:
def dice_coefficient(prediction, target, epsilon=1e-07): #compares this prediction model to validation model 
    prediction_copy = prediction.clone()

    prediction_copy[prediction_copy < 0] = 0
    prediction_copy[prediction_copy > 0] = 1

    intersection = abs(torch.sum(prediction_copy * target))
    union = abs(torch.sum(prediction_copy) + torch.sum(target))
    dice = (2. * intersection + epsilon) / (union + epsilon)
    
    return dice




In [6]:
# Define hyperparameters
batch_size = 8 #nb images processed at same time during training
batch_size_val = 4 #nb images processed at same time during validation 
lr =  0.01   # Learning Rate
total_epochs = 50  # Number of epochs (how many times the algorithm passes through training data)
num_classes = 4
temperature = 2.0 # temperature distillation loss

In [7]:
# Set device depending on the availability of GPU
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.mps.is_available():  # Apple M-series of chips
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: cpu


In [8]:
# Define transformations
transform = transforms.Compose([
    transforms.ToTensor()
])

mask_transform = transforms.Compose([
    transforms.ToTensor()
])

# Define dataloaders
root_dir = './data/'
print(' Dataset: {} '.format(root_dir))

train_set_full = medicalDataLoader.MedicalImageDataset('train',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=mask_transform,
                                                    augment=False,
                                                    equalize=False)

train_loader_full = DataLoader(train_set_full,
                            batch_size=batch_size,
                            worker_init_fn=np.random.seed(0),
                            num_workers=0,
                            shuffle=True)

#unlabeled data used to make predictions by teacher
unlabeledEval_set_full = medicalDataLoader.MedicalImageDataset('unlabeledEval',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=None,
                                                    augment=False,
                                                    equalize=False) #no transformation for now


unlabeledEval_loader_full = DataLoader(unlabeledEval_set_full,
                            batch_size=batch_size,
                            worker_init_fn=np.random.seed(0),
                            num_workers=0,
                            shuffle=True)




val_set = medicalDataLoader.MedicalImageDataset('val',
                                                root_dir,
                                                transform=transform,
                                                mask_transform=mask_transform,
                                                equalize=False)

val_loader = DataLoader(val_set,
                        batch_size=batch_size_val,
                        worker_init_fn=np.random.seed(0),
                        num_workers=0,
                        shuffle=False)

 Dataset: ./data/ 


In [9]:
for i, sample in enumerate(train_loader_full):
    images, masks, _ = sample
    print('Image batch dimensions: ', images.size()) #[batch_size, channels, height, width]
    print('Mask batch dimensions: ', masks.size())
    
    break

Image batch dimensions:  torch.Size([8, 1, 256, 256])
Mask batch dimensions:  torch.Size([8, 1, 256, 256])


In [10]:
#loss qui compare les prédictions du student avec celles du softmax
def distillation_loss(y_pred_student, y_pred_teacher) :
    soft_y_pred_teacher = torch.softmax(y_pred_teacher / temperature, dim=1)
    soft_y_pred_student = torch.softmax(y_pred_student / temperature, dim=1)
    
    loss = nn.KLDivLoss(reduction='batchmean')(torch.log(soft_y_pred_student), soft_y_pred_teacher)
    
    return loss

In [18]:
def runTraining(writer: SummaryWriter, loader, model_type="Teacher", model = None ):
    print("-" * 40)
    print(f"~~~~~~~~  Starting the training for {model_type}... ~~~~~~")
    print("-" * 40)

    print(f"~~~~~~~~~~~ Creating the UNet model for {model_type} ~~~~~~~~~~")
    modelName = f"{model_type}_Model"
    print(" Model Name: {}".format(modelName))

    ## CREATION OF YOUR MODEL
    if model == None : 
        net = UNet(num_classes).to(device)
    else :
        net = model

    print(
        "Total params: {0:,}".format(
            sum(p.numel() for p in net.parameters() if p.requires_grad)
        )
    )

    # DEFINE YOUR OUTPUT COMPONENTS (e.g., SOFTMAX, LOSS FUNCTION, ETC)
    softMax = torch.nn.Softmax(dim=1)
    CE_loss = torch.nn.CrossEntropyLoss()

    # # PUT EVERYTHING IN GPU RESOURCES
    # if torch.cuda.is_available():
    #     net.cuda()
    #     softMax.cuda()
    #     CE_loss.cuda()

    ## DEFINE YOUR OPTIMIZER
    optimizer = torch.optim.Adam(net.parameters(), lr=lr) #optimizer used (momentum +SGD)

    ### To save statistics ####
    train_losses = []
    train_dc_losses = []
    val_losses = []
    val_dc_losses = []

    best_loss_val = 1000

    directory = f"Results/Statistics/{model_type}/" + modelName
    print(f"~~~~~~~~~~~ Saving results in: {directory} ~~~~~~~~~~")

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    if os.path.exists(directory) == False:
        os.makedirs(directory)

    ## START THE TRAINING

    ## FOR EACH EPOCH
    for epoch in range(total_epochs):
        net.train()

        num_batches = len(loader) #26 batches of 8 images
        print("Number of batches: ", num_batches)

        running_train_loss = 0
        running_dice_loss = 0

        # Training loop
        for idx, data in enumerate(loader): #idx : current batch number, data : images in that batch
            ### Set to zero all the gradients
            net.zero_grad()
            optimizer.zero_grad()

            ## GET IMAGES, LABELS and IMG NAMES
            images, labels, img_names = data #images with the corresponding label 

            ### From numpy to torch variables
            labels = utils.to_var(labels).to(device)
            images = utils.to_var(images).to(device)

            # Forward pass
            net_predictions = net(images) #go through unet with images and get predictions (probabilities tensor)
            #tensor of tables of predicted probabilities
            

            # Get the segmentation classes
            segmentation_classes = utils.getTargetSegmentation(labels) #tensor of tables of 0(background), 1(1-tbd), 2(2-tbd), 3(3-tbd)

            # Compute the loss
            loss = CE_loss(net_predictions, segmentation_classes) #compare results and real segmentation
            running_train_loss += loss.item()
            # dice_loss = dice_coefficient(net_predictions, labels)
            dice_loss = utils.compute_dsc(net_predictions, labels)
            running_dice_loss += dice_loss

            # Backprop
            loss.backward() #calculates gradients from loss
            optimizer.step() #use gradients just calculated to update weights of model 
            torch.cuda.empty_cache()

            # Add the loss to the tensorboard every 5 batches
            if idx % 10 == 0:
                writer.add_scalar(
                    f"Loss/train/{model_type}", running_train_loss / (idx + 1), epoch * len(loader) + idx
                )
                writer.add_scalar(
                    f"Dice/train/{model_type}", running_dice_loss / (idx + 1), epoch * len(loader) + idx
                )

            if idx % 100 == 0:
                # Also add visualizations of the images
                probs = torch.softmax(net_predictions, dim=1)
                
                
                y_pred = torch.argmax(probs, dim=1)
                print("Images min/max:", images.min().item(), images.max().item())
                print("Predictions min/max:", y_pred.min().item(), y_pred.max().item())
                
                writer.add_figure('predictions vs. actuals',
                            utils.plot_net_predictions(images, labels, y_pred, batch_size),
                            global_step=epoch * len(loader) + idx)


            # THIS IS JUST TO VISUALIZE THE TRAINING
            printProgressBar(
                idx + 1,
                num_batches,
                prefix="[Training] Epoch: {} ".format(epoch),
                length=15,
                suffix=" Loss: {:.4f}, ".format(running_train_loss / (idx + 1)),
            )

        train_loss = running_train_loss / num_batches
        train_losses.append(train_loss)

        train_dc_loss = running_dice_loss / num_batches
        train_dc_losses.append(train_dc_loss)

        net.eval()
        val_running_loss = 0
        val_running_dc = 0

        # Validation loop
        with torch.no_grad(): #validation made every time one epoch has finished being processed
            for idx, data in enumerate(val_loader):
                images, labels, img_names = data

                labels = utils.to_var(labels).to(device)
                images = utils.to_var(images).to(device)

                net_predictions = net(images) #we try validation set in our model 

                segmentation_classes = utils.getTargetSegmentation(labels) #what should be predicted in a tensor

                loss = CE_loss(net_predictions, segmentation_classes)
                val_running_loss += loss.item()

                # dice_loss = dice_coefficient(net_predictions, labels)
                dice_loss = utils.compute_dsc(net_predictions, labels)
                val_running_dc += dice_loss

                if idx % 10 == 0:
                    writer.add_scalar(
                        f"Loss/val/{model_type}",
                        val_running_loss / (idx + 1),
                        epoch * len(val_loader) + idx,
                    )
                    writer.add_scalar(
                        f"Dice/val/{model_type}",
                        val_running_dc / (idx + 1),
                        epoch * len(val_loader) + idx,
                    )

                printProgressBar(
                    idx + 1,
                    len(val_loader),
                    prefix="[Validation] Epoch: {} ".format(epoch),
                    length=15,
                    suffix=" Loss: {:.4f}, ".format(val_running_loss / (idx + 1)),
                )

        val_loss = val_running_loss / len(val_loader)
        val_losses.append(val_loss)
        dc_loss = val_running_dc / len(val_loader)
        val_dc_losses.append(dc_loss)

        # Check if model performed best and save it if true
        if val_loss < best_loss_val:
            best_loss_val = val_loss
            if not os.path.exists(f"./models/{model_type}"):
                os.makedirs(f"./models/{model_type}")
                
            torch.save(
                net.state_dict(), f"./models/{model_type}" + "/" + str(epoch) + "_Epoch"
            )

        printProgressBar(
            num_batches,
            num_batches,
            done="[Epoch: {}, TrainLoss: {:.4f}, TrainDice: {:.4f}, ValLoss: {:.4f}".format(
                epoch, train_loss, train_dc_loss, val_loss
            ),
        )

        np.save(os.path.join(directory, "Losses.npy"), train_losses)
    writer.flush()  # Flush the writer to ensure that all the data is written to disk

In [22]:
def runTraining2(writer: SummaryWriter, loader, model_type="Student", model = None, isUnlabeled = False ):
    print("-" * 40)
    print(f"~~~~~~~~  Starting the training for {model_type}... ~~~~~~")
    print("-" * 40)

    print(f"~~~~~~~~~~~ Creating the UNet model for {model_type} ~~~~~~~~~~")
    modelName = f"{model_type}_Model"
    print(" Model Name: {}".format(modelName))

    ## CREATION OF YOUR MODEL
    if model == None : 
        net = UNet(num_classes).to(device)
    else :
        net = model

    print(
        "Total params: {0:,}".format(
            sum(p.numel() for p in net.parameters() if p.requires_grad)
        )
    )

    # DEFINE YOUR OUTPUT COMPONENTS (e.g., SOFTMAX, LOSS FUNCTION, ETC)
    softMax = torch.nn.Softmax(dim=1)
    CE_loss = torch.nn.CrossEntropyLoss()

    # # PUT EVERYTHING IN GPU RESOURCES
    # if torch.cuda.is_available():
    #     net.cuda()
    #     softMax.cuda()
    #     CE_loss.cuda()

    ## DEFINE YOUR OPTIMIZER
    optimizer = torch.optim.Adam(net.parameters(), lr=lr) #optimizer used (momentum +SGD)

    ### To save statistics ####
    train_losses = []
    train_dc_losses = []
    val_losses = []
    val_dc_losses = []

    best_loss_val = 1000

    directory = f"Results/Statistics/{model_type}/" + modelName
    print(f"~~~~~~~~~~~ Saving results in: {directory} ~~~~~~~~~~")

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    if os.path.exists(directory) == False:
        os.makedirs(directory)

    ## START THE TRAINING

    ## FOR EACH EPOCH
    for epoch in range(total_epochs):
        net.train()

        num_batches = len(loader) #26 batches of 8 images
        print("Number of batches: ", num_batches)

        running_train_loss = 0
        running_dice_loss = 0

        # Training loop
        for idx, data in enumerate(loader): #idx : current batch number, data : images in that batch
            ### Set to zero all the gradients
            net.zero_grad()
            optimizer.zero_grad()

            ## GET IMAGES, LABELS and IMG NAMES
            images, labels, img_names = data #images with the corresponding label 

            ### From numpy to torch variables
            labels = utils.to_var(labels).to(device)
            images = utils.to_var(images).to(device)

            # Forward pass
            net_predictions = net(images) #go through unet with images and get predictions (probabilities tensor)
            #tensor of tables of predicted probabilities

            # Get the segmentation classes
            segmentation_classes = utils.getTargetSegmentation(labels) #tensor of tables of 0(background), 1(1-tbd), 2(2-tbd), 3(3-tbd)

            #print(f"Shape of net_predictions: {net_predictions.size()}")
            #print(f"Shape of segmentation_classes: {segmentation_classes.size()}")


            if(isUnlabeled) :
                teacher_probs_path = "./Data/train/Img-UnlabeledProbabilities/" #on recup les distributions de proba du teacher pour les unlabeled
                batch_teacher_probs = []
                for img_name in img_names: #parcourir toutes les images du batch
                    img_base_name = os.path.basename(img_name)
                    img_base_name = os.path.splitext(img_base_name)[0]
                    prob_file_path = os.path.join(teacher_probs_path, f"{img_base_name}.npy") #on trouve le chemin exact de la distribution de proba de l'image
                    # on doit charger la distribution de probas et la convertir en tenseur pytorch
                    teacher_probs = torch.tensor(np.load(prob_file_path)).to(device)
                    batch_teacher_probs.append(teacher_probs)
                segmentation_classes_teacher = torch.stack(batch_teacher_probs)

                net_predictions_out = torch.softmax(net_predictions, dim=1)
                loss = distillation_loss(net_predictions_out, segmentation_classes_teacher)
            else :
                loss = CE_loss(net_predictions, segmentation_classes) #compare results and real segmentation
            running_train_loss += loss.item()
            # dice_loss = dice_coefficient(net_predictions, labels)
            dice_loss = utils.compute_dsc(net_predictions, labels)
            running_dice_loss += dice_loss

            # Backprop
            loss.backward() #calculates gradients from loss
            optimizer.step() #use gradients just calculated to update weights of model 
            torch.cuda.empty_cache()

            # Add the loss to the tensorboard every 5 batches
            if idx % 10 == 0:
                writer.add_scalar(
                    f"Loss/train/{model_type}", running_train_loss / (idx + 1), epoch * len(loader) + idx
                )
                writer.add_scalar(
                    f"Dice/train/{model_type}", running_dice_loss / (idx + 1), epoch * len(loader) + idx
                )

            if idx % 100 == 0:
                # Also add visualizations of the images
                probs = torch.softmax(net_predictions, dim=1)
                y_pred = torch.argmax(probs, dim=1)
                print("Images min/max:", images.min().item(), images.max().item())
                print("Predictions min/max:", y_pred.min().item(), y_pred.max().item())
                
                writer.add_figure('predictions vs. actuals',
                            utils.plot_net_predictions(images, labels, y_pred, batch_size),
                            global_step=epoch * len(loader) + idx)


            # THIS IS JUST TO VISUALIZE THE TRAINING
            printProgressBar(
                idx + 1,
                num_batches,
                prefix="[Training] Epoch: {} ".format(epoch),
                length=15,
                suffix=" Loss: {:.4f}, ".format(running_train_loss / (idx + 1)),
            )

        train_loss = running_train_loss / num_batches
        train_losses.append(train_loss)

        train_dc_loss = running_dice_loss / num_batches
        train_dc_losses.append(train_dc_loss)

        net.eval()
        val_running_loss = 0
        val_running_dc = 0

        # Validation loop
        with torch.no_grad(): #validation made every time one epoch has finished being processed
            for idx, data in enumerate(val_loader):
                images, labels, img_names = data

                labels = utils.to_var(labels).to(device)
                images = utils.to_var(images).to(device)

                net_predictions = net(images) #we try validation set in our model 

                segmentation_classes = utils.getTargetSegmentation(labels) #what should be predicted in a tensor

                loss = CE_loss(net_predictions, segmentation_classes)
                val_running_loss += loss.item()

                # dice_loss = dice_coefficient(net_predictions, labels)
                dice_loss = utils.compute_dsc(net_predictions, labels)
                val_running_dc += dice_loss

                if idx % 10 == 0:
                    writer.add_scalar(
                        f"Loss/val/{model_type}",
                        val_running_loss / (idx + 1),
                        epoch * len(val_loader) + idx,
                    )
                    writer.add_scalar(
                        f"Dice/val/{model_type}",
                        val_running_dc / (idx + 1),
                        epoch * len(val_loader) + idx,
                    )

                printProgressBar(
                    idx + 1,
                    len(val_loader),
                    prefix="[Validation] Epoch: {} ".format(epoch),
                    length=15,
                    suffix=" Loss: {:.4f}, ".format(val_running_loss / (idx + 1)),
                )

        val_loss = val_running_loss / len(val_loader)
        val_losses.append(val_loss)
        dc_loss = val_running_dc / len(val_loader)
        val_dc_losses.append(dc_loss)

        # Check if model performed best and save it if true
        if val_loss < best_loss_val:
            best_loss_val = val_loss
            if not os.path.exists(f"./models/{model_type}"):
                os.makedirs(f"./models/{model_type}")
                
            torch.save(
                net.state_dict(), f"./models/{model_type}" + "/" + str(epoch) + "_Epoch"
            )

        printProgressBar(
            num_batches,
            num_batches,
            done="[Epoch: {}, TrainLoss: {:.4f}, TrainDice: {:.4f}, ValLoss: {:.4f}".format(
                epoch, train_loss, train_dc_loss, val_loss
            ),
        )

        np.save(os.path.join(directory, "Losses.npy"), train_losses)
    

In [None]:
#Code de base
# Set up Tensorboard writer
writer = SummaryWriter()

runTraining(writer, train_loader_full)

writer.close()

In [12]:
#on utilise modele Teacher (n°28 ici) pour générer prédictions sur unlabeled

#Teacher predictions on unlabeled (attention les teachers sont placés dans le dossier models/Teacher)!
epoch_to_load = 28 #num du modele à charger
model = UNet(4)
model.load_state_dict(torch.load(f"./models/Teacher/{epoch_to_load}_Epoch")) #charge le fichiers des poids du Teacher
inferenceTeacher(model, unlabeledEval_loader_full, 'TeacherUnlabeledPredictions', epoch_to_load, device) #predictions sur les unlabeled
#predictions enregistrees dans Results/Images/TeacherUnlabeledPredictions/{numEpoc



[Inference] Teacher Inference Done !                                                                         


In [23]:
#Student predictions on labeled + unlabeled (!cette query doit rester ici car il faut d'abord que le teacher ait généré étiquettes avant de récupérer ses prédictions)
writer = SummaryWriter()
unlabeledTrain_set_full = medicalDataLoader.MedicalImageDataset('unlabeledTrain',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=mask_transform,
                                                    augment=False,
                                                    equalize=False) #no transformation for now


unlabeledTrain_loader_full = DataLoader(unlabeledTrain_set_full,
                            batch_size=batch_size,
                            worker_init_fn=np.random.seed(0),
                            num_workers=0,
                            shuffle=True)

#le student se base sur les unlabeled et les labeled
# studentTrain_set_full = ConcatDataset([train_set_full, unlabeledTrain_set_full])

# studentTrain_loader_full = DataLoader(studentTrain_set_full,
#                              batch_size=batch_size,
#                              worker_init_fn=np.random.seed(0),
#                              num_workers=0,
#                              shuffle=True)




epoch_to_load = 28  # numéro teacher dont on charge les poids pour le student
teacher_model_path = f"./models/Teacher/{epoch_to_load}_Epoch"


student_model = UNet(num_classes).to(device)

#charger les poids du teacher dans le student
try:
    student_model.load_state_dict(torch.load(teacher_model_path))
    print(f"Les poids du modèle Teacher (époque {epoch_to_load}) ont été chargés avec succès dans le modèle Student.")
except Exception as e:
    print(f"Erreur lors du chargement des poids : {e}")




runTraining2(writer, train_loader_full, 'Student', student_model)
runTraining2(writer, unlabeledTrain_loader_full, 'Student', student_model, True )


writer.flush()
writer.close()

Les poids du modèle Teacher (époque 28) ont été chargés avec succès dans le modèle Student.
----------------------------------------
~~~~~~~~  Starting the training for Student... ~~~~~~
----------------------------------------
~~~~~~~~~~~ Creating the UNet model for Student ~~~~~~~~~~
 Model Name: Student_Model
Total params: 60,664
~~~~~~~~~~~ Saving results in: Results/Statistics/Student/Student_Model ~~~~~~~~~~
~~~~~~~~~~~ Starting the training ~~~~~~~~~~
Number of batches:  26
Images min/max: 0.0 1.0
Predictions min/max: 0 3
[Training] Epoch: 0 [DONE]                                 
[Validation] Epoch: 0 [DONE]                                 
[Epoch: 0, TrainLoss: 0.0250, TrainDice: 0.0544, ValLoss: 0.0712                                             
Number of batches:  26
Images min/max: 0.0 1.0
Predictions min/max: 0 3
[Training] Epoch: 1 [DONE]                                 
[Validation] Epoch: 1 [DONE]                                 
[Epoch: 1, TrainLoss: 0.0245, TrainDice