In [6]:
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
import torchio as tio
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib

import time
from datetime import datetime
from tqdm import tqdm
from itertools import combinations_with_replacement

In [1]:
import sys
import os
sys.path.append(os.path.dirname('../.'))

from scripts.load_and_save import (get_dcm_info, get_dcm_vol, vox_size2affine,
                                   load_nii_vol, save_vol_as_nii, raw2nifti)
from scripts.utils import print_img, print_imgs, get_path
from scripts.hessian_based import hessian_detect_2016

from ml.utils import get_total_params, load_pretrainned, test_model
from ml.ControllerClass import Controller
from ml.tio_dataset import TioDataset
from ml.metrics import (DICE_Metric, JAC_Metric, SN_Metric, SP_Metric,
                        IOU_Metric, ExponentialLogarithmicLoss)

from ml.models.HessNet import HessBlock, HessNet, HessNet2, GaussianBlur3D, HessianTorch
from ml.models.unet3d import U_Net, U_HessNet, ParallelNet
from ml.models.unet2d import U_Net2d

In [20]:
DEVICE = 'cuda'
IS2D = 1

In [21]:
if IS2D:
    PATCH_SIZE = (512, 512, 1)
else:
    PATCH_SIZE = (64, 64, 64)
    
train_settings  = {
    "patch_shape" : PATCH_SIZE,
    "patches_per_volume" : 64,
    "patches_queue_length" : 1440,
    "batch_size" : 16,
    "num_workers": 4,
    "sampler": "uniform",#"weighted"
}

val_settings  = {
    "patch_shape" : PATCH_SIZE,
    "patches_per_volume" : 32,
    "patches_queue_length" : 1440,
    "batch_size" : 32,
    "num_workers": 4,
    "sampler": "uniform",#"weighted" #"uniform",#
}
if IS2D:
    test_settings  = {
    "patch_shape" : (512, 512, 1),
    "overlap_shape" : (0, 0, 0),
    "batch_size" : 16,
    "num_workers": 4,
    }
else:
    test_settings = {
        "patch_shape" : (256, 256, 64),
        "overlap_shape" : (32, 32, 24),
        "batch_size" : 1,
        "num_workers": 4,
    }

data_dir = "/home/msst/Documents/medtech/data/HessData_IXI"
dataset = TioDataset(data_dir,
                 train_settings=train_settings,
                 val_settings=None,#val_settings,
                 test_settings=test_settings)

In [22]:
#model = HessNet2(start_scale=[0.8], device=DEVICE, channel_coef=4)
#model = HessNet(start_scale=[0.8], device=DEVICE) #HessNet_isotropic_3, HessNet_isotropic_4
#model = U_HessNet(channels=16) #Unet_8ch, Unet_16ch
#model = ParallelNet(in_ch=1, inter_ch=5) #Unet_8ch, Unet_16ch


#model = HessNet(start_scale=[0.8, 0.8, 1.2], device=DEVICE) #HessNet_anisotropic_3, HessNet_anisotropic_4
#model = U_Net(channels=16) #Unet_8ch, Unet_16ch
model = U_Net2d(channels=16) #Unet_8ch, Unet_16ch

print("total_params:", get_total_params(model))

total_params: 2161841


In [23]:
controller_config = {
    "loss" : ExponentialLogarithmicLoss(gamma_tversky=0.5, gamma_bce=0.5, lamb=0.5,
                                        freq = 0.1, tversky_alfa=0.5),
    "metric" : DICE_Metric(),
    'device' : DEVICE,
    "model" : model,
    "optimizer_fn" : lambda model: torch.optim.Adam(model.parameters(), lr=0.02),
    "sheduler_fn": lambda optimizer: StepLR(optimizer, step_size=1, gamma=0.9),
    "is2d" : IS2D,
    'verbose':True
}
controller = Controller(controller_config)

In [7]:
controller.fit(dataset, 25)

Epoch 1/25


100%|███████████████████████████████████████████| 48/48 [00:23<00:00,  2.04it/s]


{'mean_loss': 0.7787109712759653}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.42s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.7759])}, {'sample': 'IXI131_0', 'metric1': tensor([0.7910])}, {'sample': 'IXI137_0', 'metric1': tensor([0.7603])}]}
Epoch 2/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.90it/s]


{'mean_loss': 0.2408290064583222}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.48s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8659])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8486])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8456])}]}
Epoch 3/25


100%|███████████████████████████████████████████| 48/48 [00:26<00:00,  1.85it/s]


{'mean_loss': 0.22912022875001034}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.47s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8721])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8519])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8485])}]}
Epoch 4/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.89it/s]


