### Import libraries

In [None]:
%cd source_QMaxViT-Unet+

In [None]:
import itertools
import os
import random
import re
from glob import glob
import matplotlib.pyplot as plt
import h5py
import numpy as np
import torch
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler
import pandas as pd
import argparse
import importlib
from torch.nn.modules.loss import CrossEntropyLoss
import torch.backends.cudnn as cudnn
from torch.nn.functional import one_hot
from time import strftime
from torchvision.transforms import ToTensor
from torchvision import transforms
import torch.nn as nn
from torch.nn import functional as F
from tqdm.auto import tqdm

### Datasets

In [None]:
from dataloaders.datasets import ACDCDataset_Edge, MSCMRDataSets_Edge
from dataloaders.utils import *

In [None]:
train_set_acdc = ACDCDataset_Edge(
    base_dir="/teamspace/studios/this_studio/ACDC_augmentated_onlyEdge",
    split="train",
    transform=transforms.Compose([RandomGenerator([256,256], is_edge_mask=True)]),
    fold="fold1",
    sup_type="label",
    is_edge_mask=True,
)

val_set_acdc = ACDCDataset_Edge(
    base_dir="/teamspace/studios/this_studio/ACDC_preprocessed",
    split='val',
    transform=None,
    fold="fold1",
)

In [None]:
sample = train_set_acdc[0]
plt.subplot(1,3,1)
plt.imshow(sample['image'].squeeze(), cmap='gray')
plt.subplot(1,3,2)
plt.imshow(sample['label'])
plt.subplot(1,3,3)
plt.imshow(sample['edge_mask'].squeeze(0))

In [None]:
train_set_mscmr = MSCMRDataSets_Edge(
    base_dir="/teamspace/studios/this_studio/MSCMR_augmentated",
    split="train",
    transform=transforms.Compose([RandomGenerator([256,256], is_edge_mask=True)]),
    fold="MAAGfold",
    sup_type="label",
    train_dir="/MSCMR_training_slices", 
    is_edge_mask=True,
)

val_set_mscmr = MSCMRDataSets_Edge(
    base_dir="/teamspace/studios/this_studio/MSCMR_preprocessed",
    split='val',
    transform=None,
    val_dir="/MSCMR_testing_volumes",
)

In [None]:
sample = train_set_mscmr[0]
plt.subplot(1,3,1)
plt.imshow(sample['image'].squeeze(), cmap='gray')
plt.subplot(1,3,2)
plt.imshow(sample['label'])
plt.subplot(1,3,3)
plt.imshow(sample['edge_mask'].squeeze(0))

### Models

In [None]:
from model.qemaxvit_unet import QEMaxViT_Unet
model = QEMaxViT_Unet(num_classes=4, backbone_pretrained_pth="/teamspace/studios/this_studio/MIST/pretrained_pth/maxvit/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth")

In [None]:
### test
for i in model(torch.rand(1,3,256,256)):
    print(i.shape)

### Setup Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()
#os.environ['CUDA_VISIBLE_DEVICES']='0, 1'

deterministic = 1
if not deterministic:
    cudnn.benchmark = True
    cudnn.deterministic = False
else:
    cudnn.benchmark = False
    cudnn.deterministic = True

seed = 444
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

def worker_init_fn(worker_id):
    random.seed(seed + worker_id)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

batch_size = 12
trainloader = DataLoader(
    train_set_acdc,
    batch_size=batch_size,
    shuffle=True,
    num_workers=12,
    pin_memory=True,
    worker_init_fn=worker_init_fn)

valloader = DataLoader(
    val_set_acdc,
    batch_size=1,
    shuffle=False,
    num_workers=0)

In [None]:
from utils.losses import pDLoss
from utils import pyutils
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch import optim

num_classes = 4
ce_loss = CrossEntropyLoss(ignore_index=4)
dice_loss = pDLoss(num_classes, ignore_index=4)
edge_loss_function = nn.MSELoss()
avg_meter = pyutils.AverageMeter('loss')

