In [1]:
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import transforms
from progressBar import printProgressBar
import matplotlib.pyplot as plt
import medicalDataLoader
import argparse
from utils import *
from AttU_Net import *
import random
import torch
import pdb

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
def runTraining():
    print('-' * 40)
    print('~~~~~~~~  Starting the training... ~~~~~~')
    print('-' * 40)

    ## DEFINE HYPERPARAMETERS (batch_size > 1)
    batch_size = 16
    batch_size_val = 4
    lr = 0.001    # Learning Rate
    epoch = 20 # Number of epochs
    
    root_dir = './Data/'

    print(' Dataset: {} '.format(root_dir))

    ## DEFINE THE TRANSFORMATIONS TO DO AND THE VARIABLES FOR TRAINING AND VALIDATION
    
    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    mask_transform = transforms.Compose([
        transforms.ToTensor()
    ])
    
    transform1 = transforms.Compose([
        transforms.RandomVerticalFlip(p=1.0),
        transforms.ToTensor()
    ])

    mask_transform1 = transforms.Compose([
        transforms.RandomVerticalFlip(p=1.0),
        transforms.ToTensor()
    ])
    
    transform2 = transforms.Compose([
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.ToTensor()
    ])

    mask_transform2 = transforms.Compose([
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.ToTensor()
    ])

    train_set_full = medicalDataLoader.MedicalImageDataset('train',
                                                      root_dir,
                                                      transform=transform,
                                                      mask_transform=mask_transform,
                                                      augment=False,
                                                      equalize=False)
    
    train_set_full1 = medicalDataLoader.MedicalImageDataset('train',
                                                      root_dir,
                                                      transform=transform1,
                                                      mask_transform=mask_transform1,
                                                      augment=False,
                                                      equalize=False)
    
    train_set_full2 = medicalDataLoader.MedicalImageDataset('train',
                                                      root_dir,
                                                      transform=transform2,
                                                      mask_transform=mask_transform2,
                                                      augment=False,
                                                      equalize=False)

    concatenated_dataset = ConcatDataset([train_set_full, train_set_full1, train_set_full2])

    
    train_loader_full = DataLoader(concatenated_dataset,
                              batch_size=batch_size,
                              worker_init_fn=np.random.seed(0),
                              num_workers=0,
                              shuffle=True)

    val_set = medicalDataLoader.MedicalImageDataset('val',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=mask_transform,
                                                    equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            worker_init_fn=np.random.seed(0),
                            num_workers=0,
                            shuffle=False)


    ## INITIALIZE YOUR MODEL
    num_classes = 4 # NUMBER OF CLASSES

    print("~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~")
    modelName = 'Attention_U_Net'
    print(" Model Name: {}".format(modelName))

    ## CREATION OF YOUR MODEL
    net = AttU_Net(1, num_classes)
    #net.load_state_dict(torch.load('./models/AttU_Net_2/model'))

    print("Total params: {0:,}".format(sum(p.numel() for p in net.parameters() if p.requires_grad)))

    def dice_loss(predictions, targets, smooth=0.0001):
        total_loss = 0.0
        num_classes = 4
        weights = [0.25, 0.25, 0.25, 0.25]
        for class_idx in range(num_classes):
            class_targets = (targets == class_idx).float()
            class_predictions = predictions[:, class_idx, ...]

            intersection = torch.sum(class_predictions * class_targets)
            union = torch.sum(class_predictions) + torch.sum(class_targets) + smooth
            class_dice = (2.0 * intersection + smooth) / union
            class_loss = 1.0 - class_dice

            total_loss += weights[class_idx]*class_loss

        return total_loss
    
    # DEFINE YOUR OUTPUT COMPONENTS (e.g., SOFTMAX, LOSS FUNCTION, ETC)
    softMax = torch.nn.Softmax(dim=1)
    weights = torch.tensor([1.0, 3.0, 3.0, 3.0])
    CE_loss = torch.nn.CrossEntropyLoss()

    ## PUT EVERYTHING IN GPU RESOURCES    
    if torch.cuda.is_available():
        net.cuda()
        softMax.cuda()
        CE_loss.cuda()
        dice_loss.cuda()

    ## DEFINE YOUR OPTIMIZER
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0.0001)

    ### To save statistics ####
    lossTotalTraining = []
    lossTotalVal = []
    Best_loss_val = 1000
    BestEpoch = 0
    
    directory = 'Results/Statistics/' + modelName

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    if os.path.exists(directory)==False:
        os.makedirs(directory)

    ## START THE TRAINING
    
    ## FOR EACH EPOCH
    for i in range(epoch):
        net.train()
        lossEpoch = []
        DSCEpoch = []
        DSCEpoch_w = []
        num_batches = len(train_loader_full)
        ## FOR EACH BATCH
        for j, data in enumerate(train_loader_full):
            ### Set to zero all the gradients
            net.zero_grad()
            optimizer.zero_grad()

            ## GET IMAGES, LABELS and IMG NAMES
            images, labels, img_names = data

            ### From numpy to torch variables
            labels = to_var(labels)
            images = to_var(images)

            ################### Train ###################
            #-- The CNN makes its predictions (forward pass)
            net_predictions = net(images)
            #-- Compute the losses --#
            # THIS FUNCTION IS TO CONVERT LABELS TO A FORMAT TO BE USED IN THIS CODE
            segmentation_classes = getTargetSegmentation(labels)
            # COMPUTE THE LOSS
            CE_loss_value = CE_loss(softMax(net_predictions), segmentation_classes) # XXXXXX and YYYYYYY are your inputs for the CE
            Dice_loss_value = dice_loss(softMax(net_predictions), segmentation_classes)
            lossTotal = 0.2*Dice_loss_value+CE_loss_value

            # DO THE STEPS FOR BACKPROP (two things to be done in pytorch)
            lossTotal.backward()
            # Update weights
            optimizer.step()
            
            # THIS IS JUST TO VISUALIZE THE TRAINING 
            lossEpoch.append(lossTotal.cpu().data.numpy())
            printProgressBar(j + 1, num_batches,
                             prefix="[Training] Epoch: {} ".format(i),
                             length=15,
                             suffix=" Loss: {:.4f}, ".format(lossTotal))

        lossEpoch = np.asarray(lossEpoch)
        lossEpoch = lossEpoch.mean()

        lossTotalTraining.append(lossEpoch)

        printProgressBar(num_batches, num_batches,
                             done="[Training] Epoch: {}, LossG: {:.4f}".format(i,lossEpoch))
        
        loss_val = []
        net.eval()
        for j, data in enumerate(val_loader):
            ## GET IMAGES, LABELS and IMG NAMES
            images_val, labels_val, img_names = data

            ### From numpy to torch variables
            labels_val = to_var(labels_val)
            images_val = to_var(images_val)
            
            pred = net(images_val)
            segmentation_classes_val = getTargetSegmentation(labels_val)
            
            CE_loss_value = CE_loss(softMax(pred), segmentation_classes_val) # XXXXXX and YYYYYYY are your inputs for the CE
            Dice_loss_value = dice_loss(softMax(pred), segmentation_classes_val)
            loss_val.append((0.2*Dice_loss_value+CE_loss_value).cpu().data.numpy())
        loss_val = np.asarray(loss_val).mean()
        print("Loss validation : " + str(loss_val.item()))
        
        
        ## THIS IS HOW YOU WILL SAVE THE TRAINED MODELS AFTER EACH EPOCH. 
        ## WARNING!!!!! YOU DON'T WANT TO SAVE IT AT EACH EPOCH, BUT ONLY WHEN THE MODEL WORKS BEST ON THE VALIDATION SET!!
        if not os.path.exists('./models/' + modelName):
                os.makedirs('./models/' + modelName)
        if loss_val < Best_loss_val:
            Best_loss_val = loss_val
            BestEpoch = i
            torch.save(net.state_dict(), './models/' + modelName + '/model')
            
        np.save(os.path.join(directory, 'Losses.npy'), lossTotalTraining)


In [4]:
runTraining()

----------------------------------------
~~~~~~~~  Starting the training... ~~~~~~
----------------------------------------
 Dataset: ./Data/ 
~~~~~~~~~~~ Creating the UNet model ~~~~~~~~~~
 Model Name: Attention_U_Net
Total params: 2,184,504
~~~~~~~~~~~ Starting the training ~~~~~~~~~~
[Training] Epoch: 0 [DONE]                                 
[Training] Epoch: 0, LossG: 1.3389                                                                           
Loss validation : 1.3399007320404053
[Training] Epoch: 1 [DONE]                                 
[Training] Epoch: 1, LossG: 1.1810                                                                           
Loss validation : 1.146612524986267
[Training] Epoch: 2 [DONE]                                 
[Training] Epoch: 2, LossG: 1.0749                                                                           
Loss validation : 1.036014199256897
[Training] Epoch: 3 [DONE]                                 
[Training] Epoch: 3, LossG: 1.006