In [5]:
%load_ext autoreload
%autoreload 2

from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt

import time
from datetime import datetime
from tqdm import tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
import sys
import os
sys.path.append(os.path.dirname('../.'))
sys.path.append(os.path.dirname('../ml/.'))
from scripts.load_and_save import (get_dcm_info, get_dcm_vol, vox_size2affine,
                                   load_nii_vol, save_vol_as_nii, raw2nifti)

#from ml.get_model import get_model
from ml.utils import get_total_params, load_pretrainned, test_model
from ml.ControllerClass import Controller
from ml.tio_dataset import TioDataset
from ml.metrics import (DICE_Metric, JAC_Metric, SN_Metric, SP_Metric,
                        IOU_Metric, ExponentialLogarithmicLoss, WeightedExpBCE)

from ml.models.HessNet import HessNet
from ml.models.unet3d import U_Net
from ml.models.unet2d import U_Net2d
from ml.models.JoB_VS import Network
from ml.transformers_models.UNETR import UNETR
from ml.models.VesselConvs import JustConv, TwoConv

In [7]:
DEVICE = 'cuda'
IS2D = 0

In [8]:
if IS2D:
    PATCH_SIZE_TRAIN = (512, 512, 1)
    PATCH_SIZE_TEST = (512, 512, 1)
    OVERLAP_TEST = (0, 0, 0)
else:
    PATCH_SIZE_TRAIN = (64, 64, 64)
    PATCH_SIZE_TEST = (64, 64, 64)
    OVERLAP_TEST = (4, 4, 4)

train_settings  = {
    "patch_shape" : PATCH_SIZE_TRAIN,
    "patches_per_volume" : 32,
    "patches_queue_length" : 1440,
    "batch_size" : 8,
    "num_workers": 4,
    "sampler": "uniform",
}

test_settings = {
    "patch_shape" : PATCH_SIZE_TEST,
    "overlap_shape" : OVERLAP_TEST,
    "batch_size" : 1,
    "num_workers": 4,
}

data_dir = "/home/msst/Documents/medtech/data/HessData_IXI"
dataset = TioDataset(data_dir,
                 train_settings=train_settings,
                 val_settings=None,#val_settings,
                 test_settings=test_settings,)

In [9]:
#model = HessNet(start_scale=[0.8, 0.8, 1.2], device=DEVICE)
#model = U_Net(channels=16) #Unet_8ch, Unet_16ch
#model = U_Net2d(channels=16) #Unet_8ch, Unet_16ch

# model = UNETR(in_channels=1, out_channels=1, img_size=PATCH_SIZE_TRAIN,
#               feature_size=16, hidden_size=128,
#               mlp_dim=512, num_heads=4,
#               norm_name='batch')

#model = Network(modalities=1, num_classes=1)

#model = JustConv(1, 1)
model = TwoConv(1, 1, 5)

print("total_params:", get_total_params(model))

total_params: 1268


In [8]:
controller_config = {
    "loss" : ExponentialLogarithmicLoss(gamma_tversky=0.5, gamma_bce=0.5, lamb=0.5,
                                        freq = 0.1, tversky_alfa=0.5),
    "metric" : DICE_Metric(),
    'device' : DEVICE,
    "model" : model,
    "optimizer_fn" : lambda model: Adam(model.parameters(), lr=0.02),
    "sheduler_fn": None, #lambda optimizer: StepLR(optimizer, step_size=1, gamma=0.9),
    "is2d" : IS2D,
    'verbose':True,
    'stoper': None
}
controller = Controller(controller_config)

In [9]:
controller.fit(dataset, 45)

Epoch 1/45


100%|███████████████████████████████████████████████| 64/64 [00:17<00:00,  3.73it/s]


{'mean_loss': 1.2777031566947699}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.80s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.4890])}
{'sample': 'IXI080_0', 'metric': tensor([0.5398])}
{'sample': 'IXI020_0', 'metric': tensor([0.4084])}
new best!
Epoch 2/45


100%|███████████████████████████████████████████████| 64/64 [00:17<00:00,  3.75it/s]


{'mean_loss': 1.0383014557883143}


100%|█████████████████████████████████████████████████| 3/3 [00:04<00:00,  1.61s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.5639])}
{'sample': 'IXI080_0', 'metric': tensor([0.5766])}
{'sample': 'IXI020_0', 'metric': tensor([0.4482])}
new best!
Epoch 3/45