{'mean_loss': 0.2275803073619803}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.52s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8816])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8557])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8507])}]}
Epoch 5/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.91it/s]


{'mean_loss': 0.22468908845136562}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.47s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8850])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8585])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8509])}]}
Epoch 6/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.90it/s]


{'mean_loss': 0.22209235777457556}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.48s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8899])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8630])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8511])}]}
Epoch 7/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.89it/s]


{'mean_loss': 0.22164262427637973}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.49s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8879])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8612])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8506])}]}
Epoch 8/25


100%|███████████████████████████████████████████| 48/48 [00:24<00:00,  1.93it/s]


{'mean_loss': 0.21811112202703953}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.48s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8892])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8593])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8498])}]}
Epoch 9/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.90it/s]


{'mean_loss': 0.21747004985809326}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.51s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8900])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8586])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8500])}]}
Epoch 10/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.90it/s]


{'mean_loss': 0.21624622711290917}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.48s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8910])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8580])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8503])}]}
Epoch 11/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.86it/s]


{'mean_loss': 0.2171153969441851}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.51s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8893])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8568])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8502])}]}
Epoch 12/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.87it/s]


{'mean_loss': 0.21570382360368967}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.49s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8907])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8582])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8508])}]}
Epoch 13/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.88it/s]


{'mean_loss': 0.21569105765471855}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.50s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8920])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8581])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8507])}]}
Epoch 14/25


100%|███████████████████████████████████████████| 48/48 [00:26<00:00,  1.83it/s]


{'mean_loss': 0.21336474207540354}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.50s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8929])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8582])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8506])}]}
Epoch 15/25


100%|███████████████████████████████████████████| 48/48 [00:26<00:00,  1.84it/s]


{'mean_loss': 0.21209395055969557}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.56s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8932])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8574])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8507])}]}
Epoch 16/25


100%|███████████████████████████████████████████| 48/48 [00:26<00:00,  1.82it/s]


{'mean_loss': 0.21156283685316643}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.56s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8944])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8582])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8506])}]}
Epoch 17/25


100%|███████████████████████████████████████████| 48/48 [00:26<00:00,  1.83it/s]


{'mean_loss': 0.21162699162960052}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.59s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8945])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8573])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8508])}]}
Epoch 18/25


100%|███████████████████████████████████████████| 48/48 [00:26<00:00,  1.81it/s]


{'mean_loss': 0.21082745771855116}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.54s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8946])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8570])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8506])}]}
Epoch 19/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.86it/s]


{'mean_loss': 0.2101504923775792}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.51s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8957])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8579])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8506])}]}
Epoch 20/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.91it/s]


{'mean_loss': 0.20978030376136303}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.47s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8958])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8575])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8504])}]}
Epoch 21/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.89it/s]


{'mean_loss': 0.2108930175503095}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.52s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8964])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8575])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8503])}]}
Epoch 22/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.90it/s]


{'mean_loss': 0.20993711644162735}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.51s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8970])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8572])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8499])}]}
Epoch 23/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.86it/s]


{'mean_loss': 0.20860299561172724}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.50s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8973])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8576])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8498])}]}
Epoch 24/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.89it/s]


{'mean_loss': 0.20750207174569368}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.49s/it]


{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8979])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8570])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8496])}]}
Epoch 25/25


100%|███████████████████████████████████████████| 48/48 [00:25<00:00,  1.89it/s]


{'mean_loss': 0.20863836631178856}


100%|█████████████████████████████████████████████| 3/3 [00:04<00:00,  1.48s/it]

{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8977])}, {'sample': 'IXI131_0', 'metric1': tensor([0.8563])}, {'sample': 'IXI137_0', 'metric1': tensor([0.8494])}]}





