## Training UNets in Pytorch

In [1]:

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from torchsummary import summary
    
import os
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim

In [2]:
! mkdir saved_images

mkdir: cannot create directory ‘saved_images’: File exists


In [3]:
import segmentation_models_pytorch as smp

def get_pretrained_unet(input_channel, output_channel):
    model = smp.Unet(
        encoder_name="resnet34", # Choose the encoder architecture
        encoder_weights="imagenet", # Use pre-trained weights from ImageNet
        in_channels=input_channel,
        classes=output_channel,
        activation=None, # No activation function for the final layer
    )
    return model


In [4]:
import segmentation_models_pytorch as smp

def get_pretrained_unetplusplus(input_channel, output_channel):
    model = smp.UnetPlusPlus(
        encoder_name="resnet34", # Choose the encoder architecture
        encoder_weights="imagenet", # Use pre-trained weights from ImageNet
        in_channels=input_channel,
        classes=output_channel,
        activation=None, # No activation function for the final layer
    )
    return model


In [5]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset

class LitsDataset(Dataset):
    def __init__(self, csv_file, base_path, type="train", split_ratio=0.2, transform=None):
        self.df = pd.read_csv(csv_file)
        self.base_path = base_path
        self.transform = transform
        
        # Filter out liver_mask_empty rows
        self.df = self.df[self.df['tumor_mask_empty'].isin(['TRUE', True])]
        
        # Train-test split
        if type=="train":
            self.df = self.df.iloc[:int(len(self.df)*(1-split_ratio))]

        elif type=="val":
            self.df = self.df.iloc[int(len(self.df)*(1-split_ratio)):]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        img_path = os.path.join(self.base_path, row['filepath'])
        image = np.array(Image.open(img_path).convert("RGB"), dtype=np.float32)  # Cast to float32
        image = np.transpose(image, (2, 0, 1))

        # Load tumor and liver masks
        #tumor_mask_path = os.path.join(self.base_path, row['tumor_maskpath'])
        liver_mask_path = os.path.join(self.base_path, row['tumor_maskpath'])
        #tumor_mask = np.array(Image.open(tumor_mask_path).point(lambda x: x * 255).convert("L"), dtype=np.float32)
        liver_mask = np.array(Image.open(liver_mask_path).point(lambda x: x * 255).convert("L"), dtype=np.float32)

        # Set mask values
        #tumor_mask[tumor_mask == 255.0] = 1.0
        liver_mask[liver_mask == 255.0] = 1.0

        # Create a two-channel mask
        #mask = np.stack((liver_mask , tumor_mask), axis=0)

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, liver_mask


In [6]:
import torchvision
from torch.utils.data import DataLoader

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

from torch.utils.data import DataLoader

def get_loaders(
    csv_file,
    base_path,
    batch_size=8,
    train_transform=None,
    val_transform=None,
    num_workers=4,
    pin_memory=True,):

    train_ds = LitsDataset(
        csv_file=csv_file,
        base_path=base_path,
        type="train",
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = LitsDataset(
        csv_file=csv_file,
        base_path=base_path,
        type="val",
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )
    
    test_ds = LitsDataset(
        csv_file="/apps/local/shared/HC701/assessment/assignment_3/data/hc701_lits_test.csv",
        base_path=base_path,
        type="test",
        transform=val_transform,
    )

    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )
    

    return train_loader, val_loader, test_loader



# def check_accuracy(loader, model, device="cuda"):
#     num_correct = 0
#     num_pixels = 0
#     dice_score = 0
#     model.eval()

#     with torch.no_grad():
#         for x, y in loader:
#             x = x.to(device)
#             y = y.to(device).unsqueeze(1)
#             preds = torch.sigmoid(model(x))
# #             print(preds.shape)
#             preds = (preds > 0.5).float()
#             num_correct += (preds == y).sum()
#             num_pixels += torch.numel(preds)
#             dice_score += (2 * (preds * y).sum()) / (
#                 (preds + y).sum() + 1e-8
#             )

#     print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
#     print(f"Dice score: {dice_score/len(loader)}")
#     model.train()

def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"):
    
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
#             preds.shape
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

In [7]:
import torch
from scipy.spatial.distance import directed_hausdorff
import numpy as np