100%|███████████████████████████████████████████████| 64/64 [00:16<00:00,  3.80it/s]


{'mean_loss': 0.7895190250128508}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.82s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.6429])}
{'sample': 'IXI080_0', 'metric': tensor([0.6494])}
{'sample': 'IXI020_0', 'metric': tensor([0.5394])}
new best!
Epoch 4/45


100%|███████████████████████████████████████████████| 64/64 [00:15<00:00,  4.03it/s]


{'mean_loss': 0.6527853570878506}


100%|█████████████████████████████████████████████████| 3/3 [00:04<00:00,  1.60s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7119])}
{'sample': 'IXI080_0', 'metric': tensor([0.7227])}
{'sample': 'IXI020_0', 'metric': tensor([0.6571])}
new best!
Epoch 5/45


100%|███████████████████████████████████████████████| 64/64 [00:15<00:00,  4.02it/s]


{'mean_loss': 0.5017053592018783}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.69s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7569])}
{'sample': 'IXI080_0', 'metric': tensor([0.7627])}
{'sample': 'IXI020_0', 'metric': tensor([0.6849])}
new best!
Epoch 6/45


100%|███████████████████████████████████████████████| 64/64 [00:15<00:00,  4.09it/s]


{'mean_loss': 0.4354806449264288}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.72s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7765])}
{'sample': 'IXI080_0', 'metric': tensor([0.7875])}
{'sample': 'IXI020_0', 'metric': tensor([0.7277])}
new best!
Epoch 7/45


100%|███████████████████████████████████████████████| 64/64 [00:15<00:00,  4.16it/s]


{'mean_loss': 0.4129210598766804}


100%|█████████████████████████████████████████████████| 3/3 [00:04<00:00,  1.64s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7853])}
{'sample': 'IXI080_0', 'metric': tensor([0.8023])}
{'sample': 'IXI020_0', 'metric': tensor([0.7742])}
new best!
Epoch 8/45


100%|███████████████████████████████████████████████| 64/64 [00:15<00:00,  4.12it/s]


{'mean_loss': 0.35232304967939854}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.69s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7935])}
{'sample': 'IXI080_0', 'metric': tensor([0.8162])}
{'sample': 'IXI020_0', 'metric': tensor([0.8097])}
new best!
Epoch 9/45


100%|███████████████████████████████████████████████| 64/64 [00:16<00:00,  3.99it/s]


{'mean_loss': 0.3847206812351942}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.69s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.8001])}
{'sample': 'IXI080_0', 'metric': tensor([0.8220])}
{'sample': 'IXI020_0', 'metric': tensor([0.8126])}
new best!
Epoch 10/45


100%|███████████████████████████████████████████████| 64/64 [00:16<00:00,  3.78it/s]


{'mean_loss': 0.358798760920763}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.77s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7990])}
{'sample': 'IXI080_0', 'metric': tensor([0.8237])}
{'sample': 'IXI020_0', 'metric': tensor([0.8290])}
new best!
Epoch 11/45


100%|███████████████████████████████████████████████| 64/64 [00:16<00:00,  3.85it/s]


{'mean_loss': 0.35298960376530886}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.84s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7963])}
{'sample': 'IXI080_0', 'metric': tensor([0.8209])}
{'sample': 'IXI020_0', 'metric': tensor([0.8375])}
new best!
Epoch 12/45


100%|███████████████████████████████████████████████| 64/64 [00:16<00:00,  3.85it/s]


{'mean_loss': 0.3595909043215215}


100%|█████████████████████████████████████████████████| 3/3 [00:04<00:00,  1.66s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7988])}
{'sample': 'IXI080_0', 'metric': tensor([0.8226])}
{'sample': 'IXI020_0', 'metric': tensor([0.8352])}
new best!
Epoch 13/45


100%|███████████████████████████████████████████████| 64/64 [00:16<00:00,  3.86it/s]


{'mean_loss': 0.3275302154943347}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.77s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7946])}
{'sample': 'IXI080_0', 'metric': tensor([0.8196])}
{'sample': 'IXI020_0', 'metric': tensor([0.8404])}
count_without_new_best_test_val: 1
Epoch 14/45


100%|███████████████████████████████████████████████| 64/64 [00:18<00:00,  3.47it/s]


{'mean_loss': 0.3226231914013624}


