# Install

In [1]:
%pip install -q "monai-weekly[tqdm, nibabel, gdown, ignite]"

  from cryptography.utils import int_from_bytes
Note: you may need to restart the kernel to use updated packages.


# !python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline
!pip install einops
!python -c "import monai" || pip install 'monai[all]'

# Import libraries

In [2]:
import os
import warnings
import datetime
warnings.filterwarnings("ignore") # remove some scikit-image warnings

from torch.utils.tensorboard import SummaryWriter

import monai
# monai.config.print_config()

from monai.apps import DecathlonDataset
from monai.data import DataLoader, CacheDataset, decollate_batch
from monai.utils import first, set_determinism
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss, DiceCELoss
from monai.inferers import sliding_window_inference
from monai.networks.nets import UNETR
from monai.transforms import (
    LoadImage,
    LoadImageD,
    EnsureChannelFirstD,
    ScaleIntensityD,
    ToTensorD,
    Compose,
    AsDiscreteD,
    SpacingD,
    OrientationD,
    ResizeD,
    RandAffineD,
    AsDiscrete,
    AsDiscreted,
    EnsureTyped,
    EnsureType,
    LoadImageD,
    EnsureChannelFirstD,
    OrientationD,
    SpacingD,
    ScaleIntensityD,
    ResizeD,
    RandAffineD,
    RandFlipD,
    RandRotateD,
    RandZoomD,
    ToTensorD,
)

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import csv


from monai.utils.misc import set_determinism

set_determinism(seed = 42)

In [3]:
cuda = torch.cuda.is_available()

device = torch.device("cuda" if cuda else "cpu")
num_workers = 4 if cuda else 0

print('You are using gpu if true, cpu if false:', cuda)

You are using gpu if true, cpu if false: True


## Load training set

In [4]:
KEYS = ["image", "label"]

def get_dataset():# Set up transform
    transform = Compose([
        LoadImageD(keys = KEYS),
        EnsureChannelFirstD(keys = KEYS),
        OrientationD(KEYS, axcodes='RAS'),
        SpacingD(keys = KEYS,
                 pixdim = (1., 1., 1.),
                 mode = ('bilinear', 'nearest')),
        ScaleIntensityD(keys = "image"),
        ResizeD(KEYS,
                (128, 128, 64),
                mode=('trilinear', 'nearest')),
        RandAffineD(keys = KEYS,
                    spatial_size = (128, 128, 64),
                    rotate_range = (0, 0, np.pi/12),
                    scale_range = (0.1, 0.1, 0.1),
                    mode = ('bilinear', 'nearest'),
                    prob = 1.0),
        ToTensorD(KEYS),
    ])

    # Load data
    dataset = DecathlonDataset(root_dir = "./",
                               task = "Task06_Lung", section = "training",
                               transform = transform, download = False)

    train_loader = DataLoader(dataset, batch_size = 4, shuffle = True, num_workers = 4)
    return (train_loader, dataset)

## Load validation set

In [5]:

def get_validation_indexes():
    list_val_dataset = []
    csv_file = 'val_dataset_NEW_64slices.csv'

    with open(csv_file, 'r', newline='', encoding='utf-8') as file:
        # Create a CSV reader object
        reader = csv.reader((line.replace('\0', '') for line in file))
        
        # Iterate over each row in the CSV file
        for row in reader:
            # Append the row to the list
            list_val_dataset.append(row)
            
    list_val_dataset = [[int(num) for num in row] for row in list_val_dataset]
    return list_val_dataset

val_transform = Compose([
    LoadImageD(keys = KEYS),
    EnsureChannelFirstD(keys = KEYS),
    OrientationD(KEYS, axcodes='RAS'),
    ScaleIntensityD(keys = "image"),
])

# Load validation data
val_dataset = DecathlonDataset(root_dir = "./",
                           task = "Task06_Lung", section = "validation",
                           transform = val_transform, download = False)
list_val_dataset = get_validation_indexes()

for idx in range(len(val_dataset)):
    # Access individual elements by index
    sample = val_dataset[idx]
    
    # Access image and label from the sample
    image = sample["image"]
    label = sample["label"]
    
    # (1, 512, 512, <no_slices>)
    print(image.shape)
    print(label.shape)
    
    image_slices = list_val_dataset[idx]
    print("Index " + str(idx))
    print(image_slices)
    
    image = image[..., image_slices]
    label = label[..., image_slices]
    
    sample["image"] = image
    sample["label"] = label

# Apply the remaining tranforms
KEYS = ["image", "label"]
remaining_val_transform = Compose([
    SpacingD(keys = KEYS,
             pixdim = (1., 1., 1.),
             mode = ('bilinear', 'nearest')),
    ResizeD(KEYS,
            (128, 128, 64),
            mode=('trilinear', 'nearest')),
    ToTensorD(KEYS),
])


for idx in range(len(val_dataset)):
    # Access individual elements by index
    sample = val_dataset[idx]
    transformed_sample = remaining_val_transform(sample)
    
    sample["image"] = transformed_sample["image"]
    sample["label"] = transformed_sample["label"]
    
    image = sample["image"]
    label = sample["label"]
    
    
    print(image.shape)
    print(label.shape)

val_loader = DataLoader(val_dataset, batch_size = 1, shuffle = True, num_workers = 4)


Loading dataset: 100%|██████████| 12/12 [00:47<00:00,  3.95s/it]