best_performance = 0.0
best_epoch = 0
iter_num = 0
max_epoches = 200

max_iterations = max_epoches * len(trainloader)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100 * len(trainloader))

In [None]:
import numpy as np
import torch
from medpy import metric
from scipy.ndimage import zoom


def calculate_metric_percase(pred, gt):
    pred[pred > 0] = 1
    gt[gt > 0] = 1
    if gt.sum() == 0 and pred.sum() == 0:
        return np.nan, np.nan
    elif gt.sum() == 0 and pred.sum() > 0:
        return 0, 0
    elif gt.sum() > 0:
        dice = metric.binary.dc(pred, gt)
        if pred.sum() == 0:
            hd95 = np.nan
        else:
            hd95 = metric.binary.hd95(pred, gt)
        return dice, hd95


@torch.no_grad()
def test_single_volume_for_training(image, label, net, classes, patch_size=[256, 256]):
    image, label = image.squeeze(0).cpu().detach(
    ).numpy(), label.squeeze(0).cpu().detach().numpy()
    if len(image.shape) == 3:
        prediction_1 = np.zeros_like(label)
        prediction_2 = np.zeros_like(label)
        for ind in range(image.shape[0]):
            slice = image[ind, :, :]
            x, y = slice.shape[0], slice.shape[1]
            slice = zoom(
                slice, (patch_size[0] / x, patch_size[1] / y), order=0)
            input = torch.from_numpy(slice).unsqueeze(
                0).unsqueeze(0).float().cuda()
            net.eval()
            with torch.no_grad():
                P1,P2,_ = net(input)
#                 val_outputs = 0.0
#                 for idx in range(len(P)):
#                     val_outputs += P[idx]

                iout_soft1 = torch.softmax(P1, dim=1)
                iout_soft2 = torch.softmax(P2, dim=1)
                #iout_soft = torch.softmax(P, dim=1)
                out_2 = torch.argmax((iout_soft2+iout_soft1), dim=1).squeeze(0)
                out_1 = torch.argmax(iout_soft1, dim=1).squeeze(0)
                out_2 = out_2.cpu().detach().numpy()
                out_1 = out_1.cpu().detach().numpy()
                
                pred_1 = zoom(
                    out_1, (x / patch_size[0], y / patch_size[1]), order=0)
                prediction_1[ind] = pred_1
                
                pred_2 = zoom(
                    out_2, (x / patch_size[0], y / patch_size[1]), order=0)
                prediction_2[ind] = pred_2
                
    else:
        input = torch.from_numpy(image).float().cuda()
        net.eval()
        with torch.no_grad():
            P = net(input)
            # val_outputs = 0.0
            # for idx in range(len(P)):
            #     val_outputs += P[idx]

            iout_soft = torch.softmax(P, dim=1)
            out = torch.argmax(iout_soft, dim=1).squeeze(0)
            prediction = out.cpu().detach().numpy()

    metric_list_one = []
    metric_list_two = []
    for i in range(1, classes):
        metric_list_one.append(calculate_metric_percase(
            prediction_1 == i, label == i))
        metric_list_two.append(calculate_metric_percase(
            prediction_2 == i, label == i))
    return metric_list_one, metric_list_two

In [None]:
from progress_table import ProgressTable
table = ProgressTable(
    pbar_embedded=False,  # Do not use embedded pbar
    pbar_style="angled alt red blue",
    interactive=1,
    pbar_show_eta=True,
)
table.add_columns("epoch")
table.add_columns("train_loss")
table.add_columns("dice_one")
table.add_columns("dice_two")
table.add_columns("hd95_one")
table.add_columns("hd95_two")
table.add_columns("best_model")

In [None]:
## Add early stopping if necessary
max_epochs_without_improvement = 50
epochs_without_improvement = 0

best_model_path = ""

In [None]:
from tqdm.auto import tqdm

import random
import torch.nn.functional as F