100%|█████████████████████████████████████████████████| 3/3 [00:04<00:00,  1.54s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7905])}
{'sample': 'IXI080_0', 'metric': tensor([0.8160])}
{'sample': 'IXI020_0', 'metric': tensor([0.8434])}
count_without_new_best_test_val: 2
Epoch 15/45


100%|███████████████████████████████████████████████| 64/64 [00:18<00:00,  3.49it/s]


{'mean_loss': 0.33573169680312276}


100%|█████████████████████████████████████████████████| 3/3 [00:04<00:00,  1.66s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7934])}
{'sample': 'IXI080_0', 'metric': tensor([0.8181])}
{'sample': 'IXI020_0', 'metric': tensor([0.8389])}
count_without_new_best_test_val: 3
Epoch 16/45


100%|███████████████████████████████████████████████| 64/64 [00:16<00:00,  3.90it/s]


{'mean_loss': 0.34523083781823516}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.87s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7912])}
{'sample': 'IXI080_0', 'metric': tensor([0.8154])}
{'sample': 'IXI020_0', 'metric': tensor([0.8401])}
count_without_new_best_test_val: 4
Epoch 17/45


100%|███████████████████████████████████████████████| 64/64 [00:16<00:00,  3.83it/s]


{'mean_loss': 0.3160396716557443}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.74s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7933])}
{'sample': 'IXI080_0', 'metric': tensor([0.8169])}
{'sample': 'IXI020_0', 'metric': tensor([0.8403])}
count_without_new_best_test_val: 5
Epoch 18/45


100%|███████████████████████████████████████████████| 64/64 [00:17<00:00,  3.70it/s]


{'mean_loss': 0.33882583188824356}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.83s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7941])}
{'sample': 'IXI080_0', 'metric': tensor([0.8178])}
{'sample': 'IXI020_0', 'metric': tensor([0.8385])}
count_without_new_best_test_val: 6
Epoch 19/45


100%|███████████████████████████████████████████████| 64/64 [00:16<00:00,  3.91it/s]


{'mean_loss': 0.34077844163402915}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.72s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7924])}
{'sample': 'IXI080_0', 'metric': tensor([0.8162])}
{'sample': 'IXI020_0', 'metric': tensor([0.8383])}
count_without_new_best_test_val: 7
Epoch 20/45


100%|███████████████████████████████████████████████| 64/64 [00:15<00:00,  4.15it/s]


{'mean_loss': 0.3182582794688642}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.67s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7903])}
{'sample': 'IXI080_0', 'metric': tensor([0.8140])}
{'sample': 'IXI020_0', 'metric': tensor([0.8401])}
count_without_new_best_test_val: 8
Epoch 21/45


100%|███████████████████████████████████████████████| 64/64 [00:16<00:00,  3.79it/s]


{'mean_loss': 0.3507692425046116}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.74s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7869])}
{'sample': 'IXI080_0', 'metric': tensor([0.8110])}
{'sample': 'IXI020_0', 'metric': tensor([0.8406])}
count_without_new_best_test_val: 9
Epoch 22/45


100%|███████████████████████████████████████████████| 64/64 [00:15<00:00,  4.03it/s]


{'mean_loss': 0.3016397648025304}


100%|█████████████████████████████████████████████████| 3/3 [00:04<00:00,  1.60s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7882])}
{'sample': 'IXI080_0', 'metric': tensor([0.8128])}
{'sample': 'IXI020_0', 'metric': tensor([0.8397])}
count_without_new_best_test_val: 10
Epoch 23/45


100%|███████████████████████████████████████████████| 64/64 [00:17<00:00,  3.62it/s]


{'mean_loss': 0.31554580805823207}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.81s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7864])}
{'sample': 'IXI080_0', 'metric': tensor([0.8108])}
{'sample': 'IXI020_0', 'metric': tensor([0.8416])}
count_without_new_best_test_val: 11
Epoch 24/45


100%|███████████████████████████████████████████████| 64/64 [00:16<00:00,  3.89it/s]


{'mean_loss': 0.31749737937934697}


100%|█████████████████████████████████████████████████| 3/3 [00:04<00:00,  1.64s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7854])}
{'sample': 'IXI080_0', 'metric': tensor([0.8097])}
{'sample': 'IXI020_0', 'metric': tensor([0.8433])}
count_without_new_best_test_val: 12
Epoch 25/45


100%|███████████████████████████████████████████████| 64/64 [00:16<00:00,  3.97it/s]


{'mean_loss': 0.30775751220062375}