(1, 512, 512, 369)
(1, 512, 512, 369)
Index 0
[259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322]
(1, 512, 512, 255)
(1, 512, 512, 255)
Index 1
[180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243]
(1, 512, 512, 288)
(1, 512, 512, 288)
Index 2
[110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,

# Training Modified

In [6]:
def get_train_indexes():
    list_dataset = []
    csv_file = 'dataset_NEW_64slices.csv'

    with open(csv_file, 'r', newline='', encoding='utf-8') as file:
        reader = csv.reader((line.replace('\0', '') for line in file))
        for row in reader:
            list_dataset.append(row)
            
    list_dataset = [[int(num) for num in row] for row in list_dataset]
    return list_dataset


In [7]:
def get_new_dataset():

    KEYS = ["image", "label"]
    transform = Compose([
        LoadImageD(keys = KEYS),
        EnsureChannelFirstD(keys = KEYS),
        OrientationD(KEYS, axcodes='RAS'),
        ScaleIntensityD(keys = "image"),
    ])

    dataset_modified_training = DecathlonDataset(root_dir = "./",
                               task = "Task06_Lung", section = "training", transform=transform,
                               download=False)
    list_dataset = get_train_indexes()
    for idx in range(len(dataset_modified_training)):
        sample = dataset_modified_training[idx]
        image = sample["image"]
        label = sample["label"]

        image_slices = list_dataset[idx]

        image = image[..., image_slices]
        label = label[..., image_slices]

        sample["image"] = image
        sample["label"] = label


    KEYS = ["image", "label"]
    remaining_transform = Compose([
        SpacingD(keys = KEYS,
                 pixdim = (1., 1., 1.),
                 mode = ('bilinear', 'nearest')),
        ResizeD(KEYS,
                (128, 128, 64),
                mode=('trilinear', 'nearest')),
        RandAffineD(keys = KEYS,
                    spatial_size = (128, 128, 64),
                    rotate_range = (0, 0, np.pi/12),
                    scale_range = (0.1, 0.1, 0.1),
                    mode = ('bilinear', 'nearest'),
                    prob = 1.0),
        ToTensorD(KEYS),
    ])


    for idx in range(len(dataset_modified_training)):
        sample = dataset_modified_training[idx]
        transformed_sample = remaining_transform(sample)

        sample["image"] = transformed_sample["image"]
        sample["label"] = transformed_sample["label"]

        image = sample["image"]
        label = sample["label"]


    modified_train_loader = DataLoader(dataset_modified_training, batch_size = 4, shuffle = True, num_workers = 4)
    return (modified_train_loader, dataset_modified_training)

# Model, Loss, Optimizer

In [12]:
device = torch.device("cuda:1")
def initialize_new_model():
    
    UNETR_metadata = dict(
        in_channels=1,
        out_channels=2,
        img_size=(128, 128, 64), 
        feature_size=16,
        hidden_size=768,
        mlp_dim=3072,
        num_heads=12,
        pos_embed='conv',
        norm_name='instance',
        conv_block=True,
        res_block=True,
        dropout_rate=0.0,
        spatial_dims=3,
    )

    model = UNETR(**UNETR_metadata).to(device)
    return model

dice_loss = DiceLoss(to_onehot_y=True, softmax=True)
cross_entropy_loss = DiceCELoss(to_onehot_y=True, softmax=True)

def loss_function(output, label):
    return dice_loss(output, label)

dice_metric = DiceMetric(include_background=False, reduction="mean")


In [17]:
def train_function(model, max_epochs, train_loader, optimizer_name, lrate, val_loader, dataset):
    if(optimizer_name=='AdamW'):
        optimizer = torch.optim.AdamW(model.parameters(), lr=lrate)
    if(optimizer_name=='Adam'):
        optimizer = torch.optim.Adam(model.parameters(), lr=lrate)
    if(optimizer_name=='Nadam'):
        optimizer = torch.optim.NAdam(model.parameters(), lr=lrate)
    if(optimizer_name=='SGD'):
        optimizer = torch.optim.SGD(model.parameters(), lr=lrate)
    val_interval = 1
    best_metric = -1
    best_metric_epoch = -1
    metric_values = []
    epoch_loss_values =[]
    lr = lrate
    post_pred = Compose([AsDiscrete(argmax = True, to_onehot = 2)])
    post_label = Compose([AsDiscrete(to_onehot = 2)])
    current_date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    run_name = f"lr_{lrate}_epoch_{max_epochs}_model_{optimizer_name}_day_{current_date}"
    save_model_name = f"lr_{lrate}_epoch_{max_epochs}_model_{optimizer_name}_day_{current_date}.pth"
    writer = SummaryWriter(log_dir=f"runs_combo_diceloss_2/{run_name}")
    for epoch in range(max_epochs):
      print("-" * 12)
      print(f"Epoch {epoch + 1}/{max_epochs}")
      # Turn model to "train" mode
      model.train()
      epoch_loss = 0
      step = 0
      for batch_data in train_loader:
        step += 1

        input, label = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )

        # A common pytorch Deep Learning format to train model
        optimizer.zero_grad()
        output = model(input)
        loss = loss_function(output, label)
        loss.backward() # Compute gradient
        optimizer.step() # Update model's parameters

        epoch_loss += loss.item()
        print(f"{step}/{len(dataset) // train_loader.batch_size}, "
              f"train_loss: {loss.item():.4f}")
        writer.add_scalar("Loss/train", loss.item(), epoch * len(train_loader) + step)


      epoch_loss /= step
      epoch_loss_values.append(epoch_loss)

      print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

      writer.add_scalar("Loss/epoch", epoch_loss, epoch)

      if (epoch + 1) % val_interval == 0:
        # Turn model to "eval" mode
        model.eval()

        with torch.no_grad():
          for val_data in val_loader:
            val_input, val_label = (
                val_data["image"].to(device),
                val_data["label"].to(device),
            )
            roi_size = (128, 128, 64)
            sw_batch_size = 1

            # Set AMP for MONAI validation
            val_output = sliding_window_inference(
                val_input, roi_size, sw_batch_size, model
            )
            val_output = [post_pred(i) for i in decollate_batch(val_output)]
            val_label = [post_label(i) for i in decollate_batch(val_label)]

            # Compute metric for current iteration
            dice_metric(y_pred = val_output, y = val_label)

          # Aggregate the final mean dice result
          metric = dice_metric.aggregate().item()

          # Reset the status for the next epoch
          dice_metric.reset()

          metric_values.append(metric)
          if metric > best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            print("saved new best metric model!!!")
            torch.save(model.state_dict(), f"{save_model_name}")
#             torch.save(model.state_dict(), "best_metric_model.pth")

          print(
              f"current epoch: {epoch + 1},"
              f" current mean dice: {metric:.4f},"
              f" best mean dice: {best_metric:.4f},"
              f" at epoch: {best_metric_epoch}"
          )
          writer.add_scalar("Dice/val", metric, epoch)

    print(
        f"train completed, best metric: {best_metric:.4f}"
        f" at epoch: {best_metric_epoch}"
    )
    writer.close()
    return model

In [18]:
# # Set up transform
# val_transform = Compose([
#     LoadImageD(keys = KEYS),
#     EnsureChannelFirstD(keys = KEYS),
#     OrientationD(KEYS, axcodes='RAS'),
#     SpacingD(keys = KEYS, pixdim = (1., 1., 1.), mode = ('bilinear', 'nearest')),
#     ScaleIntensityD(keys = "image"),
#     ResizeD(KEYS, (128, 128, 64), mode=('trilinear', 'nearest')),
#     ToTensorD(KEYS),
# ])

# # Load data
# val_dataset_normal = DecathlonDataset(root_dir = "./",
#                            task = "Task06_Lung", section = "validation",
#                            transform = val_transform, download = False)

# val_loadernormal = DataLoader(val_dataset_normal, batch_size = 1, shuffle = True, num_workers = 4)
# train_loader, dataset = get_dataset()

# Start training

In [21]:
epoch_array = [300]
lr_array = [0.001]
optimizer_array =['Nadam']


epoch_array_modif = [300]
lr_array_modif = [0.001]
optimizer_array_modif = ['AdamW']

for i in range(len(epoch_array)):
    for j in range(len(optimizer_array)):
        for k in range(len(lr_array)):
            for m in range(len(epoch_array_modif)):
                for n in range(len(optimizer_array_modif)):
                    for p in range(len(lr_array_modif)):
                        print("Epochs ", epoch_array[i])
                        print("Optimizer ", optimizer_array[j])
                        print("Lr Rate ", lr_array[k])
                        model = initialize_new_model()
                        train_loader, dataset = get_dataset()
                        model = train_function(model, epoch_array[i], train_loader, optimizer_array[j], lr_array[k], 
                                               val_loader, dataset)
                        print("Epochs Modif ", epoch_array_modif[m])
                        print("Optimizer Modif ", optimizer_array_modif[n])
                        print("Lr Rate Modif ", lr_array_modif[p])
                        train_loader, dataset = get_new_dataset()
                        model = train_function(model, epoch_array_modif[m], train_loader, optimizer_array_modif[n], 
                                               lr_array_modif[p], val_loader, dataset)
                       # combo += 1

# train_loader, dataset = get_new_dataset()
# for i in range(len(epoch_array)):
#     for j in range(len(optimizer_array)):
#         for k in range(len(lr_array)):
#             print("Epochs ", epoch_array[i])
#             print("Optimizer ", optimizer_array[j])
#             print("Lr Rate ", lr_array[k])
#             model = initialize_new_model()
#             model = train_function(model, epoch_array[i], train_loader, optimizer_array[j], lr_array[k], 
#                                    val_loader, dataset)

        
          

Epochs  300
Optimizer  Nadam
Lr Rate  0.001


Loading dataset: 100%|██████████| 51/51 [06:51<00:00,  8.07s/it]

------------
Epoch 1/300





1/12, train_loss: 0.6669
2/12, train_loss: 0.6405
3/12, train_loss: 0.6222
4/12, train_loss: 0.6102
5/12, train_loss: 0.6017
6/12, train_loss: 0.6126
7/12, train_loss: 0.5957
8/12, train_loss: 0.5922
9/12, train_loss: 0.5894
10/12, train_loss: 0.5866
11/12, train_loss: 0.5846
12/12, train_loss: 0.5828
13/12, train_loss: 0.5829
epoch 1 average loss: 0.6052
saved new best metric model!!!
current epoch: 1, current mean dice: 0.0010, best mean dice: 0.0010, at epoch: 1
------------
Epoch 2/300
1/12, train_loss: 0.5823
2/12, train_loss: 0.5790
3/12, train_loss: 0.5772
4/12, train_loss: 0.5760
5/12, train_loss: 0.5754
6/12, train_loss: 0.5731
7/12, train_loss: 0.5714
8/12, train_loss: 0.5871
9/12, train_loss: 0.5699
10/12, train_loss: 0.5687
11/12, train_loss: 0.5666
12/12, train_loss: 0.5666
13/12, train_loss: 0.5651
epoch 2 average loss: 0.5737
current epoch: 2, current mean dice: 0.0000, best mean dice: 0.0010, at epoch: 1
------------
Epoch 3/300
1/12, train_loss: 0.5650
2/12, train_loss