def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    jaccard_score = 0
    precision = 0
    recall = 0
    hausdorff_distance = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()

            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)

            intersection = (preds * y).sum()
            union = (preds + y).sum() - intersection

            dice_score += (2 * intersection) / (preds.sum() + y.sum() + 1e-8)
            jaccard_score += intersection / (union + 1e-8)
            precision += intersection / (preds.sum() + 1e-8)
            recall += intersection / (y.sum() + 1e-8)

            # Calculate Hausdorff distance
            y_np = y.squeeze().cpu().numpy()
            preds_np = preds.squeeze().cpu().numpy()
            
            hausdorff_slices = []
            for idx in range(y_np.shape[0]):
                hd1 = directed_hausdorff(y_np[idx], preds_np[idx])[0]
                hd2 = directed_hausdorff(preds_np[idx], y_np[idx])[0]
                hausdorff_slices.append(max(hd1, hd2))
            hausdorff_distance += np.mean(hausdorff_slices)

    n = len(loader)
    print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
    print(f"Dice score: {dice_score/n}")
    print(f"Jaccard score: {jaccard_score/n}")
    print(f"Precision: {precision/n}")
    print(f"Recall: {recall/n}")
    print(f"F1-score: {(2 * precision * recall) / (precision + recall + 1e-8)/n}")
    print(f"Mean Hausdorff distance: {hausdorff_distance/n}")
    model.train()


In [8]:
# hyperparams
lr = 1e-6
dev = "cuda"
batch_size = 8
epochs = 50
workers= 8
img_h = 256
img_w = 256
pin_mem= True
load_model = False


In [9]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=dev)
        targets = targets.float().unsqueeze(1).to(device=dev)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())
        

In [10]:
import torch.nn.functional as F
from segmentation_models_pytorch.losses import DiceLoss, FocalLoss