U_Net2d(
  (Maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Maxpool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv1): conv_block(
    (conv): Sequential(
      (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (Conv2): conv_block(
    (conv): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=

In [8]:
#model_name = "Unet3d_16ch"
#model_name = "Unet2d_16ch"
#model_name = "HessNet_isotropic_smartnorm"


#model_name = "HessNet2"
#model_name = "HessNet_anisotropic_smartnorm"
#model_name = "U_HessNet"
#model_name = "ParallelNet_5_lanczos"

In [10]:
#controller.save(f"/home/msst/save_folder/saved_models/{model_name}")

In [10]:
path_to_check = f"/home/msst/save_folder/saved_models/{model_name}"
controller.load(path_to_checkpoint=path_to_check)

In [26]:
N_ep = len(controller.history['train_loss'])

epochs = np.linspace(1, N_ep, N_ep)
loss_train = controller.history['train_loss']
loss_val = controller.history['val_loss']
all_metrics = controller.history['test_quality']

avg_metrics = []
for epoch_metrics in all_metrics:
    metric = []
    for dict_ in epoch_metrics['metrics']:
        metric.append(dict_['metric1'])
    print(sum(metric)/len(metric))
    
# fontsize = 12
# fig, ax = plt.subplots(1, 1, figsize=(8, 6)) 


# twin1 = ax.twinx()
# p1, = ax.plot(epochs, loss_train, label="train loss")
# #p2, = ax.plot(epochs, loss_val, label="validation loss")
# p3, = twin1.plot(epochs[0:], metric, "g-", label="DICE metric")

# ax.set_xlabel("epoch", fontsize=fontsize+3)
# ax.set_ylabel("Loss", fontsize=fontsize+3)
# twin1.set_ylabel("DICE", fontsize=fontsize+3)

# twin1.tick_params(axis='y', colors=p3.get_color())


# #ax.set_xticklabels(np.linspace(0, 50, 11).astype(int), fontsize=fontsize)
# ax.set_xticks(np.linspace(0, N_ep, 6).astype(int))

# #ax.set_yticklabels(np.linspace(0, 12, 7).astype(int), fontsize=fontsize)
# #ax.set_yticks(np.linspace(0, 12, 7))

# #twin1.set_yticklabels(np.round(np.linspace(0, 1, 6),1), fontsize=fontsize)
# twin1.set_yticks(np.linspace(0, 1, 11))

# ax.set_xlim(0, N_ep)
# ax.set_ylim(0, 1)
# twin1.set_ylim(0, 1)

# ax.legend(handles=[p1, p3], fontsize=fontsize)

# plt.savefig(f'/home/msst/{model_name}_learning.jpg', dpi=250)

tensor([0.5349])
tensor([0.7922])
tensor([0.7529])
tensor([0.7998])
tensor([0.8129])
tensor([0.8130])
tensor([0.8174])
tensor([0.8199])
tensor([0.8227])
tensor([0.8252])
tensor([0.8261])
tensor([0.8275])
tensor([0.8279])
tensor([0.8292])
tensor([0.8289])
tensor([0.8291])
tensor([0.8293])
tensor([0.8300])
tensor([0.8304])
tensor([0.8305])
tensor([0.8306])
tensor([0.8309])
tensor([0.8312])
tensor([0.8315])
tensor([0.8318])
tensor([0.8320])
tensor([0.8321])
tensor([0.8322])
tensor([0.8323])
tensor([0.8324])
tensor([0.8324])
tensor([0.8324])
tensor([0.8326])
tensor([0.8326])
tensor([0.8328])
tensor([0.8328])


In [24]:
sample_index = "111"

#path_to_GT = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="vessels")
#GT = tio.ScalarImage(path_to_GT).data
#path_to_brain_mask = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="brain")
#brain_mask = tio.ScalarImage(path_to_brain_mask).data

path_to_vol = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="head")
subject_dict = {'head': tio.ScalarImage(path_to_vol)}
subject = tio.Subject(subject_dict)
subject = tio.transforms.ZNormalization()(subject)

t = time.time()
seg = controller.single_predict(subject, test_settings)
print(time.time() - t)

dt_string = datetime.now().strftime("%d_%m_%Y_%H:%M")
dir_name = os.path.dirname(path_to_vol)

#path_to_save = f"/home/msst/IXI_MRA_work/IXI{sample_index}/{model_name}_{dt_string}.nii.gz"
#seg_to_save = seg[0].numpy()
#save_vol_as_nii(seg_to_save.astype(np.float32), subject.head.affine, path_to_save)

15.855446815490723


**<h3>GLOBAL TEST**

In [9]:
def get_model(model_name):
        if model_name == 'Unet3d_16ch':
            return(U_Net(channels=16))
        elif model_name == 'Unet2d_16ch':
            return(U_Net2d(channels=16))
        elif model_name == 'HessNet':
            return(HessNet(start_scale=[0.8, 0.8, 1.2], device='cuda'))

def run_tests(model_name, path_to_save_models, is2d, test_data_path, epochs):
    if is2d:
        PATCH_SIZE_TRAIN = (512, 512, 1)
        PATCH_SIZE_TEST = (512, 512, 1)
        OVERLAP_TEST = (0, 0, 0)
    else:
        PATCH_SIZE_TRAIN = (64, 64, 64)
        PATCH_SIZE_TEST = (256, 256, 64)
        OVERLAP_TEST = (32, 32, 24)
    
    train_settings  = {
        "patch_shape" : PATCH_SIZE_TRAIN,
        "patches_per_volume" : 64,
        "patches_queue_length" : 1440,
        "batch_size" : 16,
        "num_workers": 4,
        "sampler": "uniform",
    }
    
    test_settings = {
        "patch_shape" : PATCH_SIZE_TEST,
        "overlap_shape" : OVERLAP_TEST,
        "batch_size" : 1,
        "num_workers": 4,
    }
    
    for test in tqdm(os.listdir(test_data_path)):
        dataset = TioDataset(test_data_path + '/' + test,
                  train_settings=train_settings,
                  test_settings=test_settings)
        
        model = get_model(model_name)
        
        controller_config = {
            "loss" : ExponentialLogarithmicLoss(gamma_tversky=0.5, gamma_bce=0.5, lamb=0.5,
                                                freq = 0.1, tversky_alfa=0.5),
            "metric" : DICE_Metric(),
            'device' : 'cuda',
            "model" : model,
            "optimizer_fn" : lambda model: torch.optim.Adam(model.parameters(), lr=0.02),
            "sheduler_fn": lambda optimizer: StepLR(optimizer, step_size=1, gamma=0.9),
            "is2d" : is2d,
            'verbose': False
        }
        controller = Controller(controller_config)
        controller.fit(dataset, epochs)
        controller.save(f"{path_to_save_models}/{model_name}_{test}")

In [10]:
run_tests(model_name="HessNet", 
          path_to_save_models="/home/msst/save_folder/models_for_tests",
          test_data_path = '/home/msst/Documents/medtech/data/HessData_IXI_test',
          is2d=False,
          epochs=25)

  0%|                                                     | 0/2 [00:00<?, ?it/s]

Epoch 1/25
{'mean_loss': 0.5063172777493795}
{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.7086])}, {'sample': 'IXI100_0', 'metric1': tensor([0.6700])}, {'sample': 'IXI020_0', 'metric1': tensor([0.6788])}]}
Epoch 2/25
{'mean_loss': 0.3695343254754941}
{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.7609])}, {'sample': 'IXI100_0', 'metric1': tensor([0.7164])}, {'sample': 'IXI020_0', 'metric1': tensor([0.7280])}]}
Epoch 3/25
{'mean_loss': 0.3396744132041931}
{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8045])}, {'sample': 'IXI100_0', 'metric1': tensor([0.7383])}, {'sample': 'IXI020_0', 'metric1': tensor([0.7617])}]}
Epoch 4/25
{'mean_loss': 0.31643511572231847}
{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8182])}, {'sample': 'IXI100_0', 'metric1': tensor([0.7564])}, {'sample': 'IXI020_0', 'metric1': tensor([0.7801])}]}
Epoch 5/25
{'mean_loss': 0.3181426903853814}
{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8191])}, {'sample': 'IXI1

 50%|█████████████████████▌                     | 1/2 [17:10<17:10, 1030.07s/it]

{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8579])}, {'sample': 'IXI100_0', 'metric1': tensor([0.7935])}, {'sample': 'IXI020_0', 'metric1': tensor([0.8243])}]}
Epoch 1/25
{'mean_loss': 0.5711555108428001}
{'metrics': [{'sample': 'IXI077_0', 'metric1': tensor([0.3790])}, {'sample': 'IXI111_0', 'metric1': tensor([0.7176])}, {'sample': 'IXI115_0', 'metric1': tensor([0.6141])}]}
Epoch 2/25
{'mean_loss': 0.37248051539063454}
{'metrics': [{'sample': 'IXI077_0', 'metric1': tensor([0.4498])}, {'sample': 'IXI111_0', 'metric1': tensor([0.7594])}, {'sample': 'IXI115_0', 'metric1': tensor([0.6468])}]}
Epoch 3/25
{'mean_loss': 0.33731571957468987}
{'metrics': [{'sample': 'IXI077_0', 'metric1': tensor([0.5108])}, {'sample': 'IXI111_0', 'metric1': tensor([0.7810])}, {'sample': 'IXI115_0', 'metric1': tensor([0.6648])}]}
Epoch 4/25
{'mean_loss': 0.3103602963189284}
{'metrics': [{'sample': 'IXI077_0', 'metric1': tensor([0.5096])}, {'sample': 'IXI111_0', 'metric1': tensor([0.7961])}, {'sampl

100%|███████████████████████████████████████████| 2/2 [34:24<00:00, 1032.02s/it]

{'metrics': [{'sample': 'IXI077_0', 'metric1': tensor([0.6621])}, {'sample': 'IXI111_0', 'metric1': tensor([0.8457])}, {'sample': 'IXI115_0', 'metric1': tensor([0.7158])}]}