11/12, train_loss: 0.5073
12/12, train_loss: 0.5073
13/12, train_loss: 0.5060
epoch 18 average loss: 0.5079
current epoch: 18, current mean dice: 0.0000, best mean dice: 0.0010, at epoch: 1
------------
Epoch 19/300
1/12, train_loss: 0.5079
2/12, train_loss: 0.5071
3/12, train_loss: 0.5073
4/12, train_loss: 0.5075
5/12, train_loss: 0.5075
6/12, train_loss: 0.5065
7/12, train_loss: 0.5075
8/12, train_loss: 0.5078
9/12, train_loss: 0.5075
10/12, train_loss: 0.5073
11/12, train_loss: 0.5076
12/12, train_loss: 0.5071
13/12, train_loss: 0.5068
epoch 19 average loss: 0.5073
current epoch: 19, current mean dice: 0.0000, best mean dice: 0.0010, at epoch: 1
------------
Epoch 20/300
1/12, train_loss: 0.5070
2/12, train_loss: 0.5073
3/12, train_loss: 0.5075
4/12, train_loss: 0.5073
5/12, train_loss: 0.5070
6/12, train_loss: 0.5067
7/12, train_loss: 0.5067
8/12, train_loss: 0.5067
9/12, train_loss: 0.5067
10/12, train_loss: 0.5065
11/12, train_loss: 0.5065
12/12, train_loss: 0.5064
13/12, train_l

2/12, train_loss: 0.4979
3/12, train_loss: 0.4994
4/12, train_loss: 0.4990
5/12, train_loss: 0.4958
6/12, train_loss: 0.5004
7/12, train_loss: 0.4996
8/12, train_loss: 0.4997
9/12, train_loss: 0.4956
10/12, train_loss: 0.4948
11/12, train_loss: 0.4936
12/12, train_loss: 0.4932
13/12, train_loss: 0.5018
epoch 36 average loss: 0.4977
current epoch: 36, current mean dice: 0.0178, best mean dice: 0.0920, at epoch: 31
------------
Epoch 37/300
1/12, train_loss: 0.4985
2/12, train_loss: 0.4987
3/12, train_loss: 0.4980
4/12, train_loss: 0.5013
5/12, train_loss: 0.4933
6/12, train_loss: 0.4935
7/12, train_loss: 0.4942
8/12, train_loss: 0.4915
9/12, train_loss: 0.4926
10/12, train_loss: 0.4979
11/12, train_loss: 0.5020
12/12, train_loss: 0.4921
13/12, train_loss: 0.5012
epoch 37 average loss: 0.4965
current epoch: 37, current mean dice: 0.0000, best mean dice: 0.0920, at epoch: 31
------------
Epoch 38/300
1/12, train_loss: 0.4998
2/12, train_loss: 0.4963
3/12, train_loss: 0.5014
4/12, train_lo

11/12, train_loss: 0.4817
12/12, train_loss: 0.4463
13/12, train_loss: 0.4713
epoch 53 average loss: 0.4647
current epoch: 53, current mean dice: 0.0823, best mean dice: 0.0989, at epoch: 44
------------
Epoch 54/300
1/12, train_loss: 0.4150
2/12, train_loss: 0.4658
3/12, train_loss: 0.4921
4/12, train_loss: 0.4420
5/12, train_loss: 0.4676
6/12, train_loss: 0.4539
7/12, train_loss: 0.4777
8/12, train_loss: 0.4226
9/12, train_loss: 0.4777
10/12, train_loss: 0.4428
11/12, train_loss: 0.4810
12/12, train_loss: 0.4436
13/12, train_loss: 0.4343
epoch 54 average loss: 0.4551
current epoch: 54, current mean dice: 0.0714, best mean dice: 0.0989, at epoch: 44
------------
Epoch 55/300
1/12, train_loss: 0.4561
2/12, train_loss: 0.4308
3/12, train_loss: 0.4238
4/12, train_loss: 0.4614
5/12, train_loss: 0.4835
6/12, train_loss: 0.4553
7/12, train_loss: 0.4435
8/12, train_loss: 0.4601
9/12, train_loss: 0.4338
10/12, train_loss: 0.4802
11/12, train_loss: 0.4334
12/12, train_loss: 0.4758
13/12, train

1/12, train_loss: 0.3273
2/12, train_loss: 0.4128
3/12, train_loss: 0.3889
4/12, train_loss: 0.2703
5/12, train_loss: 0.3263
6/12, train_loss: 0.2712
7/12, train_loss: 0.3658
8/12, train_loss: 0.3583
9/12, train_loss: 0.3889
10/12, train_loss: 0.2705
11/12, train_loss: 0.4211
12/12, train_loss: 0.3868
13/12, train_loss: 0.2339
epoch 71 average loss: 0.3402
current epoch: 71, current mean dice: 0.1272, best mean dice: 0.1934, at epoch: 57
------------
Epoch 72/300
1/12, train_loss: 0.3164
2/12, train_loss: 0.3339
3/12, train_loss: 0.3419
4/12, train_loss: 0.3092
5/12, train_loss: 0.2914
6/12, train_loss: 0.2374
7/12, train_loss: 0.3161
8/12, train_loss: 0.4256
9/12, train_loss: 0.3212
10/12, train_loss: 0.3641
11/12, train_loss: 0.4393
12/12, train_loss: 0.4123
13/12, train_loss: 0.3955
epoch 72 average loss: 0.3465
current epoch: 72, current mean dice: 0.1285, best mean dice: 0.1934, at epoch: 57
------------
Epoch 73/300
1/12, train_loss: 0.3782
2/12, train_loss: 0.2978
3/12, train_lo

11/12, train_loss: 0.2640
12/12, train_loss: 0.2590
13/12, train_loss: 0.3219
epoch 88 average loss: 0.2774
current epoch: 88, current mean dice: 0.1089, best mean dice: 0.1934, at epoch: 57
------------
Epoch 89/300
1/12, train_loss: 0.3871
2/12, train_loss: 0.2037
3/12, train_loss: 0.2798
4/12, train_loss: 0.2050
5/12, train_loss: 0.2754
6/12, train_loss: 0.2436
7/12, train_loss: 0.4051
8/12, train_loss: 0.2554
9/12, train_loss: 0.3164
10/12, train_loss: 0.3605
11/12, train_loss: 0.2058
12/12, train_loss: 0.3537
13/12, train_loss: 0.3076
epoch 89 average loss: 0.2923
current epoch: 89, current mean dice: 0.1063, best mean dice: 0.1934, at epoch: 57
------------
Epoch 90/300
1/12, train_loss: 0.4117
2/12, train_loss: 0.2841
3/12, train_loss: 0.3087
4/12, train_loss: 0.3260
5/12, train_loss: 0.2364
6/12, train_loss: 0.2423
7/12, train_loss: 0.1524
8/12, train_loss: 0.2369
9/12, train_loss: 0.3867
10/12, train_loss: 0.2559
11/12, train_loss: 0.2795
12/12, train_loss: 0.3163
13/12, train

2/12, train_loss: 0.2104
3/12, train_loss: 0.3413
4/12, train_loss: 0.1901
5/12, train_loss: 0.2607
6/12, train_loss: 0.3192
7/12, train_loss: 0.2700
8/12, train_loss: 0.1806
9/12, train_loss: 0.2476
10/12, train_loss: 0.1687
11/12, train_loss: 0.2696
12/12, train_loss: 0.2785
13/12, train_loss: 0.2691
epoch 106 average loss: 0.2595
current epoch: 106, current mean dice: 0.1632, best mean dice: 0.1934, at epoch: 57
------------
Epoch 107/300
1/12, train_loss: 0.1448
2/12, train_loss: 0.1579
3/12, train_loss: 0.2688
4/12, train_loss: 0.2288
5/12, train_loss: 0.4110
6/12, train_loss: 0.3234
7/12, train_loss: 0.1996
8/12, train_loss: 0.3453
9/12, train_loss: 0.2250
10/12, train_loss: 0.2540
11/12, train_loss: 0.1747
12/12, train_loss: 0.2626
13/12, train_loss: 0.3242
epoch 107 average loss: 0.2554
current epoch: 107, current mean dice: 0.1825, best mean dice: 0.1934, at epoch: 57
------------
Epoch 108/300
1/12, train_loss: 0.2297
2/12, train_loss: 0.2812
3/12, train_loss: 0.1952
4/12, tr