100%|█████████████████████████████████████████████████| 3/3 [00:04<00:00,  1.63s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7877])}
{'sample': 'IXI080_0', 'metric': tensor([0.8117])}
{'sample': 'IXI020_0', 'metric': tensor([0.8420])}
count_without_new_best_test_val: 13
Epoch 26/45


100%|███████████████████████████████████████████████| 64/64 [00:15<00:00,  4.05it/s]


{'mean_loss': 0.309324259404093}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.69s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7859])}
{'sample': 'IXI080_0', 'metric': tensor([0.8101])}
{'sample': 'IXI020_0', 'metric': tensor([0.8435])}
count_without_new_best_test_val: 14
Epoch 27/45


100%|███████████████████████████████████████████████| 64/64 [00:12<00:00,  5.10it/s]


{'mean_loss': 0.3034371023532003}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.05s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7870])}
{'sample': 'IXI080_0', 'metric': tensor([0.8110])}
{'sample': 'IXI020_0', 'metric': tensor([0.8429])}
count_without_new_best_test_val: 15
Epoch 28/45


100%|███████████████████████████████████████████████| 64/64 [00:14<00:00,  4.31it/s]


{'mean_loss': 0.32675103284418583}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.08s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7845])}
{'sample': 'IXI080_0', 'metric': tensor([0.8092])}
{'sample': 'IXI020_0', 'metric': tensor([0.8438])}
count_without_new_best_test_val: 16
Epoch 29/45


100%|███████████████████████████████████████████████| 64/64 [00:15<00:00,  4.17it/s]


{'mean_loss': 0.30972992326132953}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.10s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7871])}
{'sample': 'IXI080_0', 'metric': tensor([0.8110])}
{'sample': 'IXI020_0', 'metric': tensor([0.8437])}
count_without_new_best_test_val: 17
Epoch 30/45


100%|███████████████████████████████████████████████| 64/64 [00:11<00:00,  5.41it/s]


{'mean_loss': 0.29802958597429097}


100%|█████████████████████████████████████████████████| 3/3 [00:04<00:00,  1.63s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7845])}
{'sample': 'IXI080_0', 'metric': tensor([0.8093])}
{'sample': 'IXI020_0', 'metric': tensor([0.8449])}
count_without_new_best_test_val: 18
Epoch 31/45


100%|███████████████████████████████████████████████| 64/64 [00:15<00:00,  4.15it/s]


{'mean_loss': 0.337987374747172}


100%|█████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.91s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7872])}
{'sample': 'IXI080_0', 'metric': tensor([0.8111])}
{'sample': 'IXI020_0', 'metric': tensor([0.8444])}
count_without_new_best_test_val: 19
Epoch 32/45


100%|███████████████████████████████████████████████| 64/64 [00:12<00:00,  5.15it/s]


{'mean_loss': 0.3332180173601955}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.05s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7859])}
{'sample': 'IXI080_0', 'metric': tensor([0.8102])}
{'sample': 'IXI020_0', 'metric': tensor([0.8455])}
count_without_new_best_test_val: 20
Epoch 33/45


100%|███████████████████████████████████████████████| 64/64 [00:11<00:00,  5.69it/s]


{'mean_loss': 0.32659338088706136}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.04s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7867])}
{'sample': 'IXI080_0', 'metric': tensor([0.8111])}
{'sample': 'IXI020_0', 'metric': tensor([0.8450])}
count_without_new_best_test_val: 21
Epoch 34/45


100%|███████████████████████████████████████████████| 64/64 [00:11<00:00,  5.51it/s]


{'mean_loss': 0.32132465299218893}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.04s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7846])}
{'sample': 'IXI080_0', 'metric': tensor([0.8093])}
{'sample': 'IXI020_0', 'metric': tensor([0.8456])}
count_without_new_best_test_val: 22
Epoch 35/45


100%|███████████████████████████████████████████████| 64/64 [00:11<00:00,  5.72it/s]


{'mean_loss': 0.30084664607420564}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.04s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7826])}
{'sample': 'IXI080_0', 'metric': tensor([0.8076])}
{'sample': 'IXI020_0', 'metric': tensor([0.8465])}
count_without_new_best_test_val: 23
Epoch 36/45


100%|███████████████████████████████████████████████| 64/64 [00:11<00:00,  5.65it/s]


