In [1]:
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from progressBar import printProgressBar

import medicalDataLoader
import argparse
import utils
from utils import inferenceTeacher


from UNet_Base import *
import random
import torch
import pdb
import matplotlib.pyplot as plt
import numpy as np
import os

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
#!poetry shell
#!poetry run tensorboard --logdir=runs

In [4]:
classes = ("background", "1–tbd", "2–tbd", "3–tbd")

In [5]:
# Inspired by https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

def images_to_probs(net, images):
    '''
    Generates predictions and corresponding probabilities from a trained
    network and a list of images
    '''
    output = net(images)
    # convert output probabilities to predicted class
    _, preds_tensor = torch.max(output, 1)
    preds = np.squeeze(preds_tensor.numpy())
    return preds, [F.softmax(el, dim=0)[i].item() for i, el in zip(preds, output)]

def plot_classes_preds(net, images, labels):
    '''
    Generates matplotlib Figure using a trained network, along with images
    and labels from a batch, that shows the network's top prediction along
    with its probability, alongside the actual label, coloring this
    information based on whether the prediction was correct or not.
    Uses the "images_to_probs" function.
    '''
    preds, probs = images_to_probs(net, images)
    # plot the images in the batch, along with predicted and true labels
    fig = plt.figure(figsize=(12, 48))
    for idx in np.arange(4):
        ax = fig.add_subplot(1, 4, idx+1, xticks=[], yticks=[])
        matplotlib_imshow(images[idx], one_channel=True)
        ax.set_title("{0}, {1:.1f}%\n(label: {2})".format(
            classes[preds[idx]],
            probs[idx] * 100.0,
            classes[labels[idx]]),
                    color=("green" if preds[idx]==labels[idx].item() else "red"))
    return fig

In [6]:
def dice_coefficient(prediction, target, epsilon=1e-07): #compares this prediction model to validation model 
    prediction_copy = prediction.clone()

    prediction_copy[prediction_copy < 0] = 0
    prediction_copy[prediction_copy > 0] = 1

    intersection = abs(torch.sum(prediction_copy * target))
    union = abs(torch.sum(prediction_copy) + torch.sum(target))
    dice = (2. * intersection + epsilon) / (union + epsilon)
    
    return dice

In [7]:
# Define hyperparameters
batch_size = 8 #nb images processed at same time during training
batch_size_val = 4 #nb images processed at same time during validation 
lr =  0.01   # Learning Rate
total_epochs = 50  # Number of epochs (how many times the algorithm passes through training data)
num_classes = 4

In [8]:
# Set device depending on the availability of GPU
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.mps.is_available():  # Apple M-series of chips
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: cpu


In [9]:
# Define transformations
transform = transforms.Compose([
    transforms.ToTensor()
])

mask_transform = transforms.Compose([
    transforms.ToTensor()
])

# Define dataloaders
root_dir = './data/'
print(' Dataset: {} '.format(root_dir))

train_set_full = medicalDataLoader.MedicalImageDataset('train',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=mask_transform,
                                                    augment=False,
                                                    equalize=False)

train_loader_full = DataLoader(train_set_full,
                            batch_size=batch_size,
                            worker_init_fn=np.random.seed(0),
                            num_workers=0,
                            shuffle=True)

selftrain_set_full = medicalDataLoader.MedicalImageDataset('selfTrain',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=None,
                                                    augment=False,
                                                    equalize=False) #no transformation for now


selftrain_loader_full = DataLoader(selftrain_set_full,
                            batch_size=batch_size,
                            worker_init_fn=np.random.seed(0),
                            num_workers=0,
                            shuffle=True)


val_set = medicalDataLoader.MedicalImageDataset('val',
                                                root_dir,
                                                transform=transform,
                                                mask_transform=mask_transform,
                                                equalize=False)

val_loader = DataLoader(val_set,
                        batch_size=batch_size_val,
                        worker_init_fn=np.random.seed(0),
                        num_workers=0,
                        shuffle=False)

 Dataset: ./data/ 


In [10]:
for i, sample in enumerate(train_loader_full):
    images, masks, _ = sample
    print('Image batch dimensions: ', images.size()) #[batch_size, channels, height, width]
    print('Mask batch dimensions: ', masks.size())
    
    break

Image batch dimensions:  torch.Size([8, 1, 256, 256])
Mask batch dimensions:  torch.Size([8, 1, 256, 256])