8/12, train_loss: 0.2167
9/12, train_loss: 0.2879
10/12, train_loss: 0.2395
11/12, train_loss: 0.2520
12/12, train_loss: 0.3492
13/12, train_loss: 0.1922
epoch 123 average loss: 0.2554
current epoch: 123, current mean dice: 0.1868, best mean dice: 0.1994, at epoch: 110
------------
Epoch 124/300
1/12, train_loss: 0.3223
2/12, train_loss: 0.1621
3/12, train_loss: 0.2134
4/12, train_loss: 0.1841
5/12, train_loss: 0.3195
6/12, train_loss: 0.2938
7/12, train_loss: 0.1842
8/12, train_loss: 0.2266
9/12, train_loss: 0.3157
10/12, train_loss: 0.2624
11/12, train_loss: 0.2995
12/12, train_loss: 0.2898
13/12, train_loss: 0.1765
epoch 124 average loss: 0.2500
current epoch: 124, current mean dice: 0.1560, best mean dice: 0.1994, at epoch: 110
------------
Epoch 125/300
1/12, train_loss: 0.3406
2/12, train_loss: 0.1566
3/12, train_loss: 0.2926
4/12, train_loss: 0.3470
5/12, train_loss: 0.2460
6/12, train_loss: 0.2290
7/12, train_loss: 0.2938
8/12, train_loss: 0.2873
9/12, train_loss: 0.3462
10/12,

current epoch: 140, current mean dice: 0.1711, best mean dice: 0.1997, at epoch: 139
------------
Epoch 141/300
1/12, train_loss: 0.3401
2/12, train_loss: 0.2510
3/12, train_loss: 0.2503
4/12, train_loss: 0.2800
5/12, train_loss: 0.2231
6/12, train_loss: 0.1489
7/12, train_loss: 0.2621
8/12, train_loss: 0.1975
9/12, train_loss: 0.1562
10/12, train_loss: 0.2461
11/12, train_loss: 0.2278
12/12, train_loss: 0.2003
13/12, train_loss: 0.3917
epoch 141 average loss: 0.2442
current epoch: 141, current mean dice: 0.1524, best mean dice: 0.1997, at epoch: 139
------------
Epoch 142/300
1/12, train_loss: 0.2641
2/12, train_loss: 0.2479
3/12, train_loss: 0.2511
4/12, train_loss: 0.2250
5/12, train_loss: 0.2239
6/12, train_loss: 0.3742
7/12, train_loss: 0.1805
8/12, train_loss: 0.3794
9/12, train_loss: 0.2541
10/12, train_loss: 0.2277
11/12, train_loss: 0.2299
12/12, train_loss: 0.1926
13/12, train_loss: 0.1970
epoch 142 average loss: 0.2498
current epoch: 142, current mean dice: 0.1423, best mean

4/12, train_loss: 0.2461
5/12, train_loss: 0.2663
6/12, train_loss: 0.2367
7/12, train_loss: 0.2266
8/12, train_loss: 0.2568
9/12, train_loss: 0.2022
10/12, train_loss: 0.2125
11/12, train_loss: 0.1835
12/12, train_loss: 0.3314
13/12, train_loss: 0.1568
epoch 158 average loss: 0.2246
current epoch: 158, current mean dice: 0.1643, best mean dice: 0.1997, at epoch: 139
------------
Epoch 159/300
1/12, train_loss: 0.2429
2/12, train_loss: 0.2252
3/12, train_loss: 0.2294
4/12, train_loss: 0.1582
5/12, train_loss: 0.2332
6/12, train_loss: 0.4001
7/12, train_loss: 0.1673
8/12, train_loss: 0.2258
9/12, train_loss: 0.2648
10/12, train_loss: 0.1937
11/12, train_loss: 0.2990
12/12, train_loss: 0.1339
13/12, train_loss: 0.2503
epoch 159 average loss: 0.2326
current epoch: 159, current mean dice: 0.1711, best mean dice: 0.1997, at epoch: 139
------------
Epoch 160/300
1/12, train_loss: 0.1134
2/12, train_loss: 0.2600
3/12, train_loss: 0.2982
4/12, train_loss: 0.2759
5/12, train_loss: 0.2406
6/12, 

9/12, train_loss: 0.3365
10/12, train_loss: 0.1530
11/12, train_loss: 0.1426
12/12, train_loss: 0.1929
13/12, train_loss: 0.2052
epoch 175 average loss: 0.2176
current epoch: 175, current mean dice: 0.1178, best mean dice: 0.2176, at epoch: 172
------------
Epoch 176/300
1/12, train_loss: 0.3757
2/12, train_loss: 0.1633
3/12, train_loss: 0.2564
4/12, train_loss: 0.2148
5/12, train_loss: 0.1369
6/12, train_loss: 0.2280
7/12, train_loss: 0.2423
8/12, train_loss: 0.1525
9/12, train_loss: 0.2496
10/12, train_loss: 0.1607
11/12, train_loss: 0.2332
12/12, train_loss: 0.3114
13/12, train_loss: 0.2592
epoch 176 average loss: 0.2295
current epoch: 176, current mean dice: 0.1577, best mean dice: 0.2176, at epoch: 172
------------
Epoch 177/300
1/12, train_loss: 0.1505
2/12, train_loss: 0.2743
3/12, train_loss: 0.1752
4/12, train_loss: 0.2941
5/12, train_loss: 0.2659
6/12, train_loss: 0.1421
7/12, train_loss: 0.2770
8/12, train_loss: 0.1622
9/12, train_loss: 0.2150
10/12, train_loss: 0.2590
11/12

1/12, train_loss: 0.1612
2/12, train_loss: 0.2826
3/12, train_loss: 0.1361
4/12, train_loss: 0.2578
5/12, train_loss: 0.1922
6/12, train_loss: 0.2610
7/12, train_loss: 0.3227
8/12, train_loss: 0.2694
9/12, train_loss: 0.1476
10/12, train_loss: 0.2203
11/12, train_loss: 0.3216
12/12, train_loss: 0.1276
13/12, train_loss: 0.1588
epoch 193 average loss: 0.2199
current epoch: 193, current mean dice: 0.1708, best mean dice: 0.2176, at epoch: 172
------------
Epoch 194/300
1/12, train_loss: 0.1944
2/12, train_loss: 0.2827
3/12, train_loss: 0.2038
4/12, train_loss: 0.2497
5/12, train_loss: 0.1347
6/12, train_loss: 0.2667
7/12, train_loss: 0.1140
8/12, train_loss: 0.2209
9/12, train_loss: 0.2102
10/12, train_loss: 0.1669
11/12, train_loss: 0.1803
12/12, train_loss: 0.3457
13/12, train_loss: 0.2709
epoch 194 average loss: 0.2185
current epoch: 194, current mean dice: 0.2127, best mean dice: 0.2176, at epoch: 172
------------
Epoch 195/300
1/12, train_loss: 0.1396
2/12, train_loss: 0.2271
3/12, 

8/12, train_loss: 0.1942
9/12, train_loss: 0.1962
10/12, train_loss: 0.3197
11/12, train_loss: 0.2879
12/12, train_loss: 0.2288
13/12, train_loss: 0.1841
epoch 210 average loss: 0.2136
current epoch: 210, current mean dice: 0.1595, best mean dice: 0.2176, at epoch: 172
------------
Epoch 211/300
1/12, train_loss: 0.1994
2/12, train_loss: 0.1082
3/12, train_loss: 0.1624
4/12, train_loss: 0.2995
5/12, train_loss: 0.1625
6/12, train_loss: 0.2402
7/12, train_loss: 0.2488
8/12, train_loss: 0.1410
9/12, train_loss: 0.2100
10/12, train_loss: 0.2756
11/12, train_loss: 0.3385
12/12, train_loss: 0.1133
13/12, train_loss: 0.2667
epoch 211 average loss: 0.2128
current epoch: 211, current mean dice: 0.1729, best mean dice: 0.2176, at epoch: 172
------------
Epoch 212/300
1/12, train_loss: 0.2004
2/12, train_loss: 0.1842
3/12, train_loss: 0.1913
4/12, train_loss: 0.0960
5/12, train_loss: 0.2611
6/12, train_loss: 0.2563
7/12, train_loss: 0.2014
8/12, train_loss: 0.2405
9/12, train_loss: 0.2367
10/12,

current epoch: 227, current mean dice: 0.1496, best mean dice: 0.2176, at epoch: 172
------------
Epoch 228/300
1/12, train_loss: 0.2228
2/12, train_loss: 0.1405
3/12, train_loss: 0.2406
4/12, train_loss: 0.1685
5/12, train_loss: 0.1339
6/12, train_loss: 0.2433
7/12, train_loss: 0.2458
8/12, train_loss: 0.2429
9/12, train_loss: 0.3706
10/12, train_loss: 0.1880
11/12, train_loss: 0.2199
12/12, train_loss: 0.1137
13/12, train_loss: 0.1026
epoch 228 average loss: 0.2025
current epoch: 228, current mean dice: 0.1567, best mean dice: 0.2176, at epoch: 172
------------
Epoch 229/300
1/12, train_loss: 0.2101
2/12, train_loss: 0.2825
3/12, train_loss: 0.2374
4/12, train_loss: 0.2275
5/12, train_loss: 0.1526
6/12, train_loss: 0.1083
7/12, train_loss: 0.1951
8/12, train_loss: 0.1502
9/12, train_loss: 0.2909
10/12, train_loss: 0.2414
11/12, train_loss: 0.1266
12/12, train_loss: 0.1797
13/12, train_loss: 0.2833
epoch 229 average loss: 0.2066
current epoch: 229, current mean dice: 0.1890, best mean