model.to(device)
model.train()
iter_num = 0
#for ep in range(max_epoches):
for ep in table(max_epoches, show_throughput=False, show_eta=True):
    table["epoch"] = ep
    #for iter, sampled_batch in enumerate(trainloader):
    for sampled_batch in table(trainloader, description="train epoch"):
        img, label, groundtruth_edge = sampled_batch['image'], sampled_batch['label'], sampled_batch['edge_mask']
        img, label, groundtruth_edge = img.to(device), label.to(device), groundtruth_edge.to(device)

        output_main, output_aux, edge_map = model(img)
        
        
        outputs_soft1 = torch.softmax(output_main, dim=1)
        outputs_soft2 = torch.softmax(output_aux, dim=1)
        
        beta = random.random() + 1e-10

        pseudo_supervision = torch.argmax(
                (beta * outputs_soft1.detach() + (1.0-beta) * outputs_soft2.detach()), dim=1, keepdim=False)

        loss_pse_sup = 0.5 * (dice_loss(outputs_soft1, pseudo_supervision.unsqueeze(
                1)) + dice_loss(outputs_soft2, pseudo_supervision.unsqueeze(1)))

            
        loss_ce1 = ce_loss(output_main, label[:].long())
        loss_ce2 = ce_loss(output_aux, label[:].long())
        loss_ce = 0.5 * (loss_ce1 + loss_ce2)  # <--- loss ce
        
        edge_loss = edge_loss_function(edge_map, groundtruth_edge)  # <--- loss edge

        loss = loss_ce + 0.5 * loss_pse_sup + 0.2 * edge_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_meter.add({'loss': loss.item()})
        
        scheduler.step()
        iter_num += 1
    else:
        table.update("train_loss", avg_meter.get('loss'), color="blue")

        model.eval()
        metric_list_one = []
        metric_list_two = []
        for sampled_batch in table(valloader, description="valid epoch"):
            metric_i_one, metric_i_two = test_single_volume_for_training(
                sampled_batch["image"], sampled_batch["label"], model, classes=num_classes)
            metric_list_one.append(metric_i_one)
            metric_list_two.append(metric_i_two)
        
        metric_list_one = np.nanmean(np.array(metric_list_one), axis=0)
        metric_list_two = np.nanmean(np.array(metric_list_two), axis=0)
            
                 
        performance = None
        mean_hd95 = None
        
        performance_one = np.mean(metric_list_one, axis=0)[0]
        mean_hd95_one = np.mean(metric_list_one, axis=0)[1]
        
        performance_two = np.mean(metric_list_two, axis=0)[0]
        mean_hd95_two = np.mean(metric_list_two, axis=0)[1]

       
        if performance_one > best_performance or performance_two > best_performance:
            table["best_model"] = "âœ…"
            if performance_one > performance_two:
                performance = performance_one
                mean_hd95 = mean_hd95_one
            else:
                performance = performance_two
                mean_hd95 = mean_hd95_two
            
            best_performance = performance
            best_epoch = ep
            epochs_without_improvement = 0
            save_best = "/teamspace/studios/this_studio/acdc_bestmodel/Qmaxvitunet_aaa_test.pth"
            torch.save(model.state_dict(), save_best)
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= max_epochs_without_improvement:
                print(f'Early stopping at epoch {ep} as validation performance did not improve.')
                break
       
        table.update("dice_one", performance_one, color="green")
        table.update("hd95_one", mean_hd95_one, color="green")
        table.update("dice_two", performance_two, color="blue")
        table.update("hd95_two", mean_hd95_two, color="blue")
        model.train()
        avg_meter.pop()
        table.next_row()
print('best model in epoch %5d  mean_dice : %.4f' % (best_epoch, best_performance))

### Setup Evaluation

In [None]:
test_set_acdc = ACDCDataset_Edge(
    base_dir="/teamspace/studios/this_studio/ACDC_preprocessed",
    split='test',
    fold="MAAGfold",
    transform=None,
)

testloader_acdc = DataLoader(
    test_set_acdc,
    batch_size=1,
    shuffle=False,
    num_workers=0)