In [11]:
def runTraining(writer: SummaryWriter):
    print("-" * 40)
    print("~~~~~~~~  Starting the training... ~~~~~~")
    print("-" * 40)

    print("~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~")
    modelName = "Test_Model"
    print(" Model Name: {}".format(modelName))

    ## CREATION OF YOUR MODEL
    teacherNet = UNet(num_classes).to(device)

    print(
        "Total params: {0:,}".format(
            sum(p.numel() for p in teacherNet.parameters() if p.requires_grad)
        )
    )

    # DEFINE YOUR OUTPUT COMPONENTS (e.g., SOFTMAX, LOSS FUNCTION, ETC)
    softMax = torch.nn.Softmax(dim=1)
    CE_loss = torch.nn.CrossEntropyLoss()

    # # PUT EVERYTHING IN GPU RESOURCES
    # if torch.cuda.is_available():
    #     teacherNet.cuda()
    #     softMax.cuda()
    #     CE_loss.cuda()

    ## DEFINE YOUR OPTIMIZER
    optimizer = torch.optim.Adam(teacherNet.parameters(), lr=lr) #optimizer used (momentum +SGD)

    ### To save statistics ####
    train_losses = []
    train_dc_losses = []
    val_losses = []
    val_dc_losses = []

    best_loss_val = 1000

    directory = "Results/Statistics/" + modelName

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    if os.path.exists(directory) == False:
        os.makedirs(directory)

    ## START THE TRAINING

    ## FOR EACH EPOCH
    for epoch in range(total_epochs):
        teacherNet.train()

        num_batches = len(train_loader_full) #26 batches of 8 images
        print("Number of batches: ", num_batches)

        running_train_loss = 0
        running_dice_loss = 0

        # Training loop
        for idx, data in enumerate(train_loader_full): #idx : current batch number, data : images in that batch
            ### Set to zero all the gradients
            teacherNet.zero_grad()
            optimizer.zero_grad()

            ## GET IMAGES, LABELS and IMG NAMES
            images, labels, img_names = data #images with the corresponding label 

            ### From numpy to torch variables
            labels = utils.to_var(labels).to(device)
            images = utils.to_var(images).to(device)

            # Forward pass
            teacherNet_predictions = teacherNet(images) #go through unet with images and get predictions (probabilities tensor)
            #tensor of tables of predicted probabilities
            

            # Get the segmentation classes
            segmentation_classes = utils.getTargetSegmentation(labels) #tensor of tables of 0(background), 1(1-tbd), 2(2-tbd), 3(3-tbd)

            # Compute the loss
            loss = CE_loss(teacherNet_predictions, segmentation_classes) #compare results and real segmentation
            running_train_loss += loss.item()
            # dice_loss = dice_coefficient(teacherNet_predictions, labels)
            dice_loss = utils.compute_dsc(teacherNet_predictions, labels)
            running_dice_loss += dice_loss

            # Backprop
            loss.backward() #calculates gradients from loss
            optimizer.step() #use gradients just calculated to update weights of model 

            # Add the loss to the tensorboard every 5 batches
            # if idx % 10 == 0:
            #     writer.add_scalar(
            #         "Loss/train", running_train_loss / (idx + 1), epoch * len(train_loader_full) + idx
            #     )
            #     writer.add_scalar(
            #         "Dice/train", running_dice_loss / (idx + 1), epoch * len(train_loader_full) + idx
            #     )

            if idx % 100 == 0:
                # Also add visualizations of the images
                probs = torch.softmax(teacherNet_predictions, dim=1)
                
                
                y_pred = torch.argmax(probs, dim=1)
                print("Images min/max:", images.min().item(), images.max().item())
                print("Predictions min/max:", y_pred.min().item(), y_pred.max().item())
                #print(y_pred)
                
                # writer.add_figure('predictions vs. actuals',
                #             utils.plot_net_predictions(images, labels, y_pred, batch_size),
                #             global_step=epoch * len(train_loader_full) + idx)

            # THIS IS JUST TO VISUALIZE THE TRAINING
            printProgressBar(
                idx + 1,
                num_batches,
                prefix="[Training] Epoch: {} ".format(epoch),
                length=15,
                suffix=" Loss: {:.4f}, ".format(running_train_loss / (idx + 1)),
            )

        train_loss = running_train_loss / num_batches
        train_losses.append(train_loss)

        train_dc_loss = running_dice_loss / num_batches
        train_dc_losses.append(train_dc_loss)

        teacherNet.eval()
        val_running_loss = 0
        val_running_dc = 0

        # Validation loop
        with torch.no_grad(): #validation made every time one epoch has finished being processed
            for idx, data in enumerate(val_loader):
                images, labels, img_names = data

                labels = utils.to_var(labels).to(device)
                images = utils.to_var(images).to(device)

                teacherNet_predictions = teacherNet(images) #we try validation set in our model 

                segmentation_classes = utils.getTargetSegmentation(labels) #what should be predicted in a tensor

                loss = CE_loss(teacherNet_predictions, segmentation_classes)
                val_running_loss += loss.item()

                # dice_loss = dice_coefficient(teacherNet_predictions, labels)
                dice_loss = utils.compute_dsc(teacherNet_predictions, labels)
                val_running_dc += dice_loss

                # if idx % 10 == 0:
                #     writer.add_scalar(
                #         "Loss/val",
                #         val_running_loss / (idx + 1),
                #         epoch * len(val_loader) + idx,
                #     )
                #     writer.add_scalar(
                #         "Dice/val",
                #         val_running_dc / (idx + 1),
                #         epoch * len(val_loader) + idx,
                #     )

                printProgressBar(
                    idx + 1,
                    len(val_loader),
                    prefix="[Validation] Epoch: {} ".format(epoch),
                    length=15,
                    suffix=" Loss: {:.4f}, ".format(val_running_loss / (idx + 1)),
                )

        val_loss = val_running_loss / len(val_loader)
        val_losses.append(val_loss)
        dc_loss = val_running_dc / len(val_loader)
        val_dc_losses.append(dc_loss)

        # Check if model performed best and save it if true
        if val_loss < best_loss_val:
            best_loss_val = val_loss
            if not os.path.exists("./models/" + modelName):
                os.makedirs("./models/" + modelName)
            torch.save(
                teacherNet.state_dict(), "./models/" + modelName + "/" + str(epoch) + "_Epoch"
            )

        printProgressBar(
            num_batches,
            num_batches,
            done="[Epoch: {}, TrainLoss: {:.4f}, TrainDice: {:.4f}, ValLoss: {:.4f}".format(
                epoch, train_loss, train_dc_loss, val_loss
            ),
        )

        np.save(os.path.join(directory, "Losses.npy"), train_losses)
    writer.flush()  # Flush the writer to ensure that all the data is written to disk
    return teacherNet