3/12, train_loss: 0.1662
4/12, train_loss: 0.1895
5/12, train_loss: 0.2168
6/12, train_loss: 0.1889
7/12, train_loss: 0.1414
8/12, train_loss: 0.2855
9/12, train_loss: 0.2942
10/12, train_loss: 0.2090
11/12, train_loss: 0.0926
12/12, train_loss: 0.1912
13/12, train_loss: 0.1573
epoch 245 average loss: 0.1947
current epoch: 245, current mean dice: 0.1664, best mean dice: 0.2429, at epoch: 241
------------
Epoch 246/300
1/12, train_loss: 0.1220
2/12, train_loss: 0.1195
3/12, train_loss: 0.2544
4/12, train_loss: 0.2765
5/12, train_loss: 0.3359
6/12, train_loss: 0.1634
7/12, train_loss: 0.2342
8/12, train_loss: 0.1562
9/12, train_loss: 0.2369
10/12, train_loss: 0.2197
11/12, train_loss: 0.2177
12/12, train_loss: 0.0856
13/12, train_loss: 0.2196
epoch 246 average loss: 0.2032
current epoch: 246, current mean dice: 0.1794, best mean dice: 0.2429, at epoch: 241
------------
Epoch 247/300
1/12, train_loss: 0.1882
2/12, train_loss: 0.2129
3/12, train_loss: 0.2316
4/12, train_loss: 0.2752
5/12, 

10/12, train_loss: 0.2795
11/12, train_loss: 0.1785
12/12, train_loss: 0.2871
13/12, train_loss: 0.1285
epoch 262 average loss: 0.2001
current epoch: 262, current mean dice: 0.1721, best mean dice: 0.2429, at epoch: 241
------------
Epoch 263/300
1/12, train_loss: 0.1849
2/12, train_loss: 0.2063
3/12, train_loss: 0.2146
4/12, train_loss: 0.0976
5/12, train_loss: 0.1174
6/12, train_loss: 0.1813
7/12, train_loss: 0.1950
8/12, train_loss: 0.2056
9/12, train_loss: 0.2969
10/12, train_loss: 0.2028
11/12, train_loss: 0.1546
12/12, train_loss: 0.2279
13/12, train_loss: 0.2924
epoch 263 average loss: 0.1982
current epoch: 263, current mean dice: 0.1955, best mean dice: 0.2429, at epoch: 241
------------
Epoch 264/300
1/12, train_loss: 0.2111
2/12, train_loss: 0.0944
3/12, train_loss: 0.3122
4/12, train_loss: 0.1975
5/12, train_loss: 0.1263
6/12, train_loss: 0.1680
7/12, train_loss: 0.1643
8/12, train_loss: 0.1301
9/12, train_loss: 0.3114
10/12, train_loss: 0.2584
11/12, train_loss: 0.1716
12/1

1/12, train_loss: 0.1903
2/12, train_loss: 0.2428
3/12, train_loss: 0.1426
4/12, train_loss: 0.2317
5/12, train_loss: 0.1262
6/12, train_loss: 0.2460
7/12, train_loss: 0.2091
8/12, train_loss: 0.1488
9/12, train_loss: 0.1759
10/12, train_loss: 0.2156
11/12, train_loss: 0.2364
12/12, train_loss: 0.2160
13/12, train_loss: 0.2416
epoch 280 average loss: 0.2018
current epoch: 280, current mean dice: 0.1823, best mean dice: 0.2429, at epoch: 241
------------
Epoch 281/300
1/12, train_loss: 0.2954
2/12, train_loss: 0.1610
3/12, train_loss: 0.1079
4/12, train_loss: 0.1155
5/12, train_loss: 0.1268
6/12, train_loss: 0.1702
7/12, train_loss: 0.3117
8/12, train_loss: 0.2137
9/12, train_loss: 0.3563
10/12, train_loss: 0.2110
11/12, train_loss: 0.1124
12/12, train_loss: 0.1293
13/12, train_loss: 0.1449
epoch 281 average loss: 0.1889
current epoch: 281, current mean dice: 0.1275, best mean dice: 0.2429, at epoch: 241
------------
Epoch 282/300
1/12, train_loss: 0.2493
2/12, train_loss: 0.1313
3/12, 

8/12, train_loss: 0.2365
9/12, train_loss: 0.2202
10/12, train_loss: 0.1208
11/12, train_loss: 0.1702
12/12, train_loss: 0.2812
13/12, train_loss: 0.1649
epoch 297 average loss: 0.1895
current epoch: 297, current mean dice: 0.1641, best mean dice: 0.2429, at epoch: 241
------------
Epoch 298/300
1/12, train_loss: 0.1499
2/12, train_loss: 0.1259
3/12, train_loss: 0.0882
4/12, train_loss: 0.1788
5/12, train_loss: 0.2216
6/12, train_loss: 0.1254
7/12, train_loss: 0.3140
8/12, train_loss: 0.2335
9/12, train_loss: 0.4119
10/12, train_loss: 0.1172
11/12, train_loss: 0.1447
12/12, train_loss: 0.1819
13/12, train_loss: 0.3249
epoch 298 average loss: 0.2014
current epoch: 298, current mean dice: 0.2141, best mean dice: 0.2429, at epoch: 241
------------
Epoch 299/300
1/12, train_loss: 0.2483
2/12, train_loss: 0.1344
3/12, train_loss: 0.1507
4/12, train_loss: 0.2474
5/12, train_loss: 0.1554
6/12, train_loss: 0.2400
7/12, train_loss: 0.1670
8/12, train_loss: 0.2911
9/12, train_loss: 0.2544
10/12,

Loading dataset: 100%|██████████| 51/51 [03:18<00:00,  3.90s/it]


------------
Epoch 1/300
1/12, train_loss: 0.4158
2/12, train_loss: 0.3592
3/12, train_loss: 0.3935
4/12, train_loss: 0.4344
5/12, train_loss: 0.3687
6/12, train_loss: 0.3217
7/12, train_loss: 0.3001
8/12, train_loss: 0.3775
9/12, train_loss: 0.3605
10/12, train_loss: 0.3352
11/12, train_loss: 0.3571
12/12, train_loss: 0.4315
13/12, train_loss: 0.4189
epoch 1 average loss: 0.3749
saved new best metric model!!!
current epoch: 1, current mean dice: 0.2464, best mean dice: 0.2464, at epoch: 1
------------
Epoch 2/300
1/12, train_loss: 0.3835
2/12, train_loss: 0.3723
3/12, train_loss: 0.3680
4/12, train_loss: 0.4010
5/12, train_loss: 0.3019
6/12, train_loss: 0.3402
7/12, train_loss: 0.2983
8/12, train_loss: 0.3838
9/12, train_loss: 0.4058
10/12, train_loss: 0.2566
11/12, train_loss: 0.3954
12/12, train_loss: 0.4763
13/12, train_loss: 0.3061
epoch 2 average loss: 0.3607
saved new best metric model!!!
current epoch: 2, current mean dice: 0.3263, best mean dice: 0.3263, at epoch: 2
----------

5/12, train_loss: 0.2156
6/12, train_loss: 0.3925
7/12, train_loss: 0.1756
8/12, train_loss: 0.2702
9/12, train_loss: 0.2417
10/12, train_loss: 0.2822
11/12, train_loss: 0.2821
12/12, train_loss: 0.2922
13/12, train_loss: 0.4033
epoch 18 average loss: 0.2592
current epoch: 18, current mean dice: 0.4138, best mean dice: 0.4304, at epoch: 7
------------
Epoch 19/300
1/12, train_loss: 0.1853
2/12, train_loss: 0.2054
3/12, train_loss: 0.2214
4/12, train_loss: 0.2575
5/12, train_loss: 0.1554
6/12, train_loss: 0.2598
7/12, train_loss: 0.2438
8/12, train_loss: 0.3997
9/12, train_loss: 0.1968
10/12, train_loss: 0.2424
11/12, train_loss: 0.3552
12/12, train_loss: 0.2155
13/12, train_loss: 0.3378
epoch 19 average loss: 0.2520
current epoch: 19, current mean dice: 0.3697, best mean dice: 0.4304, at epoch: 7
------------
Epoch 20/300
1/12, train_loss: 0.1998
2/12, train_loss: 0.2303
3/12, train_loss: 0.2618
4/12, train_loss: 0.2465
5/12, train_loss: 0.2543
6/12, train_loss: 0.2272
7/12, train_loss