{'mean_loss': 0.3100172330159694}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.03s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7839])}
{'sample': 'IXI080_0', 'metric': tensor([0.8091])}
{'sample': 'IXI020_0', 'metric': tensor([0.8464])}
count_without_new_best_test_val: 24
Epoch 37/45


100%|███████████████████████████████████████████████| 64/64 [00:11<00:00,  5.41it/s]


{'mean_loss': 0.3228057858068496}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.01s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7820])}
{'sample': 'IXI080_0', 'metric': tensor([0.8073])}
{'sample': 'IXI020_0', 'metric': tensor([0.8475])}
count_without_new_best_test_val: 25
Epoch 38/45


100%|███████████████████████████████████████████████| 64/64 [00:12<00:00,  5.18it/s]


{'mean_loss': 0.31217886903323233}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.01s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7803])}
{'sample': 'IXI080_0', 'metric': tensor([0.8057])}
{'sample': 'IXI020_0', 'metric': tensor([0.8482])}
count_without_new_best_test_val: 26
Epoch 39/45


100%|███████████████████████████████████████████████| 64/64 [00:11<00:00,  5.36it/s]


{'mean_loss': 0.3108735845889896}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.03s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7798])}
{'sample': 'IXI080_0', 'metric': tensor([0.8053])}
{'sample': 'IXI020_0', 'metric': tensor([0.8487])}
count_without_new_best_test_val: 27
Epoch 40/45


100%|███████████████████████████████████████████████| 64/64 [00:11<00:00,  5.58it/s]


{'mean_loss': 0.3471141008194536}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.02s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7806])}
{'sample': 'IXI080_0', 'metric': tensor([0.8061])}
{'sample': 'IXI020_0', 'metric': tensor([0.8490])}
count_without_new_best_test_val: 28
Epoch 41/45


100%|███████████████████████████████████████████████| 64/64 [00:11<00:00,  5.54it/s]


{'mean_loss': 0.3229111300315708}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.06s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7792])}
{'sample': 'IXI080_0', 'metric': tensor([0.8047])}
{'sample': 'IXI020_0', 'metric': tensor([0.8494])}
count_without_new_best_test_val: 29
Epoch 42/45


100%|███████████████████████████████████████████████| 64/64 [00:11<00:00,  5.46it/s]


{'mean_loss': 0.3149724365212023}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.09s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7800])}
{'sample': 'IXI080_0', 'metric': tensor([0.8056])}
{'sample': 'IXI020_0', 'metric': tensor([0.8494])}
count_without_new_best_test_val: 30
Epoch 43/45


100%|███████████████████████████████████████████████| 64/64 [00:11<00:00,  5.55it/s]


{'mean_loss': 0.332792836939916}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.03s/it]


{'sample': 'IXI052_0', 'metric': tensor([0.7786])}
{'sample': 'IXI080_0', 'metric': tensor([0.8044])}
{'sample': 'IXI020_0', 'metric': tensor([0.8499])}
count_without_new_best_test_val: 31
Epoch 44/45


100%|███████████████████████████████████████████████| 64/64 [00:11<00:00,  5.46it/s]


{'mean_loss': 0.32940768729895353}


100%|█████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.00it/s]


{'sample': 'IXI052_0', 'metric': tensor([0.7771])}
{'sample': 'IXI080_0', 'metric': tensor([0.8032])}
{'sample': 'IXI020_0', 'metric': tensor([0.8499])}
count_without_new_best_test_val: 32
Epoch 45/45


100%|███████████████████████████████████████████████| 64/64 [00:11<00:00,  5.43it/s]


{'mean_loss': 0.3268724805675447}


100%|█████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.05s/it]

{'sample': 'IXI052_0', 'metric': tensor([0.7784])}
{'sample': 'IXI080_0', 'metric': tensor([0.8041])}
{'sample': 'IXI020_0', 'metric': tensor([0.8501])}
count_without_new_best_test_val: 33





TwoConv(
  (conv1): ConvModule(
    (act): ReLU()
    (norm): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv): Conv3d(1, 5, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2), padding_mode=reflect)
  )
  (conv2): ConvModule(
    (act): Sigmoid()
    (norm): BatchNorm3d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv): Conv3d(5, 1, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2), padding_mode=reflect)
  )
)

In [13]:
#model_name = "Unet3d_16ch_11.10_2"
#model_name = "Unet2d_16ch"
#model_name = "HessNet_isotropic_smartnorm"
#model_name = "UNETR"