In [12]:
#selfTrain with unlabeled images
def runSelfTraining(teacherModel, writer: SummaryWriter): #garder writer: SummaryWriter ? 
    teacherModel.eval()
    print("-" * 40)
    print("~~~~~~~~  Starting the self training... ~~~~~~")
    print("-" * 40)

    num_images = len(selftrain_loader_full.dataset)
    print("num images fourni : ", num_images)

    with torch.no_grad():
        for idx, data in enumerate(selftrain_loader_full):
            images, img_names = data
            
            images = utils.to_var(images).to(device)

            studentNet_predictions = teacherModel(images)
            


            if idx % 100 == 0:
                # Also add visualizations of the images
                probs = torch.softmax(studentNet_predictions, dim=1)
                y_pred = torch.argmax(probs, dim=1)

                
                
                writer.add_figure('predictions on unlabeled',
                            utils.plot_net_predictions_without_ground_truth(images, y_pred, img_names, batch_size)                            )
    print("hellooo")
    
    

In [None]:
# Set up Tensorboard writer
writer = SummaryWriter()
#training du teacher
teacherNet = runTraining(writer)

writer.close()

In [13]:
from utils import inference

epoch_to_load = 41 #num du modele à charger
model = UNet(4)
model.load_state_dict(torch.load(f"./models/Test_Model/{epoch_to_load}_Epoch")) #charge le fichiers des poids
#inf_losses = inference(model, val_loader, "Student", epoch_to_load)

<All keys matched successfully>

In [14]:
writer = SummaryWriter()

#predictions on unlabeled images
#teacherNet = UNet(num_classes).to(device)
# epoch = 41
# checkpoint_path = "./models/" + "Test_Model" + "/" + str(epoch) + "_Epoch"
#teacherNet.load_state_dict(torch.load(checkpoint_path))
# model.eval()
# runSelfTraining(model, writer)


epoch_to_load = 41 #num du modele à charger
model = UNet(4)
inferenceTeacher(model, selftrain_loader_full, 'TeacherUnlabeledPredictions', epoch_to_load, device)
writer.close()

nb images à traiter  126
[Inference] Segmentation Done !                                                                              