In [None]:
model_inference = QEMaxViT_Unet(num_classes=4, backbone_pretrained_pth="/teamspace/studios/this_studio/MIST/pretrained_pth/maxvit/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth")
model_inference.load_state_dict(torch.load("/teamspace/studios/this_studio/acdc_bestmodel/Qmaxvitunet_final_fold4.pth"))

In [None]:
import pandas as pd
import h5py
from tqdm.auto import tqdm
model_inference.to(device)
model_inference.eval()
# model.eval()
metric_list_one = []
metric_list_two = []
for i_batch, sampled_batch in enumerate(tqdm(testloader_acdc)):
    metric_i_one, metric_i_two = test_single_volume_for_training(
        sampled_batch["image"], sampled_batch["label"], model_inference, classes=4)
    metric_list_one.append(metric_i_one)
    metric_list_two.append(metric_i_two)

metric_list_one = np.nanmean(np.array(metric_list_one), axis=0)
metric_list_two = np.nanmean(np.array(metric_list_two), axis=0)

df_one = pd.DataFrame(metric_list_one, columns=['Dice', 'HD95'], index=['RV', 'Myo', 'LV'])
df_two = pd.DataFrame(metric_list_two, columns=['Dice', 'HD95'], index=['RV', 'Myo', 'LV'])

In [None]:
print(df_one.round(3))
print(df_two.round(3))
df_one['Dice'].mean(), df_one['HD95'].mean(), df_two['Dice'].mean(), df_two['HD95'].mean()

### Inference

In [None]:

@torch.no_grad()
def inference(image, label, net, classes, patch_size=[256, 256]):
    image, label = image.squeeze(0).cpu().detach(
    ).numpy(), label.squeeze(0).cpu().detach().numpy()
    if len(image.shape) == 3:
        prediction_1 = np.zeros_like(label)
        prediction_2 = np.zeros_like(label)
        for ind in range(image.shape[0]):
            slice = image[ind, :, :]
            x, y = slice.shape[0], slice.shape[1]
            slice = zoom(
                slice, (patch_size[0] / x, patch_size[1] / y), order=0)
            input = torch.from_numpy(slice).unsqueeze(
                0).unsqueeze(0).float().cuda()
            net.eval()
            with torch.no_grad():
                P1,P2,_ = net(input)
#                 val_outputs = 0.0
#                 for idx in range(len(P)):
#                     val_outputs += P[idx]

                iout_soft1 = torch.softmax(P1, dim=1)
                iout_soft2 = torch.softmax(P2, dim=1)
                #iout_soft = torch.softmax(P, dim=1)
                out_2 = torch.argmax(iout_soft2, dim=1).squeeze(0)
                out_1 = torch.argmax(iout_soft1, dim=1).squeeze(0)
                out_2 = out_2.cpu().detach().numpy()
                out_1 = out_1.cpu().detach().numpy()
                
                pred_1 = zoom(
                    out_1, (x / patch_size[0], y / patch_size[1]), order=0)
                prediction_1[ind] = pred_1
                
                pred_2 = zoom(
                    out_2, (x / patch_size[0], y / patch_size[1]), order=0)
                prediction_2[ind] = pred_2
                
    return prediction_1, prediction_2

In [None]:
model_inference.cuda()
model_inference.eval()
for i_batch, sampled_batch in enumerate(tqdm(testloader_acdc)):
    if i_batch == 200:
        break
    prediction_1, prediction_2 = inference(sampled_batch["image"], sampled_batch["label"], model_inference, 4)
    
    plt.figure(figsize=(10,10))
    plt.subplot(1,4,1)
    plt.imshow(sampled_batch['image'].squeeze(0).cpu().numpy()[2], cmap='gray')
    plt.axis("off")
    plt.subplot(1,4,2)
    plt.imshow(sampled_batch['label'].squeeze(0).cpu().numpy()[2])
    plt.axis("off")
    plt.subplot(1,4,3)
    plt.imshow(prediction_1[2])
    plt.axis("off")
    plt.subplot(1,4,4)
    plt.imshow(prediction_2[2])
    plt.axis("off")
    plt.show()