In [14]:
#controller.save(f"/home/msst/save_folder/saved_models/{model_name}.pth")
#controller.save_weights(f"/home/msst/save_folder/saved_models/{model_name}_weights.pth")

In [16]:
#path_to_check = f"/home/msst/save_folder/saved_models/{model_name}.pth"
#controller.load(path_to_checkpoint=path_to_check)

In [8]:
from scripts.utils import get_path
import torchio as tio

model_name = "TwoConv"
sample_index = "092"

path_to_vol = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="head")
subject_dict = {'head': tio.ScalarImage(path_to_vol)}
subject = tio.Subject(subject_dict)
subject = tio.transforms.ZNormalization()(subject)

t = time.time()
seg = controller.single_predict(subject, test_settings)
print(time.time() - t)

#dt_string = datetime.now().strftime("%d_%m_%Y_%H:%M")
dir_name = os.path.dirname(path_to_vol)

#path_to_save = f"/home/msst/IXI_MRA_work/IXI{sample_index}/{model_name}_{dt_string}.nii.gz"
path_to_save = f"/home/msst/IXI_MRA_work/IXI{sample_index}/{model_name}.nii.gz"
save_vol_as_nii(seg, subject.head.affine, path_to_save)

0.9255907535552979


In [36]:
path_to_GT = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="vessels")
GT = tio.ScalarImage(path_to_GT).data
path_to_brain_mask = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="brain")
brain_mask = tio.ScalarImage(path_to_brain_mask).data

path_to_save_masked = f"/home/msst/IXI_MRA_work/IXI{sample_index}/{model_name}_mask.nii.gz"
path_to_save_GT_masked = f"/home/msst/IXI_MRA_work/IXI{sample_index}/vessels_mask.nii.gz"
save_vol_as_nii(seg*brain_mask, subject.head.affine, path_to_save_masked)
save_vol_as_nii(GT*brain_mask, subject.head.affine, path_to_save_GT_masked)

In [39]:
import subprocess
import re
path_to_EvaluateSegmentation = '/home/msst/repo/MSRepo/VesselSegmentation/Inference/EvaluateSegmentation'

GT_path = path_to_GT
SEG_path = path_to_save
GT_mask_path = path_to_save_masked
SEG_mask_path = path_to_save_GT_masked

command_output = subprocess.run([f"{path_to_EvaluateSegmentation}",
                                    GT_path, SEG_path], stdout=subprocess.PIPE, text=True)
command_output = command_output.stdout.split('\n')
metrics = ["DICE", "AVGDIST", "SNSVTY"]
metric_dict = {}
for metric in metrics:
    for line in command_output:
        if re.search(metric, line):
            metric_dict.update({metric : line.split('\t')[1][2:]})

command_output = subprocess.run([f"{path_to_EvaluateSegmentation}",
                                    GT_mask_path, SEG_mask_path], stdout=subprocess.PIPE, text=True)
command_output = command_output.stdout.split('\n')
metrics = ["DICE", "AVGDIST", "SNSVTY"]
metric_dict_mask = {}
for metric in metrics:
    for line in command_output:
        if re.search(metric, line):
            metric_dict_mask.update({metric : line.split('\t')[1][2:]})

metric_dict, metric_dict_mask

({'DICE': '0.823599', 'AVGDIST': '2.841585', 'SNSVTY': '0.861042'},
 {'DICE': '0.872858', 'AVGDIST': '0.263038', 'SNSVTY': '0.900524'})

**<h3>GLOBAL TEST**