10/12, train_loss: 0.1961
11/12, train_loss: 0.1285
12/12, train_loss: 0.2382
13/12, train_loss: 0.2060
epoch 35 average loss: 0.2093
current epoch: 35, current mean dice: 0.4286, best mean dice: 0.4508, at epoch: 25
------------
Epoch 36/300
1/12, train_loss: 0.1610
2/12, train_loss: 0.1967
3/12, train_loss: 0.1229
4/12, train_loss: 0.1991
5/12, train_loss: 0.1614
6/12, train_loss: 0.2120
7/12, train_loss: 0.2969
8/12, train_loss: 0.3545
9/12, train_loss: 0.1167
10/12, train_loss: 0.2200
11/12, train_loss: 0.1859
12/12, train_loss: 0.1729
13/12, train_loss: 0.2896
epoch 36 average loss: 0.2069
current epoch: 36, current mean dice: 0.4483, best mean dice: 0.4508, at epoch: 25
------------
Epoch 37/300
1/12, train_loss: 0.1388
2/12, train_loss: 0.1019
3/12, train_loss: 0.1356
4/12, train_loss: 0.1626
5/12, train_loss: 0.3147
6/12, train_loss: 0.3134
7/12, train_loss: 0.1812
8/12, train_loss: 0.1897
9/12, train_loss: 0.2173
10/12, train_loss: 0.1543
11/12, train_loss: 0.2357
12/12, train

1/12, train_loss: 0.2737
2/12, train_loss: 0.2232
3/12, train_loss: 0.2061
4/12, train_loss: 0.2527
5/12, train_loss: 0.1280
6/12, train_loss: 0.1335
7/12, train_loss: 0.2659
8/12, train_loss: 0.1010
9/12, train_loss: 0.2460
10/12, train_loss: 0.2269
11/12, train_loss: 0.1901
12/12, train_loss: 0.1363
13/12, train_loss: 0.0746
epoch 53 average loss: 0.1891
current epoch: 53, current mean dice: 0.4518, best mean dice: 0.4636, at epoch: 47
------------
Epoch 54/300
1/12, train_loss: 0.0972
2/12, train_loss: 0.2118
3/12, train_loss: 0.3008
4/12, train_loss: 0.1960
5/12, train_loss: 0.1334
6/12, train_loss: 0.1299
7/12, train_loss: 0.2692
8/12, train_loss: 0.2564
9/12, train_loss: 0.1523
10/12, train_loss: 0.2361
11/12, train_loss: 0.1626
12/12, train_loss: 0.2255
13/12, train_loss: 0.1494
epoch 54 average loss: 0.1939
saved new best metric model!!!
current epoch: 54, current mean dice: 0.4695, best mean dice: 0.4695, at epoch: 54
------------
Epoch 55/300
1/12, train_loss: 0.1081
2/12, tr

10/12, train_loss: 0.2046
11/12, train_loss: 0.2435
12/12, train_loss: 0.2252
13/12, train_loss: 0.1097
epoch 70 average loss: 0.1937
current epoch: 70, current mean dice: 0.4306, best mean dice: 0.4695, at epoch: 54
------------
Epoch 71/300
1/12, train_loss: 0.2988
2/12, train_loss: 0.2878
3/12, train_loss: 0.1372
4/12, train_loss: 0.2082
5/12, train_loss: 0.2119
6/12, train_loss: 0.2725
7/12, train_loss: 0.1331
8/12, train_loss: 0.1204
9/12, train_loss: 0.2788
10/12, train_loss: 0.0985
11/12, train_loss: 0.1455
12/12, train_loss: 0.1462
13/12, train_loss: 0.2008
epoch 71 average loss: 0.1954
current epoch: 71, current mean dice: 0.4126, best mean dice: 0.4695, at epoch: 54
------------
Epoch 72/300
1/12, train_loss: 0.1139
2/12, train_loss: 0.2153
3/12, train_loss: 0.2041
4/12, train_loss: 0.1083
5/12, train_loss: 0.1166
6/12, train_loss: 0.2032
7/12, train_loss: 0.2825
8/12, train_loss: 0.2306
9/12, train_loss: 0.1819
10/12, train_loss: 0.1022
11/12, train_loss: 0.1869
12/12, train

1/12, train_loss: 0.2131
2/12, train_loss: 0.3006
3/12, train_loss: 0.2167
4/12, train_loss: 0.1476
5/12, train_loss: 0.1304
6/12, train_loss: 0.2595
7/12, train_loss: 0.1669
8/12, train_loss: 0.0914
9/12, train_loss: 0.0803
10/12, train_loss: 0.1782
11/12, train_loss: 0.1488
12/12, train_loss: 0.1144
13/12, train_loss: 0.2614
epoch 88 average loss: 0.1776
current epoch: 88, current mean dice: 0.4218, best mean dice: 0.4695, at epoch: 54
------------
Epoch 89/300
1/12, train_loss: 0.0825
2/12, train_loss: 0.1729
3/12, train_loss: 0.2474
4/12, train_loss: 0.2342
5/12, train_loss: 0.1036
6/12, train_loss: 0.2282
7/12, train_loss: 0.1098
8/12, train_loss: 0.3002
9/12, train_loss: 0.0964
10/12, train_loss: 0.1927
11/12, train_loss: 0.1286
12/12, train_loss: 0.2054
13/12, train_loss: 0.1828
epoch 89 average loss: 0.1757
current epoch: 89, current mean dice: 0.4447, best mean dice: 0.4695, at epoch: 54
------------
Epoch 90/300
1/12, train_loss: 0.0588
2/12, train_loss: 0.1766
3/12, train_lo

9/12, train_loss: 0.1055
10/12, train_loss: 0.1984
11/12, train_loss: 0.1173
12/12, train_loss: 0.1136
13/12, train_loss: 0.0768
epoch 105 average loss: 0.1471
current epoch: 105, current mean dice: 0.3974, best mean dice: 0.4800, at epoch: 98
------------
Epoch 106/300
1/12, train_loss: 0.1162
2/12, train_loss: 0.1805
3/12, train_loss: 0.1068
4/12, train_loss: 0.2520
5/12, train_loss: 0.1392
6/12, train_loss: 0.1033
7/12, train_loss: 0.1913
8/12, train_loss: 0.1805
9/12, train_loss: 0.1853
10/12, train_loss: 0.1641
11/12, train_loss: 0.2441
12/12, train_loss: 0.2214
13/12, train_loss: 0.1143
epoch 106 average loss: 0.1692
current epoch: 106, current mean dice: 0.4487, best mean dice: 0.4800, at epoch: 98
------------
Epoch 107/300
1/12, train_loss: 0.2641
2/12, train_loss: 0.2119
3/12, train_loss: 0.1407
4/12, train_loss: 0.1855
5/12, train_loss: 0.0877
6/12, train_loss: 0.2412
7/12, train_loss: 0.0998
8/12, train_loss: 0.2080
9/12, train_loss: 0.1517
10/12, train_loss: 0.0943
11/12, 

current epoch: 122, current mean dice: 0.4412, best mean dice: 0.4817, at epoch: 119
------------
Epoch 123/300
1/12, train_loss: 0.1172
2/12, train_loss: 0.1348
3/12, train_loss: 0.1759
4/12, train_loss: 0.0883
5/12, train_loss: 0.2546
6/12, train_loss: 0.0659
7/12, train_loss: 0.2901
8/12, train_loss: 0.0774
9/12, train_loss: 0.1552
10/12, train_loss: 0.1017
11/12, train_loss: 0.1349
12/12, train_loss: 0.1851
13/12, train_loss: 0.2013
epoch 123 average loss: 0.1525
current epoch: 123, current mean dice: 0.4284, best mean dice: 0.4817, at epoch: 119
------------
Epoch 124/300
1/12, train_loss: 0.1147
2/12, train_loss: 0.1018
3/12, train_loss: 0.1759
4/12, train_loss: 0.1001
5/12, train_loss: 0.2317
6/12, train_loss: 0.3068
7/12, train_loss: 0.3051
8/12, train_loss: 0.0967
9/12, train_loss: 0.0823
10/12, train_loss: 0.0684
11/12, train_loss: 0.1830
12/12, train_loss: 0.1278
13/12, train_loss: 0.0684
epoch 124 average loss: 0.1510
current epoch: 124, current mean dice: 0.4531, best mean