class CombinedLoss(torch.nn.Module):
    def __init__(self, alpha=0.5, gamma=2, weight=None, reduction='mean', focal_weight=0.5, dice_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight
        self.reduction = reduction
        self.focal_weight = focal_weight
        self.dice_weight = dice_weight
        self.focal_loss = FocalLoss(mode='binary', alpha=alpha, gamma=gamma)
        self.dice_loss = DiceLoss(mode='binary')

    def forward(self, logits, targets):
        focal_loss = self.focal_loss(logits, targets)
        dice_loss = self.dice_loss(logits, targets)
        return self.focal_weight * focal_loss + self.dice_weight * dice_loss


In [11]:
def eval_fn(loader, model, device, criterion):
    model.eval()
    with torch.no_grad():
        total_loss = 0.0
        for data in loader:
            inputs, targets = data
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item() * inputs.size(0)
        
        avg_loss = total_loss / len(loader.dataset)
        
    return avg_loss


In [13]:
train_transform = A.Compose(
    [
        A.Resize(height=img_h, width=img_w),
        A.RandomRotate90(),
        A.Flip(),
        A.Transpose(),
        A.OneOf([
            A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
            A.GridDistortion(),
            A.OpticalDistortion(distort_limit=2, shift_limit=0.5),
        ], p=0.3),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

val_transforms = A.Compose(
    [
        A.Resize(height=img_h, width=img_w),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)


#model = UNet(input_channel=3, output_channel=1).to(dev)
input_channel = 3
output_channel = 1
#model = get_pretrained_unet(input_channel, output_channel).to(dev)
model = get_pretrained_unetplusplus(input_channel, output_channel).to(dev)

# weights = torch.tensor([1, 10], dtype=torch.float32).to(dev)
# loss_fn = nn.BCEWithLogitsLoss(pos_weight=weights)

from segmentation_models_pytorch.losses import DiceLoss
#loss_fn = DiceLoss(mode='binary')
#loss_fn = nn.BCEWithLogitsLoss()
loss_fn = CombinedLoss(alpha=0.25, gamma=2, focal_weight=0.2, dice_weight=0.8)


#optimizer = optim.AdamW(model.parameters(), lr=lr)
optimizer = optim.AdamW(model.parameters(), lr=1e-4 ,  weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)


train_loader, val_loader , test_loader= get_loaders(
    csv_file="/apps/local/shared/HC701/assessment/assignment_3/data/hc701_lits_train.csv",
    base_path="/apps/local/shared/HC701/assessment/assignment_3/data/",
    batch_size=8,
    train_transform=None,
    val_transform=None,
)


# if LOAD_MODEL:
#     load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)


#check_accuracy(val_loader, model, device=dev)
scaler = torch.cuda.amp.GradScaler()
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
for epoch in range(epochs):
    print(epoch)
    train_fn(train_loader, model, optimizer, loss_fn, scaler)
    val_loss = eval_fn(val_loader, model, device=dev, criterion=loss_fn)
    scheduler.step(val_loss)

    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer":optimizer.state_dict(),
    }
    #save_checkpoint(checkpoint)
    check_accuracy(val_loader, model, device=dev)
    # check accuracy
    # print('Train Performance' , check_accuracy(train_loader, model, device=dev))
    # print('Val Performance' , check_accuracy(val_loader, model, device=dev))
    # print ('Test Performance' , check_accuracy(test_loader, model, device=dev))

    # print some examples to a folder
save_predictions_as_imgs(val_loader, model, folder="saved_images/", device=dev)

0


100%|██████████| 414/414 [00:38<00:00, 10.62it/s, loss=0.527]


Got 53844691/54198272 with acc 99.35
Dice score: 0.39618411660194397
Jaccard score: 0.29805994033813477
Precision: 0.4693100154399872
Recall: 0.41527435183525085
F1-score: 0.4406417906284332
Mean Hausdorff distance: 3.516203421837848
1


100%|██████████| 414/414 [00:39<00:00, 10.37it/s, loss=0.138] 


Got 53902054/54198272 with acc 99.45
Dice score: 0.4458152949810028
Jaccard score: 0.3377199172973633
Precision: 0.5604496002197266
Recall: 0.41336002945899963
F1-score: 0.4757961928844452
Mean Hausdorff distance: 3.2261187364341097
2


100%|██████████| 414/414 [00:37<00:00, 11.01it/s, loss=0.161] 


Got 53925652/54198272 with acc 99.50
Dice score: 0.45633387565612793
Jaccard score: 0.3463192582130432
Precision: 0.6343935132026672
Recall: 0.40416207909584045
F1-score: 0.4937584400177002
Mean Hausdorff distance: 3.1344488248634854
3


100%|██████████| 414/414 [00:38<00:00, 10.73it/s, loss=0.526] 


Got 53904091/54198272 with acc 99.46
Dice score: 0.4370352029800415
Jaccard score: 0.3271198868751526
Precision: 0.6295318007469177
Recall: 0.37877199053764343
F1-score: 0.472970575094223
Mean Hausdorff distance: 3.1651654696579623
4


100%|██████████| 414/414 [00:39<00:00, 10.49it/s, loss=0.141] 


Got 53936825/54198272 with acc 99.52
Dice score: 0.4832511246204376
Jaccard score: 0.3716176152229309
Precision: 0.6555336713790894
Recall: 0.4295060634613037
F1-score: 0.5189776420593262
Mean Hausdorff distance: 3.078646724885734
5


100%|██████████| 414/414 [00:38<00:00, 10.72it/s, loss=0.305] 


Got 53946370/54198272 with acc 99.54
Dice score: 0.4785108268260956
Jaccard score: 0.3692352771759033
Precision: 0.6568892002105713
Recall: 0.42317378520965576
F1-score: 0.5147445797920227
Mean Hausdorff distance: 3.084302496560128
6


100%|██████████| 414/414 [00:38<00:00, 10.73it/s, loss=0.0809]


Got 53946280/54198272 with acc 99.54
Dice score: 0.490526020526886
Jaccard score: 0.3792916536331177
Precision: 0.6631592512130737
Recall: 0.447778582572937
F1-score: 0.5345906615257263
Mean Hausdorff distance: 3.079715257763873
7


100%|██████████| 414/414 [00:38<00:00, 10.81it/s, loss=0.19]  


Got 53951556/54198272 with acc 99.54
Dice score: 0.4850725829601288
Jaccard score: 0.3759429156780243
Precision: 0.62287837266922
Recall: 0.4574204981327057
F1-score: 0.5274786949157715
Mean Hausdorff distance: 3.0815583058642297
8


100%|██████████| 414/414 [00:36<00:00, 11.26it/s, loss=0.499] 


Got 53952994/54198272 with acc 99.55
Dice score: 0.4980138838291168
Jaccard score: 0.3852366805076599
Precision: 0.6483213901519775
Recall: 0.4509125053882599
F1-score: 0.5318908095359802
Mean Hausdorff distance: 3.0443009715888607
9


100%|██████████| 414/414 [00:36<00:00, 11.24it/s, loss=0.8]   


Got 53939610/54198272 with acc 99.52
Dice score: 0.48976922035217285
Jaccard score: 0.3729501962661743
Precision: 0.6477521061897278
Recall: 0.4499916136264801
F1-score: 0.5310584306716919
Mean Hausdorff distance: 3.0806150894735795
10


100%|██████████| 414/414 [00:38<00:00, 10.79it/s, loss=0.0293]


Got 53838361/54198272 with acc 99.34
Dice score: 0.37663576006889343
Jaccard score: 0.2673940062522888
Precision: 0.6790729761123657
Recall: 0.2861822545528412
F1-score: 0.4026678800582886
Mean Hausdorff distance: 3.3175428468860195
11


100%|██████████| 414/414 [00:38<00:00, 10.83it/s, loss=0.0864]


Epoch 00012: reducing learning rate of group 0 to 1.0000e-05.
Got 53926820/54198272 with acc 99.50
Dice score: 0.4695315957069397
Jaccard score: 0.35701942443847656
Precision: 0.669028103351593
Recall: 0.41440054774284363
F1-score: 0.5117930173873901
Mean Hausdorff distance: 3.0997002876728703
12


100%|██████████| 414/414 [00:37<00:00, 11.05it/s, loss=0.0703]


Got 53943001/54198272 with acc 99.53
Dice score: 0.49031326174736023
Jaccard score: 0.3767741918563843
Precision: 0.6590667963027954
Recall: 0.4403355121612549
F1-score: 0.5279423594474792
Mean Hausdorff distance: 3.057276864498403
13


100%|██████████| 414/414 [00:38<00:00, 10.81it/s, loss=0.0191]


Got 53940980/54198272 with acc 99.53
Dice score: 0.4928587079048157
Jaccard score: 0.378490686416626
Precision: 0.6581636071205139
Recall: 0.4441027343273163
F1-score: 0.5303477644920349
Mean Hausdorff distance: 3.0589454074329336
14


100%|██████████| 414/414 [00:38<00:00, 10.84it/s, loss=0.0735]


Epoch 00015: reducing learning rate of group 0 to 1.0000e-06.
Got 53941152/54198272 with acc 99.53
Dice score: 0.49015742540359497
Jaccard score: 0.376149445772171
Precision: 0.6715083718299866
Recall: 0.4349646270275116
F1-score: 0.5279520750045776
Mean Hausdorff distance: 3.059608082510437
15


100%|██████████| 414/414 [00:38<00:00, 10.66it/s, loss=0.183] 


Got 53938851/54198272 with acc 99.52
Dice score: 0.4891809821128845
Jaccard score: 0.3746798038482666
Precision: 0.6682640910148621
Recall: 0.43521440029144287
F1-score: 0.5271297097206116
Mean Hausdorff distance: 3.065718228967846
16


100%|██████████| 414/414 [00:38<00:00, 10.66it/s, loss=0.0336]


Got 53940169/54198272 with acc 99.52
Dice score: 0.4917598068714142
Jaccard score: 0.37701061367988586
Precision: 0.6685816049575806
Recall: 0.4377239942550659
F1-score: 0.5290657877922058
Mean Hausdorff distance: 3.0607813806824327
17


100%|██████████| 414/414 [00:38<00:00, 10.88it/s, loss=0.12]  


Epoch 00018: reducing learning rate of group 0 to 1.0000e-07.
Got 53941569/54198272 with acc 99.53
Dice score: 0.49292758107185364
Jaccard score: 0.378203809261322
Precision: 0.6681854128837585
Recall: 0.44137042760849
F1-score: 0.5315952301025391
Mean Hausdorff distance: 3.0624466717037495
18


100%|██████████| 414/414 [00:39<00:00, 10.55it/s, loss=0.0202]


Got 53940916/54198272 with acc 99.53
Dice score: 0.49105319380760193
Jaccard score: 0.3766283392906189
Precision: 0.6781863570213318
Recall: 0.4363296329975128
F1-score: 0.5310158133506775
Mean Hausdorff distance: 3.06506997409143
19


100%|██████████| 414/414 [00:39<00:00, 10.55it/s, loss=0.01]  


Got 53939072/54198272 with acc 99.52
Dice score: 0.4917924106121063
Jaccard score: 0.37674200534820557
Precision: 0.6665008664131165
Recall: 0.4390020966529846
F1-score: 0.5293432474136353
Mean Hausdorff distance: 3.0621590729886785
20


100%|██████████| 414/414 [00:38<00:00, 10.66it/s, loss=0.0342]


Epoch 00021: reducing learning rate of group 0 to 1.0000e-08.
Got 53940049/54198272 with acc 99.52
Dice score: 0.4953796863555908
Jaccard score: 0.3796701431274414
Precision: 0.6726676225662231
Recall: 0.4428321123123169
F1-score: 0.5340724587440491
Mean Hausdorff distance: 3.0544373477997078
21


100%|██████████| 414/414 [00:39<00:00, 10.43it/s, loss=0.8]   


Got 53941179/54198272 with acc 99.53
Dice score: 0.4929441511631012
Jaccard score: 0.37798187136650085
Precision: 0.6726890206336975
Recall: 0.43988901376724243
F1-score: 0.5319330096244812
Mean Hausdorff distance: 3.063014041177901
22


100%|██████████| 414/414 [00:38<00:00, 10.84it/s, loss=0.0924]


Got 53941612/54198272 with acc 99.53
Dice score: 0.49398699402809143
Jaccard score: 0.3790588974952698
Precision: 0.671692967414856
Recall: 0.44139933586120605
F1-score: 0.5327227711677551
Mean Hausdorff distance: 3.058523475653625
23


  3%|▎         | 13/414 [00:01<00:53,  7.52it/s, loss=0.048] 


KeyboardInterrupt: 

In [14]:
save_predictions_as_imgs(val_loader, model, folder="saved_images/", device=dev)

In [16]:
#print('Train Performance' , check_accuracy(train_loader, model, device=dev))
print('Val Performance' , check_accuracy(val_loader, model, device=dev))
print ('Test Performance' , check_accuracy(test_loader, model, device=dev))


Got 53939751/54198272 with acc 99.52
Dice score: 0.4922533333301544
Jaccard score: 0.3772498071193695
Precision: 0.6656484007835388
Recall: 0.4380527436733246
F1-score: 0.5283842086791992
Mean Hausdorff distance: 3.0613920776508388
Val Performance None
Got 47897156/48562176 with acc 98.63
Dice score: 0.35177478194236755
Jaccard score: 0.2671577036380768
Precision: 0.6356946229934692
Recall: 0.2781559228897095
F1-score: 0.38698282837867737
Mean Hausdorff distance: 4.826022132369212
Test Performance None


In [17]:
 check_accuracy(test_loader, model, device=dev)

Got 47897156/48562176 with acc 98.63
Dice score: 0.35177478194236755
Jaccard score: 0.2671577036380768
Precision: 0.6356946229934692
Recall: 0.2781559228897095
F1-score: 0.38698282837867737
Mean Hausdorff distance: 4.826022132369212


In [None]:
#Only Liver

# Got 634087309/635895808 with acc 99.72
# Dice score: 0.978397011756897
# Jaccard score: 0.9577668905258179
# Precision: 0.979331374168396
# Recall: 0.9775193333625793
# F1-score: 0.9784245491027832
# Mean Hausdorff distance: 2.6984509143702238
# Train Performance None

# Got 157943585/158990336 with acc 99.34
# Dice score: 0.9115058779716492
# Jaccard score: 0.8643732666969299
# Precision: 0.9210348725318909
# Recall: 0.9184235334396362
# F1-score: 0.9197274446487427
# Mean Hausdorff distance: 3.3960617562368527
# Val Performance None

# Got 77909008/78643200 with acc 99.07
# Dice score: 0.9234818816184998
# Jaccard score: 0.8717976808547974
# Precision: 0.9464868307113647
# Recall: 0.9146184325218201
# F1-score: 0.9302798509597778
# Mean Hausdorff distance: 3.8074717452253646
# Test Performance None

In [None]:
# Only Tumor

# Got 53939751/54198272 with acc 99.52
# Dice score: 0.4922533333301544
# Jaccard score: 0.3772498071193695
# Precision: 0.6656484007835388
# Recall: 0.4380527436733246
# F1-score: 0.5283842086791992
# Mean Hausdorff distance: 3.0613920776508388
# Val Performance None

# Got 47897156/48562176 with acc 98.63
# Dice score: 0.35177478194236755
# Jaccard score: 0.2671577036380768
# Precision: 0.6356946229934692
# Recall: 0.2781559228897095
# F1-score: 0.38698282837867737
# Mean Hausdorff distance: 4.826022132369212