In [3]:
def run_tests(model_name, path_to_save_models, is2d, test_data_path, epochs):
    if is2d:
        PATCH_SIZE_TRAIN = (512, 512, 1)
        PATCH_SIZE_TEST = (512, 512, 1)
        OVERLAP_TEST = (0, 0, 0)
    else:
        PATCH_SIZE_TRAIN = (64, 64, 64)
        PATCH_SIZE_TEST = (256, 256, 64)
        OVERLAP_TEST = (32, 32, 24)
    
    train_settings  = {
        "patch_shape" : PATCH_SIZE_TRAIN,
        "patches_per_volume" : 64,
        "patches_queue_length" : 1440,
        "batch_size" : 16,
        "num_workers": 4,
        "sampler": "uniform",
    }
    
    test_settings = {
        "patch_shape" : PATCH_SIZE_TEST,
        "overlap_shape" : OVERLAP_TEST,
        "batch_size" : 1,
        "num_workers": 4,
    }
    
    for test in tqdm(os.listdir(test_data_path)):
        dataset = TioDataset(test_data_path + '/' + test,
                  train_settings=train_settings,
                  test_settings=test_settings)
        
        model = get_model(model_name)
        
        controller_config = {
            "loss" : ExponentialLogarithmicLoss(gamma_tversky=0.5, gamma_bce=0.5, lamb=0.5,
                                                freq = 0.1, tversky_alfa=0.5),
            "metric" : DICE_Metric(),
            'device' : 'cuda',
            "model" : model,
            "optimizer_fn" : lambda model: torch.optim.Adam(model.parameters(), lr=0.02),
            "sheduler_fn": lambda optimizer: StepLR(optimizer, step_size=1, gamma=0.9),
            "is2d" : is2d,
            'verbose': False
        }
        controller = Controller(controller_config)
        controller.fit(dataset, epochs)
        controller.save(f"{path_to_save_models}/{model_name}_{test}")

In [4]:
#model_name = 'HessUNet2'
model_name = 'Unet3d_16ch'

run_tests(model_name=model_name, 
          path_to_save_models="/home/msst/save_folder/models_for_tests",
          test_data_path = '/home/msst/Documents/medtech/data/HessData_IXI_test',
          is2d=False,
          epochs=25)

  0%|                                                     | 0/2 [00:00<?, ?it/s]

Epoch 1/25
{'mean_loss': 0.7622763651112715}
{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.7776])}, {'sample': 'IXI100_0', 'metric1': tensor([0.7168])}, {'sample': 'IXI020_0', 'metric1': tensor([0.7513])}]}
Epoch 2/25
{'mean_loss': 0.2930100342879693}
{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8368])}, {'sample': 'IXI100_0', 'metric1': tensor([0.7817])}, {'sample': 'IXI020_0', 'metric1': tensor([0.8290])}]}
Epoch 3/25
{'mean_loss': 0.2731123532479008}
{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8429])}, {'sample': 'IXI100_0', 'metric1': tensor([0.7831])}, {'sample': 'IXI020_0', 'metric1': tensor([0.8305])}]}
Epoch 4/25
{'mean_loss': 0.26982324353108805}
{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8570])}, {'sample': 'IXI100_0', 'metric1': tensor([0.7953])}, {'sample': 'IXI020_0', 'metric1': tensor([0.8501])}]}
Epoch 5/25
{'mean_loss': 0.25067211718608934}
{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.8659])}, {'sample': 'IXI

 50%|█████████████████████▌                     | 1/2 [18:16<18:16, 1096.89s/it]

{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.9343])}, {'sample': 'IXI100_0', 'metric1': tensor([0.8491])}, {'sample': 'IXI020_0', 'metric1': tensor([0.9339])}]}
Epoch 1/25
{'mean_loss': 0.6672985454400381}
{'metrics': [{'sample': 'IXI077_0', 'metric1': tensor([0.3137])}, {'sample': 'IXI111_0', 'metric1': tensor([0.3520])}, {'sample': 'IXI115_0', 'metric1': tensor([0.4546])}]}
Epoch 2/25
{'mean_loss': 0.2853224550684293}
{'metrics': [{'sample': 'IXI077_0', 'metric1': tensor([0.7835])}, {'sample': 'IXI111_0', 'metric1': tensor([0.8429])}, {'sample': 'IXI115_0', 'metric1': tensor([0.7320])}]}
Epoch 3/25
{'mean_loss': 0.26316667441278696}
{'metrics': [{'sample': 'IXI077_0', 'metric1': tensor([0.7689])}, {'sample': 'IXI111_0', 'metric1': tensor([0.8603])}, {'sample': 'IXI115_0', 'metric1': tensor([0.7597])}]}
Epoch 4/25
{'mean_loss': 0.2678287575642268}
{'metrics': [{'sample': 'IXI077_0', 'metric1': tensor([0.7987])}, {'sample': 'IXI111_0', 'metric1': tensor([0.8774])}, {'sample

100%|███████████████████████████████████████████| 2/2 [36:29<00:00, 1094.53s/it]

{'metrics': [{'sample': 'IXI077_0', 'metric1': tensor([0.8792])}, {'sample': 'IXI111_0', 'metric1': tensor([0.9370])}, {'sample': 'IXI115_0', 'metric1': tensor([0.8047])}]}