3/12, train_loss: 0.1366
4/12, train_loss: 0.1956
5/12, train_loss: 0.1437
6/12, train_loss: 0.1329
7/12, train_loss: 0.0690
8/12, train_loss: 0.2187
9/12, train_loss: 0.0943
10/12, train_loss: 0.1536
11/12, train_loss: 0.1408
12/12, train_loss: 0.0831
13/12, train_loss: 0.1883
epoch 140 average loss: 0.1443
current epoch: 140, current mean dice: 0.4767, best mean dice: 0.5095, at epoch: 125
------------
Epoch 141/300
1/12, train_loss: 0.1343
2/12, train_loss: 0.2436
3/12, train_loss: 0.1061
4/12, train_loss: 0.0652
5/12, train_loss: 0.0947
6/12, train_loss: 0.1651
7/12, train_loss: 0.1765
8/12, train_loss: 0.1347
9/12, train_loss: 0.1112
10/12, train_loss: 0.2084
11/12, train_loss: 0.0702
12/12, train_loss: 0.1924
13/12, train_loss: 0.0633
epoch 141 average loss: 0.1358
current epoch: 141, current mean dice: 0.4767, best mean dice: 0.5095, at epoch: 125
------------
Epoch 142/300
1/12, train_loss: 0.0858
2/12, train_loss: 0.2675
3/12, train_loss: 0.1875
4/12, train_loss: 0.0688
5/12, 

9/12, train_loss: 0.2028
10/12, train_loss: 0.3442
11/12, train_loss: 0.2761
12/12, train_loss: 0.1335
13/12, train_loss: 0.1458
epoch 157 average loss: 0.1572
current epoch: 157, current mean dice: 0.4460, best mean dice: 0.5172, at epoch: 154
------------
Epoch 158/300
1/12, train_loss: 0.2127
2/12, train_loss: 0.1653
3/12, train_loss: 0.1461
4/12, train_loss: 0.1252
5/12, train_loss: 0.2611
6/12, train_loss: 0.0654
7/12, train_loss: 0.1721
8/12, train_loss: 0.1651
9/12, train_loss: 0.1300
10/12, train_loss: 0.0565
11/12, train_loss: 0.0971
12/12, train_loss: 0.2067
13/12, train_loss: 0.1914
epoch 158 average loss: 0.1534
current epoch: 158, current mean dice: 0.5093, best mean dice: 0.5172, at epoch: 154
------------
Epoch 159/300
1/12, train_loss: 0.2233
2/12, train_loss: 0.1764
3/12, train_loss: 0.1693
4/12, train_loss: 0.1486
5/12, train_loss: 0.0720
6/12, train_loss: 0.1203
7/12, train_loss: 0.0748
8/12, train_loss: 0.1557
9/12, train_loss: 0.0792
10/12, train_loss: 0.1797
11/12

1/12, train_loss: 0.2172
2/12, train_loss: 0.1518
3/12, train_loss: 0.0793
4/12, train_loss: 0.1513
5/12, train_loss: 0.0534
6/12, train_loss: 0.1131
7/12, train_loss: 0.0744
8/12, train_loss: 0.1725
9/12, train_loss: 0.0944
10/12, train_loss: 0.0530
11/12, train_loss: 0.1923
12/12, train_loss: 0.1519
13/12, train_loss: 0.2619
epoch 175 average loss: 0.1359
current epoch: 175, current mean dice: 0.5052, best mean dice: 0.5172, at epoch: 154
------------
Epoch 176/300
1/12, train_loss: 0.0733
2/12, train_loss: 0.0798
3/12, train_loss: 0.0885
4/12, train_loss: 0.1024
5/12, train_loss: 0.1092
6/12, train_loss: 0.1912
7/12, train_loss: 0.0616
8/12, train_loss: 0.2494
9/12, train_loss: 0.0813
10/12, train_loss: 0.1684
11/12, train_loss: 0.2162
12/12, train_loss: 0.1168
13/12, train_loss: 0.0730
epoch 176 average loss: 0.1239
current epoch: 176, current mean dice: 0.4783, best mean dice: 0.5172, at epoch: 154
------------
Epoch 177/300
1/12, train_loss: 0.2149
2/12, train_loss: 0.1220
3/12, 

7/12, train_loss: 0.0836
8/12, train_loss: 0.0803
9/12, train_loss: 0.0698
10/12, train_loss: 0.1081
11/12, train_loss: 0.1163
12/12, train_loss: 0.1438
13/12, train_loss: 0.1157
epoch 192 average loss: 0.1266
current epoch: 192, current mean dice: 0.5095, best mean dice: 0.5184, at epoch: 184
------------
Epoch 193/300
1/12, train_loss: 0.0576
2/12, train_loss: 0.1544
3/12, train_loss: 0.0716
4/12, train_loss: 0.0593
5/12, train_loss: 0.1752
6/12, train_loss: 0.0871
7/12, train_loss: 0.1233
8/12, train_loss: 0.2136
9/12, train_loss: 0.1075
10/12, train_loss: 0.2234
11/12, train_loss: 0.2133
12/12, train_loss: 0.2006
13/12, train_loss: 0.1227
epoch 193 average loss: 0.1392
current epoch: 193, current mean dice: 0.4856, best mean dice: 0.5184, at epoch: 184
------------
Epoch 194/300
1/12, train_loss: 0.0714
2/12, train_loss: 0.1451
3/12, train_loss: 0.0855
4/12, train_loss: 0.1141
5/12, train_loss: 0.1559
6/12, train_loss: 0.2597
7/12, train_loss: 0.1984
8/12, train_loss: 0.1310
9/12, 

13/12, train_loss: 0.0723
epoch 209 average loss: 0.1450
current epoch: 209, current mean dice: 0.4960, best mean dice: 0.5392, at epoch: 208
------------
Epoch 210/300
1/12, train_loss: 0.0655
2/12, train_loss: 0.1637
3/12, train_loss: 0.0777
4/12, train_loss: 0.0702
5/12, train_loss: 0.1364
6/12, train_loss: 0.0950
7/12, train_loss: 0.0908
8/12, train_loss: 0.1106
9/12, train_loss: 0.0738
10/12, train_loss: 0.1895
11/12, train_loss: 0.2748
12/12, train_loss: 0.1178
13/12, train_loss: 0.1920
epoch 210 average loss: 0.1275
current epoch: 210, current mean dice: 0.4668, best mean dice: 0.5392, at epoch: 208
------------
Epoch 211/300
1/12, train_loss: 0.1464
2/12, train_loss: 0.1551
3/12, train_loss: 0.0808
4/12, train_loss: 0.1318
5/12, train_loss: 0.1569
6/12, train_loss: 0.0853
7/12, train_loss: 0.0986
8/12, train_loss: 0.1081
9/12, train_loss: 0.2455
10/12, train_loss: 0.0826
11/12, train_loss: 0.0716
12/12, train_loss: 0.1435
13/12, train_loss: 0.2209
epoch 211 average loss: 0.1329

1/12, train_loss: 0.1352
2/12, train_loss: 0.0715
3/12, train_loss: 0.1028
4/12, train_loss: 0.2562
5/12, train_loss: 0.1573
6/12, train_loss: 0.2180
7/12, train_loss: 0.0672
8/12, train_loss: 0.1291
9/12, train_loss: 0.1152
10/12, train_loss: 0.1976
11/12, train_loss: 0.1268
12/12, train_loss: 0.1184
13/12, train_loss: 0.0616
epoch 227 average loss: 0.1351
current epoch: 227, current mean dice: 0.4646, best mean dice: 0.5392, at epoch: 208
------------
Epoch 228/300
1/12, train_loss: 0.0626
2/12, train_loss: 0.1217
3/12, train_loss: 0.1961
4/12, train_loss: 0.1983
5/12, train_loss: 0.1432
6/12, train_loss: 0.0915
7/12, train_loss: 0.0935
8/12, train_loss: 0.2187
9/12, train_loss: 0.0510
10/12, train_loss: 0.0875
11/12, train_loss: 0.0823
12/12, train_loss: 0.1362
13/12, train_loss: 0.1846
epoch 228 average loss: 0.1282
current epoch: 228, current mean dice: 0.4903, best mean dice: 0.5392, at epoch: 208
------------
Epoch 229/300
1/12, train_loss: 0.0558
2/12, train_loss: 0.0668
3/12, 

