In [None]:
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import transforms
from progressBar import printProgressBar

import medicalDataLoader
import argparse
from utils import *

from UNet_Base import *
import random
import torch
import pdb

import warnings
warnings.filterwarnings("ignore")

In [None]:
########### HYPERPARAMETERS ###########
NUM_EPOCHS = 300
BATCH_SIZE = 16
BATCH_SIZE_VAL = 8
LEARNING_RATE = 0.001

### SELF LEARNING ###
SELF_LEARNING_EPOCHS = 50
CONFIDENCE_THRESHOLD = 0.97
SIMILARITY_THRESHOLD = 0.85
GAMMA_LOSS = 0.3

### LOSS FUNCTIONS ###
# Tversky Focal Loss
ALPHA = 0.3
BETA = 0.7
GAMMA = 2.0

In [None]:
def runTraining(
        debug=False, model_name='Test', loss = 'TF',
        num_epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, batch_size_val=BATCH_SIZE_VAL, lr=LEARNING_RATE,
        self_learning_epochs=SELF_LEARNING_EPOCHS, confidence_threshold=CONFIDENCE_THRESHOLD, similarity_threshold=SIMILARITY_THRESHOLD, gamma_loss=GAMMA_LOSS,
        alpha=ALPHA, beta=BETA, gamma=GAMMA
    ):
    
    root_dir = './Data/'


    ## DATA TRANSFORMS
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    mask_transform = transforms.Compose([
        transforms.ToTensor()
    ])


    ## CREATE THE DATASET
    # Training dataset
    original_train_set = medicalDataLoader.MedicalImageDataset('train',
                                                      root_dir,
                                                      transform=transform,
                                                      mask_transform=mask_transform,
                                                      augment=True,
                                                      equalize=True)
    train_set_full = original_train_set
    # Create a DataLoader for the training set
    train_loader_full = DataLoader(train_set_full,
                              batch_size=batch_size,
                              worker_init_fn=np.random.seed(0),
                              num_workers=0,
                              shuffle=True)
    # Unlabelled dataset
    unlabeled_set = medicalDataLoader.MedicalImageDataset('unlabeled',
                                                          root_dir,
                                                          transform=transform,
                                                          mask_transform=mask_transform,
                                                          equalize=True)
    # Create a DataLoader for the unlabeled set
    unlabeled_loader = DataLoader(unlabeled_set,
                                  batch_size=batch_size,
                                  worker_init_fn=np.random.seed(0),
                                  num_workers=0,
                                  shuffle=False)
    # Validation dataset
    val_set = medicalDataLoader.MedicalImageDataset('val',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=mask_transform,
                                                    equalize=True)
    # Create a DataLoader for the validation set
    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            worker_init_fn=np.random.seed(0),
                            num_workers=0,
                            shuffle=False)


    ## INITIALIZING THE MODEL
    num_classes = 4 # NUMBER OF CLASSES
    print("~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~~")
    print(" Model Name: {}".format(model_name))
    net = UNet(num_classes)
    print("Total params: {0:,}".format(sum(p.numel() for p in net.parameters() if p.requires_grad))) if debug else None
    
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    ## OUTPUT COMPONENTS
    # Compute class weights for the cross-entropy loss
    initial_pixel_counts = compute_class_pixel_counts(train_set_full, num_classes)
    class_weights = compute_class_weights(initial_pixel_counts)

    softMax = torch.nn.Softmax(dim=1)
    TF_loss = TverskyFocalLoss(alpha=alpha, beta=beta, gamma=gamma)
    CE_loss = torch.nn.CrossEntropyLoss(weight=class_weights)
    DSC_loss = DiceLoss()
    if torch.cuda.is_available(): # Move to GPU if available
        net.cuda()
        softMax.cuda()
        TF_loss.cuda()
        CE_loss.cuda()
        DSC_loss.cuda()
    
    # Define the loss function
    def loss(predictions, targets):
        if loss == 'TF':
            return TF_loss(predictions, targets)
        elif loss == 'CE':
            return CE_loss(predictions, targets)
        elif loss == 'CE+TF':
            return CE_loss(predictions, targets) + TF_loss(predictions, targets)
        elif loss == 'DSC':
            return DSC_loss(predictions, targets)
        else:
            return TF_loss(predictions, targets)


    ## PSEUDO-LABELED DATA
    pseudo_dataset = None
    pseudo_loader = None


    ## LOSS AND METRICS
    # Loss
    directory = 'Results/Statistics/' + model_name
    if os.path.exists(directory)==False:
        os.makedirs(directory)
    lossTotalTraining = []
    val_losses = []  # List to store validation losses
    Best_loss_val = 1000
    BestEpoch = 0
    # Metrics
    metrics_directory = os.path.join(directory, 'Metrics')
    if os.path.exists(metrics_directory)==False:
        os.makedirs(metrics_directory)
    else: # Clear the directory if it exists
        for file in os.listdir(metrics_directory):
            os.remove(os.path.join(metrics_directory, file))



    ## MAIN TRAINING LOOP
    print("~~~~~~~~~~~~ Starting the training ~~~~~~~~~~~~")

    ## FOR EACH EPOCH
    for i in range(num_epochs):
        net.train()
        lossEpoch = []
        val_loss = []

        # Keep track of the number of images used for training
        num_images_trained = 0
        
        ########## TRAINING PHASE ##########
        # Create iterators for the data loaders
        train_loader_iter = iter(train_loader_full)
        pseudo_loader_iter = iter(pseudo_loader) if pseudo_dataset is not None else None

        num_batches = max(len(train_loader_full), len(pseudo_loader)) if pseudo_dataset is not None else len(train_loader_full)

        ## BATCH TRAINING LOOP
        for _ in range(num_batches):

            # Set to zero all the gradients
            net.zero_grad()
            optimizer.zero_grad()
            # Get next batch from training data
            try:
                data_train = next(train_loader_iter)
                if len(data_train[0]) == 0:
                    continue  # Skip empty batches
                
                # Process training data
                images_train, labels_train, img_names_train = data_train
                images_train = to_var(images_train)
                labels_train = to_var(labels_train)
                net_predictions_train = net(images_train)
                segmentation_classes_train = getTargetSegmentation(labels_train)
                loss_train = loss(net_predictions_train, segmentation_classes_train)
            except StopIteration:
                loss_train = 0
            
            # Get next batch from pseudo-labeled data
            if pseudo_loader_iter is not None:
                try:
                    data_pseudo = next(pseudo_loader_iter)
                    if len(data_pseudo[0]) == 0:
                        continue  # Skip empty batches

                    # Process pseudo-labeled data
                    images_pseudo, labels_pseudo, img_names_pseudo = data_pseudo
                    images_pseudo = to_var(images_pseudo)
                    labels_pseudo = to_var(labels_pseudo)
                    net_predictions_pseudo = net(images_pseudo)
                    segmentation_classes_pseudo = getTargetSegmentation(labels_pseudo)
                    loss_pseudo = loss(net_predictions_pseudo, segmentation_classes_pseudo)
                except StopIteration:
                    loss_pseudo = 0
            else:
                loss_pseudo = 0
            
            # Total loss
            total_loss = 0
            if loss_train == 0 and loss_pseudo == 0:
                continue  # Skip if both losses are zero
            elif loss_train == 0:
                total_loss = loss_pseudo
            elif loss_pseudo == 0:
                total_loss = loss_train
            else:
                total_loss = (1 - gamma_loss) * loss_train + gamma_loss * loss_pseudo
            
            if loss_train != 0:
                num_images_trained += images_train.size(0)
            if loss_pseudo != 0:
                num_images_trained += images_pseudo.size(0)

            # Backward pass and optimization
            total_loss.backward()
            optimizer.step()
            
            # Append losses for monitoring
            if loss_train != 0:
                lossEpoch.append(loss_train.cpu().data.numpy())
            if loss_pseudo != 0:
                lossEpoch.append(loss_pseudo.cpu().data.numpy())

        # Compute mean training loss for the epoch
        lossEpoch = np.asarray(lossEpoch)
        lossEpoch = lossEpoch.mean()
        lossTotalTraining.append(lossEpoch)
        print(f'Epoch {i} - Training Loss: {lossEpoch:.4f} - ', end='')
        # Save the training losses
        np.save(os.path.join(directory, 'Train_Losses.npy'), lossTotalTraining)


        ########## VALIDATION PHASE ##########
        # Set the model to evaluation mode
        net.eval()
        # Initialize metrics dictionary
        val_metrics = {
            'IoU': {cls: [] for cls in range(1, num_classes)},
            'Precision': {cls: [] for cls in range(1, num_classes)},
            'Recall': {cls: [] for cls in range(1, num_classes)},
            'DSC': {cls: [] for cls in range(1, num_classes)},
        }

        ## BATCH VALIDATION LOOP
        with torch.no_grad():  # No need to compute gradients during validation
            for j, data in enumerate(val_loader):
                if len(data[0]) == 0: continue # Skip empty batches

                # Get the data
                images, labels, img_names = data
                # (images) Shape: (batch_size, 1, height, width) | torch.float32 | range: [0, 1]
                # (labels) Shape: (batch_size, height, width) | torch.int64 | values: {0, 1, 2, 3}

                # Move to GPU if available
                labels = to_var(labels)
                images = to_var(images)

                # Forward pass
                net_predictions = net(images) # Shape: (batch_size, num_classes, height, width) | torch.float32 | range: [0, 1]
                # Get the target segmentation
                segmentation_classes = getTargetSegmentation(labels) # Shape: (batch_size, height, width) | torch.int64 | values: {0, 1, 2, 3}

                # Compute loss
                loss_value = loss(net_predictions, segmentation_classes)
                val_loss.append(loss_value.cpu().data.numpy())

                # Get predicted segmentation masks
                pred_masks = torch.argmax(net_predictions, dim=1) # Shape: (batch_size, height, width) | torch.int64 | values: {0, 1, 2, 3}

                # Compute metrics
                batch_metrics = compute_metrics(pred_masks, segmentation_classes, num_classes)
                for metric_name in val_metrics:
                    for cls in range(1, num_classes):
                        val_metrics[metric_name][cls].extend(batch_metrics[metric_name][cls])
                
                # Print some of the validation images and predictions at the end of training
                display_segmented_images(images, labels, pred_masks, num_images=images.size(1)) if (i == num_epochs - 1 and debug) else None

        # Compute mean validation loss
        val_loss_mean = np.mean(val_loss)
        val_losses.append(val_loss_mean)  # Save the validation loss
        print(f'Validation Loss: {val_loss_mean:.4f} - ', end='')
        if pseudo_dataset is not None:
            print(f'Training on {len(train_set_full)} labeled images and {len(pseudo_dataset)} unlabeled images')
        else:
            print(f'Training on {len(train_set_full)} labeled images')

        # Save the metrics for plotting later
        metrics_save_path = os.path.join(metrics_directory, f'metrics_epoch_{i}.npy')
        np.save(metrics_save_path, val_metrics)

        # Save the model if the validation loss is the best so far
        if val_loss_mean < Best_loss_val:
            Best_loss_val = val_loss_mean
            BestEpoch = i
            # Save the model
            model_save_dir = os.path.join('./Models', model_name)
            if not os.path.exists(model_save_dir):
                os.makedirs(model_save_dir)
            torch.save(net.state_dict(), os.path.join(model_save_dir, f'best_model_epoch_{i}.pth'))
            print(f'Saved model at epoch {i} with validation loss {val_loss_mean}') if debug else None
        

        ########## SELF-TRAINING PHASE ##########
        pseudo_labeled_data = None
        if i == self_learning_epochs:
            print("~~~~~~~~~~ Starting the self-training ~~~~~~~~~")

        if i >= self_learning_epochs and len(unlabeled_set) > 0:

            net.eval()  # Set the model to evaluation mode
            img_paths_to_remove = set() # Set to store image paths for removal from the unlabeled dataset
            max_confidence = 0.0
            
            # Prediction loop for the unlabeled dataset
            if unlabeled_loader is not None and len(unlabeled_set) > 0:
                with torch.no_grad():
                    for j, data in enumerate(unlabeled_loader):
                        images, _, img_names = data  # (images) Shape: (batch_size, 1, height, width) | torch.float32 | range: [0, 1]
                        images = to_var(images)

                        # Forward pass
                        net_predictions = net(images) # Shape: (batch_size, num_classes, height, width) | torch.float32

                        # Apply softmax to get probabilities
                        probs = softMax(net_predictions) # Shape: (batch_size, num_classes, height, width) | torch.float32 | range: [0, 1] (sums to 1 across classes)

                        # Get predicted masks and max probabilities per pixel
                        pred_masks = torch.argmax(probs, dim=1) # Shape: (batch_size, height, width) | torch.int64 | values: {0, 1, 2, 3}

                        max_probs = torch.max(probs, dim=1)[0] # Shape: (batch_size, height, width) | torch.float32 | range: [0, 1] (max probability across classes)

                        # Create a mask for foreground pixels (predicted as classes 1, 2, or 3)
                        foreground_mask = (pred_masks >= 1) # Shape: (batch_size, height, width) | torch.bool (True where predicted class is >= 1)

                        # Compute mean confidence over foreground pixels for each image
                        mean_confidences = []
                        for idx_in_batch in range(images.size(0)):
                            # Extract foreground mask and max probabilities for the current image
                            fg_mask = foreground_mask[idx_in_batch]  # Shape: (height, width) | torch.bool
                            probs_fg = max_probs[idx_in_batch][fg_mask]  # Shape: (num_foreground_pixels) | torch.float32 | range: [0, 1]

                            if probs_fg.numel() > 0:
                                # Compute mean confidence over foreground pixels
                                mean_confidence = probs_fg.mean()
                            else:
                                # If no foreground pixels are predicted, set mean confidence to zero
                                mean_confidence = torch.tensor(0.0, device=probs.device)

                            # Update the maximum confidence
                            if mean_confidence.item() > max_confidence:
                                max_confidence = mean_confidence.item()
                            mean_confidences.append(mean_confidence.item())

                        # For images where mean foreground confidence is above threshold
                        for idx_in_batch, mean_confidence in enumerate(mean_confidences):
                            if mean_confidence >= confidence_threshold:
                                # Get original image and predicted mask
                                original_image = images[idx_in_batch]
                                pred_mask_original = pred_masks[idx_in_batch]

                                # Apply random transformations and track them
                                transformed_image, _, transformations_dict = augment(original_image)

                                # Predict on the transformed image
                                transformed_image = transformed_image.unsqueeze(0)
                                net_predictions_transformed = net(transformed_image)
                                probs_transformed = softMax(net_predictions_transformed)
                                pred_mask_transformed = torch.argmax(probs_transformed, dim=1)

                                # De-transform the transformed mask
                                pred_mask_de_transformed = de_transform_mask(pred_mask_transformed.squeeze(0), transformations_dict)

                                # Compute similarity between the two masks
                                mean_iou = compute_mean_iou(
                                    pred_mask_original.cpu(), pred_mask_de_transformed.cpu(), num_classes
                                )

                                # Some debugging
                                print(f"Mean IoU: {mean_iou} - Confidence: {mean_confidence}") if debug else None

                                # If the IoU is above a threshold, add the pseudo-labeled data
                                if mean_iou >= similarity_threshold:
                                    pseudo_image = images[idx_in_batch].cpu() # Shape: (1, height, width) | torch.float32 | range: [0, 1]
                                    pseudo_mask_indices = pred_masks[idx_in_batch].cpu() # Shape: (height, width) | torch.int64 | values: {0, 1, 2, 3}
                                    value_map = torch.tensor([0.0, 0.33333334, 0.6666667, 0.94117647]) # Mapping of class indices to values
                                    pseudo_mask = value_map[pseudo_mask_indices] # Shape: (height, width) | torch.float32 | range: [0, 1]
                                    
                                    pseudo_img_name = img_names[idx_in_batch]

                                    # Add pseudo-labeled data to the list
                                    if pseudo_labeled_data is None:
                                        pseudo_labeled_data = [(pseudo_image, pseudo_mask, pseudo_img_name)]
                                    else:
                                        pseudo_labeled_data.append((pseudo_image, pseudo_mask, pseudo_img_name))

                                    # Mark image path for removal
                                    img_paths_to_remove.add(pseudo_img_name)
            
            
            # Add good pseudo-labeled data to the training set
            if pseudo_labeled_data is not None:
                print(f"Number of pseudo-labeled images added: {len(pseudo_labeled_data)}") if debug else None

                # Create a dataset for the pseudo-labeled data
                # (item) Shape: (image, mask, img_name)
                #     - (image) Shape: (1, height, width) | torch.float32 | range: [0, 1]
                #     - (mask) Shape: (height, width) | torch.float32 | range: [0, 1]
                #     - (img_name) String
                new_pseudo_dataset = medicalDataLoader.MedicalImageDataset(
                    mode='pseudo',
                    root_dir=root_dir,
                    transform=transform,
                    mask_transform=mask_transform,
                    augment=True,
                    equalize=True,
                    data_list=pseudo_labeled_data
                )
                # Visualize the pseudo-labeled samples
                display_dataset_samples(new_pseudo_dataset, num_samples=1) if debug else None

                # If it's the first time, assign pseudo_dataset
                if pseudo_dataset is None:
                    pseudo_dataset = new_pseudo_dataset
                else:
                    # Update the pseudo_dataset with new data
                    pseudo_dataset = ConcatDataset([pseudo_dataset, new_pseudo_dataset])

                # Create a DataLoader for the pseudo-labeled data
                pseudo_loader = DataLoader(
                    pseudo_dataset,
                    batch_size=batch_size,
                    worker_init_fn=np.random.seed(0),
                    num_workers=0,
                    shuffle=True,
                    drop_last=True
                )

                # Remove pseudo-labeled images from the unlabeled dataset
                unlabeled_set.remove_items(img_paths_to_remove)
                unlabeled_loader = DataLoader(
                    unlabeled_set,
                    batch_size=batch_size,
                    worker_init_fn=np.random.seed(0),
                    num_workers=0,
                    shuffle=False
                )

                # Clear the pseudo-labeled data for the next iteration
                pseudo_labeled_data = None
            else:
                print("No pseudo-labeled images were confident enough to be added.") if debug else None
        

        # Print number of images trained
        print(f'Images trained on: {num_images_trained}')
    
    print(f"Training completed. Best model saved at epoch {BestEpoch} with validation loss {Best_loss_val}")

    # Save the validation losses
    np.save(os.path.join(directory, 'Val_Losses.npy'), val_losses)

    # Plot the metrics
    plot_metrics(model_name, num_epochs, num_classes)
    # Plot the losses
    plot_losses(model_name)
    
    # Return best loss
    return Best_loss_val


