In [1]:
from dotenv import load_dotenv
import os
import sys
load_dotenv()
sys.path.append(os.getenv('BASE_DIR'))
from Efficientunet.liver_abd_efficientunet import get_efficientunet_b0
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from liver_abd_efficientunet import get_efficientunet_b0_parallel, get_efficientunet_b0_shared_decoder, get_efficientunet_b0
from shared_functions.data import ImageMaskDataset


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [2]:
from shared_functions.metrics import dice_coef, fpr, hausdorff_distance, tpr, dice_loss_multi
from shared_functions.data import load_data_loaders
n_splits = 3

n_epochs = 150
batch_size = 4

torch.cuda.empty_cache()

dataloader_dir = os.getenv("DATALOADER_MULTI_DIR")
train_loaders, test_loaders = load_data_loaders(os.path.join(dataloader_dir, 'train_loaders.pth'), os.path.join(dataloader_dir, 'test_loaders.pth'))
alpha = 0.5

  train_loaders = torch.load(train_path)
  test_loaders = torch.load(test_path)


In [3]:
torch.cuda.empty_cache()

performance_file = os.getenv('BASE_DIR') + '/Efficientunet/results/multiclass/performance.txt'
model_dir = os.getenv('BASE_DIR') + '/Efficientunet/results/multiclass/models/'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
    
    
for fold in range(n_splits):
    train_loader = train_loaders[fold]
    test_loader = test_loaders[fold]

    # Use the multi-class segmentation model.
    model = get_efficientunet_b0(out_channels=3, concat_input=False, pretrained=False, multi_class=True).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

    criterion = dice_loss_multi()

    # Training
    for epoch in range(n_epochs):
        model.train()
        running_loss = 0.0
        for images, _, liver_masks, abd_masks in train_loader:
            images = images.to(device)
            liver_masks = liver_masks.to(device)
            abd_masks = abd_masks.to(device)
            # Combine the two binary masks into a single multi-class target.
            # Assumption: liver_masks and abd_masks are mutually exclusive.
            #   background: 0, liver: 1, abd wall: 2.
            target = liver_masks + abd_masks * 2

            optimizer.zero_grad()
            outputs = model(images)  # outputs: (N, 3, H, W) with softmax probabilities
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Fold [{fold+1}/{n_splits}], Epoch [{epoch+1}/{n_epochs}] - Loss: {running_loss / len(train_loader)}')
        scheduler.step(running_loss / len(train_loader))

        # Evaluation after each 25 epoch
        if (epoch + 1) % 25 == 0:
            model.eval()
            liver_dice_scores = []
            liver_hausdorff_distances = []
            liver_tprs = []
            liver_fprs = []

            abd_dice_scores = []
            abd_hausdorff_distances = []
            abd_tprs = []
            abd_fprs = []

            with torch.no_grad():
                for images, _, liver_masks, abd_masks in test_loader:
                    images = images.to(device)
                    liver_masks = liver_masks.to(device)
                    abd_masks = abd_masks.to(device)
                    outputs = model(images) # (N,3,H,W)
                    # Get discrete predictions by taking argmax.
                    preds = torch.argmax(outputs, dim=1).cpu().numpy()  # (N,H,W)
                    # Recover separate binary predictions:
                    liver_preds = (preds == 1).astype('float32')
                    abd_preds = (preds == 2).astype('float32')

                    # Also bring ground truth to numpy.
                    liver_gt = liver_masks.cpu().numpy()
                    abd_gt = abd_masks.cpu().numpy()

                    for i in range(liver_preds.shape[0]):
                        liver_hausdorff_distances.append(hausdorff_distance(liver_gt[i], liver_preds[i]))
                        liver_dice_scores.append(dice_coef(liver_gt[i], liver_preds[i]))
                        liver_tprs.append(tpr(liver_gt[i], liver_preds[i]))
                        liver_fprs.append(fpr(liver_gt[i], liver_preds[i]))

                        abd_hausdorff_distances.append(hausdorff_distance(abd_gt[i], abd_preds[i]))
                        abd_dice_scores.append(dice_coef(abd_gt[i], abd_preds[i]))
                        abd_tprs.append(tpr(abd_gt[i], abd_preds[i]))
                        abd_fprs.append(fpr(abd_gt[i], abd_preds[i]))

            liver_avg_dice_coef = np.mean(liver_dice_scores)
            liver_avg_tpr = np.mean(liver_tprs)
            liver_avg_hausdorff = np.mean(liver_hausdorff_distances)
            liver_avg_fpr = np.mean(liver_fprs)

            abd_avg_dice_coef = np.mean(abd_dice_scores)
            abd_avg_tpr = np.mean(abd_tprs)
            abd_avg_hausdorff = np.mean(abd_hausdorff_distances)
            abd_avg_fpr = np.mean(abd_fprs)

            print(f'Fold [{fold+1}/{n_splits}] - Liver - Average Dice Coef: {liver_avg_dice_coef}, '
                  f'Average TPR: {liver_avg_tpr}, Average FPR: {liver_avg_fpr}, '
                  f'Average Hausdorff: {liver_avg_hausdorff}')
            print(f'Fold [{fold+1}/{n_splits}] - Abd Wall - Average Dice Coef: {abd_avg_dice_coef}, '
                  f'Average TPR: {abd_avg_tpr}, Average FPR: {abd_avg_fpr}, '
                  f'Average Hausdorff: {abd_avg_hausdorff}')

    model_save_path = os.path.join(model_dir, f'model_fold_{fold+1}.pth')
    torch.save(model.state_dict(), model_save_path)
    print(f'Model saved at {model_save_path}')

    # Final Evaluation
    model.eval()
    liver_dice_scores = []
    liver_hausdorff_distances = []
    liver_tprs = []
    liver_fprs = []

    abd_dice_scores = []
    abd_hausdorff_distances = []
    abd_tprs = []
    abd_fprs = []

    with torch.no_grad():
        for images, _, liver_masks, abd_masks in test_loader:
            images = images.to(device)
            liver_masks = liver_masks.to(device)
            abd_masks = abd_masks.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            liver_preds = (preds == 1).astype('float32')
            abd_preds = (preds == 2).astype('float32')
            liver_gt = liver_masks.cpu().numpy()
            abd_gt = abd_masks.cpu().numpy()

            for i in range(liver_preds.shape[0]):
                liver_hausdorff_distances.append(hausdorff_distance(liver_gt[i], liver_preds[i]))
                liver_dice_scores.append(dice_coef(liver_gt[i], liver_preds[i]))
                liver_tprs.append(tpr(liver_gt[i], liver_preds[i]))
                liver_fprs.append(fpr(liver_gt[i], liver_preds[i]))

                abd_hausdorff_distances.append(hausdorff_distance(abd_gt[i], abd_preds[i]))
                abd_dice_scores.append(dice_coef(abd_gt[i], abd_preds[i]))
                abd_tprs.append(tpr(abd_gt[i], abd_preds[i]))
                abd_fprs.append(fpr(abd_gt[i], abd_preds[i]))

    liver_avg_dice_coef = np.mean(liver_dice_scores)
    liver_avg_tpr = np.mean(liver_tprs)
    liver_avg_hausdorff = np.mean(liver_hausdorff_distances)
    liver_avg_fpr = np.mean(liver_fprs)

    abd_avg_dice_coef = np.mean(abd_dice_scores)
    abd_avg_tpr = np.mean(abd_tprs)
    abd_avg_hausdorff = np.mean(abd_hausdorff_distances)
    abd_avg_fpr = np.mean(abd_fprs)

    print(f'Fold [{fold+1}/{n_splits}] - Liver - Average Dice Coef: {liver_avg_dice_coef}, '
          f'Average TPR: {liver_avg_tpr}, Average FPR: {liver_avg_fpr}, '
          f'Average Hausdorff: {liver_avg_hausdorff}')
    print(f'Fold [{fold+1}/{n_splits}] - Abd Wall - Average Dice Coef: {abd_avg_dice_coef}, '
          f'Average TPR: {abd_avg_tpr}, Average FPR: {abd_avg_fpr}, '
          f'Average Hausdorff: {abd_avg_hausdorff}')
    with open(performance_file, "a") as f:
        print(f'Fold [{fold+1}/{n_splits}] - Liver - Average Dice Coef: {liver_avg_dice_coef}, '
              f'Average TPR: {liver_avg_tpr}, Average FPR: {liver_avg_fpr}, '
              f'Average Hausdorff: {liver_avg_hausdorff}', file=f)
        print(f'Fold [{fold+1}/{n_splits}] - Abd Wall - Average Dice Coef: {abd_avg_dice_coef}, '
              f'Average TPR: {abd_avg_tpr}, Average FPR: {abd_avg_fpr}, '
              f'Average Hausdorff: {abd_avg_hausdorff}', file=f)