8/12, train_loss: 0.0370
9/12, train_loss: 0.1916
10/12, train_loss: 0.0352
11/12, train_loss: 0.0523
12/12, train_loss: 0.0720
13/12, train_loss: 0.0675
epoch 244 average loss: 0.0900
current epoch: 244, current mean dice: 0.4449, best mean dice: 0.5392, at epoch: 208
------------
Epoch 245/300
1/12, train_loss: 0.0487
2/12, train_loss: 0.0855
3/12, train_loss: 0.1964
4/12, train_loss: 0.0632
5/12, train_loss: 0.0884
6/12, train_loss: 0.0488
7/12, train_loss: 0.0666
8/12, train_loss: 0.1760
9/12, train_loss: 0.0363
10/12, train_loss: 0.1294
11/12, train_loss: 0.0885
12/12, train_loss: 0.0487
13/12, train_loss: 0.0868
epoch 245 average loss: 0.0895
current epoch: 245, current mean dice: 0.4667, best mean dice: 0.5392, at epoch: 208
------------
Epoch 246/300
1/12, train_loss: 0.0786
2/12, train_loss: 0.0990
3/12, train_loss: 0.0805
4/12, train_loss: 0.0953
5/12, train_loss: 0.0172
6/12, train_loss: 0.0723
7/12, train_loss: 0.1663
8/12, train_loss: 0.0412
9/12, train_loss: 0.0606
10/12,

current epoch: 261, current mean dice: 0.4505, best mean dice: 0.5392, at epoch: 208
------------
Epoch 262/300
1/12, train_loss: 0.0483
2/12, train_loss: 0.1079
3/12, train_loss: 0.1127
4/12, train_loss: 0.1343
5/12, train_loss: 0.1530
6/12, train_loss: 0.0644
7/12, train_loss: 0.2006
8/12, train_loss: 0.0891
9/12, train_loss: 0.0775
10/12, train_loss: 0.0659
11/12, train_loss: 0.0808
12/12, train_loss: 0.0702
13/12, train_loss: 0.0545
epoch 262 average loss: 0.0969
current epoch: 262, current mean dice: 0.4526, best mean dice: 0.5392, at epoch: 208
------------
Epoch 263/300
1/12, train_loss: 0.1688
2/12, train_loss: 0.1343
3/12, train_loss: 0.1570
4/12, train_loss: 0.1076
5/12, train_loss: 0.0629
6/12, train_loss: 0.1544
7/12, train_loss: 0.0756
8/12, train_loss: 0.0938
9/12, train_loss: 0.0800
10/12, train_loss: 0.0289
11/12, train_loss: 0.0493
12/12, train_loss: 0.0355
13/12, train_loss: 0.0646
epoch 263 average loss: 0.0933
current epoch: 263, current mean dice: 0.4440, best mean

4/12, train_loss: 0.0511
5/12, train_loss: 0.0988
6/12, train_loss: 0.0581
7/12, train_loss: 0.1194
8/12, train_loss: 0.0536
9/12, train_loss: 0.0504
10/12, train_loss: 0.1518
11/12, train_loss: 0.0495
12/12, train_loss: 0.0494
13/12, train_loss: 0.2039
epoch 279 average loss: 0.0899
current epoch: 279, current mean dice: 0.4426, best mean dice: 0.5392, at epoch: 208
------------
Epoch 280/300
1/12, train_loss: 0.1717
2/12, train_loss: 0.0969
3/12, train_loss: 0.0575
4/12, train_loss: 0.0699
5/12, train_loss: 0.0414
6/12, train_loss: 0.0467
7/12, train_loss: 0.0570
8/12, train_loss: 0.0501
9/12, train_loss: 0.0591
10/12, train_loss: 0.1407
11/12, train_loss: 0.0791
12/12, train_loss: 0.1687
13/12, train_loss: 0.1023
epoch 280 average loss: 0.0878
current epoch: 280, current mean dice: 0.4890, best mean dice: 0.5392, at epoch: 208
------------
Epoch 281/300
1/12, train_loss: 0.0736
2/12, train_loss: 0.0246
3/12, train_loss: 0.0635
4/12, train_loss: 0.0579
5/12, train_loss: 0.0995
6/12, 

11/12, train_loss: 0.1090
12/12, train_loss: 0.0446
13/12, train_loss: 0.0501
epoch 296 average loss: 0.0914
current epoch: 296, current mean dice: 0.4486, best mean dice: 0.5392, at epoch: 208
------------
Epoch 297/300
1/12, train_loss: 0.0600
2/12, train_loss: 0.0756
3/12, train_loss: 0.2612
4/12, train_loss: 0.0727
5/12, train_loss: 0.0823
6/12, train_loss: 0.0616
7/12, train_loss: 0.0641
8/12, train_loss: 0.0969
9/12, train_loss: 0.1012
10/12, train_loss: 0.0374
11/12, train_loss: 0.0357
12/12, train_loss: 0.1229
13/12, train_loss: 0.0474
epoch 297 average loss: 0.0861
current epoch: 297, current mean dice: 0.4474, best mean dice: 0.5392, at epoch: 208
------------
Epoch 298/300
1/12, train_loss: 0.0422
2/12, train_loss: 0.2088
3/12, train_loss: 0.0408
4/12, train_loss: 0.0766
5/12, train_loss: 0.1191
6/12, train_loss: 0.0489
7/12, train_loss: 0.0789
8/12, train_loss: 0.1133
9/12, train_loss: 0.0587
10/12, train_loss: 0.0383
11/12, train_loss: 0.0397
12/12, train_loss: 0.0611
13/1

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.savefig('model_03_graph.png')
plt.show()
#plt.savefig(PATH_NAME+'model.png')

# Check model and Visualize

In [20]:
# print(
#     f"train completed, best_metric: {best_metric:.4f} "
#     f"at epoch: {best_metric_epoch}")
%load_ext tensorboard
%tensorboard --logdir=runs_combo_diceloss_2

In [None]:
slide = 28
plt.figure("Test Model", (12, 12))

plt.subplot(1, 3, 1)
plt.title("Input")
plt.imshow(image[0, 0, :, :, slide], cmap = "gray")

plt.subplot(1, 3, 2)
plt.title("Label")
plt.imshow(image[0, 0, :, :, slide], cmap = "gray")
plt.imshow(label[0, 0, :, :, slide], cmap = 'jet', alpha = 0.5)

plt.show()

In [None]:
model.cpu()
model.load_state_dict(torch.load("best_metric_model.pth"))

with torch.no_grad():
  output = model(image)
  output = torch.nn.Softmax()(output)
  output = torch.round(output)

  print(f"Output shape: {output.shape}")

  slide = 28
  plt.figure("Test Model", (12, 12))

  plt.subplot(1, 3, 1)
  plt.title("Input")
  plt.imshow(image[0, 0, :, :, slide], cmap = "gray")

  plt.subplot(1, 3, 2)
  plt.title("Label")
  plt.imshow(image[0, 0, :, :, slide], cmap = "gray")
  plt.imshow(label[0, 0, :, :, slide], cmap = 'jet', alpha = 0.5)

  plt.subplot(1, 3, 3)
  plt.title("Output")
  plt.imshow(image[0, 0, :, :, slide], cmap = "gray")
  plt.imshow(output[0, 1, :, :, slide], cmap = 'jet', alpha = 0.5)

  plt.show()

In [None]:
def predict(image, label, model):
    with torch.no_grad():
      output = model(image)
      output = torch.nn.Softmax()(output)
      output = torch.round(output)

      print(f"Output shape: {output.shape}")

      slide = 28
      plt.figure("Test Model", (12, 12))

      plt.subplot(1, 3, 1)
      plt.title("Input")
      plt.imshow(image[0, 0, :, :, slide], cmap = "gray")

      plt.subplot(1, 3, 2)
      plt.title("Label")
      plt.imshow(image[0, 0, :, :, slide], cmap = "gray")
      plt.imshow(label[0, 0, :, :, slide], cmap = 'jet', alpha = 0.5)

      plt.subplot(1, 3, 3)
      plt.title("Output")
      plt.imshow(image[0, 0, :, :, slide], cmap = "gray")
      plt.imshow(output[0, 1, :, :, slide], cmap = 'jet', alpha = 0.5)

      plt.show()

for i in range(50):
    checker = first(val_loader)
    image, label = checker['image'], checker['label']
    predict(image, label, model)