In [None]:
BATCH_SIZE_TEST = 24

def test_model(model_name, batch_size_test=BATCH_SIZE_TEST, dataset='val'):
    # Paths
    model_dir = os.path.join('./models', model_name)
    results_dir = os.path.join('./Results/Images', model_name)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    # Load the best model (highest epoch number)
    model_files = [f for f in os.listdir(model_dir) if f.endswith('.pth')]
    if not model_files:
        print(f"No model files found in {model_dir}")
    else:
        # Extract epoch numbers ("best_model_epoch_XXX.pth")
        epoch_numbers = [int(f.split('_')[-1].split('.')[0]) for f in model_files]
        # Find the model with the highest epoch number
        best_epoch = max(epoch_numbers)
        best_model_file = f'best_model_epoch_{best_epoch}.pth'
        best_model_path = os.path.join(model_dir, best_model_file)
        print(f"Loading model: {best_model_path}")

        # Initialize the model
        net = UNet(4)
        net.load_state_dict(torch.load(best_model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
        if torch.cuda.is_available():
            net.cuda()
        net.eval()

        # Prepare the test dataset (subset of validation set)
        root_dir = './Data/'

        # Data transforms
        transform = transforms.Compose([
            transforms.ToTensor()
        ])
        mask_transform = transforms.Compose([
            transforms.ToTensor()
        ])

        # Load the test dataset
        test_set = medicalDataLoader.MedicalImageDataset(dataset,
                                                         root_dir,
                                                         transform=transform,
                                                         mask_transform=mask_transform,
                                                         equalize=True)
        test_loader = DataLoader(test_set,
                                batch_size=batch_size_test,
                                worker_init_fn=np.random.seed(0),
                                num_workers=0,
                                shuffle=False)
        
        # Compute class weights for the cross-entropy loss
        initial_pixel_counts = compute_class_pixel_counts(test_set, 4)
        class_weights = compute_class_weights(initial_pixel_counts)

        # Run inference
        mean_loss, mean_dice = inference(net, test_loader, model_name, class_weights)

        print(f"Mean Cross-Entropy Loss on test set: {mean_loss}")
        print(f"Mean Dice Coefficient on test set: {mean_dice}\n")

        return mean_loss, mean_dice

In [None]:

# Train a model
#runTraining()

# Test a model
#test_model("0_Base_Model")