Fold [1/3], Epoch [1/150] - Loss: 0.595152262184355
Fold [1/3], Epoch [2/150] - Loss: 0.49579522013664246
Fold [1/3], Epoch [3/150] - Loss: 0.42429322997728985
Fold [1/3], Epoch [4/150] - Loss: 0.3473660449186961
Fold [1/3], Epoch [5/150] - Loss: 0.282897031141652
Fold [1/3], Epoch [6/150] - Loss: 0.22209685875309837
Fold [1/3], Epoch [7/150] - Loss: 0.16451577262745964
Fold [1/3], Epoch [8/150] - Loss: 0.1254791021347046
Fold [1/3], Epoch [9/150] - Loss: 0.09939005722602208
Fold [1/3], Epoch [10/150] - Loss: 0.08007390631569757
Fold [1/3], Epoch [11/150] - Loss: 0.06840906188719803
Fold [1/3], Epoch [12/150] - Loss: 0.061529749590489596
Fold [1/3], Epoch [13/150] - Loss: 0.053349432431989245
Fold [1/3], Epoch [14/150] - Loss: 0.04634307718111409
Fold [1/3], Epoch [15/150] - Loss: 0.0415016615556346
Fold [1/3], Epoch [16/150] - Loss: 0.03844418345640103
Fold [1/3], Epoch [17/150] - Loss: 0.0365379870765739
Fold [1/3], Epoch [18/150] - Loss: 0.03476950888418489
Fold [1/3], Epoch [19/150