In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
import torchio as tio
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib

import time
from datetime import datetime
from tqdm import tqdm

In [2]:
import sys
import os
sys.path.append(os.path.dirname('../.'))

from scripts.load_and_save import (get_dcm_info, get_dcm_vol, vox_size2affine,
                                   load_nii_vol, save_vol_as_nii, raw2nifti)
from scripts.utils import print_img, print_imgs, get_path
from scripts.hessian_based import hessian_detect_2016

from ml.get_model import get_model
from ml.utils import get_total_params, load_pretrainned, test_model
from ml.ControllerClass import Controller
from ml.tio_dataset import TioDataset
from ml.metrics import (DICE_Metric, JAC_Metric, SN_Metric, SP_Metric,
                        IOU_Metric, ExponentialLogarithmicLoss)

from ml.models.HessNet import HessBlock, HessNet, HessFeatures, HessUNet, HessUNet2 
from ml.models.unet3d import U_Net
from ml.models.unet2d import U_Net2d

In [5]:
DEVICE = 'cuda'
IS2D = 0

In [6]:
if IS2D:
    PATCH_SIZE_TRAIN = (512, 512, 1)
    PATCH_SIZE_TEST = (512, 512, 1)
    OVERLAP_TEST = (0, 0, 0)
else:
    PATCH_SIZE_TRAIN = (64, 64, 64)
    PATCH_SIZE_TEST = (256, 256, 64)
    OVERLAP_TEST = (32, 32, 24)

train_settings  = {
    "patch_shape" : PATCH_SIZE_TRAIN,
    "patches_per_volume" : 4,
    "patches_queue_length" : 1440,
    "batch_size" : 16,
    "num_workers": 4,
    "sampler": "uniform",
}

test_settings = {
    "patch_shape" : PATCH_SIZE_TEST,
    "overlap_shape" : OVERLAP_TEST,
    "batch_size" : 1,
    "num_workers": 4,
}

data_dir = "/home/msst/Documents/medtech/data/HessData_IXI"
dataset = TioDataset(data_dir,
                 train_settings=train_settings,
                 val_settings=None,#val_settings,
                 test_settings=test_settings)

In [7]:
#model = HessNet(start_scale=[0.8], device=DEVICE) #HessNet_isotropic_3, HessNet_isotropic_4


#model = HessNet(start_scale=[0.8, 0.8, 1.2], device=DEVICE) #HessNet_anisotropic_3, HessNet_anisotropic_4
#model = U_Net(channels=16) #Unet_8ch, Unet_16ch
#model = U_Net2d(channels=16) #Unet_8ch, Unet_16ch

model = HessUNet2(in_channels=1, out_channels=1,
                  channels=16, depth=3)
#model = HessUNet(in_channels=1, out_channels=1, channels=16,
#                 act=nn.Sigmoid(), depth=3)

print("total_params:", get_total_params(model))

  self.HessBlocks.append(HessBlock(start_scale=(0.5+i/2)*torch.tensor(start_scale),


total_params: 1608103


In [8]:
controller_config = {
    "loss" : ExponentialLogarithmicLoss(gamma_tversky=0.5, gamma_bce=0.5, lamb=0.5,
                                        freq = 0.1, tversky_alfa=0.5),
    "metric" : DICE_Metric(),
    'device' : DEVICE,
    "model" : model,
    "optimizer_fn" : lambda model: torch.optim.Adam(model.parameters(), lr=0.02),
    "sheduler_fn": lambda optimizer: StepLR(optimizer, step_size=1, gamma=0.9),
    "is2d" : IS2D,
    'verbose':True
}
controller = Controller(controller_config)

In [9]:
controller.fit(dataset, 25)

Epoch 1/25


100%|█████████████████████████████████████████████| 3/3 [00:07<00:00,  2.53s/it]


{'mean_loss': 1.2988348603248596}


  0%|                                                     | 0/3 [00:05<?, ?it/s]


KeyboardInterrupt: 

In [9]:
model_name = "Unet3d_16ch"
#model_name = "Unet2d_16ch"
#model_name = "HessNet_isotropic_smartnorm"

#model_name = "HessUNet_depth3"

In [9]:
#controller.save(f"/home/msst/save_folder/saved_models/{model_name}")

In [10]:
path_to_check = f"/home/msst/save_folder/saved_models/{model_name}"
controller.load(path_to_checkpoint=path_to_check)

In [11]:
sample_index = "020"

#path_to_GT = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="vessels")
#GT = tio.ScalarImage(path_to_GT).data
#path_to_brain_mask = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="brain")
#brain_mask = tio.ScalarImage(path_to_brain_mask).data

path_to_vol = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="head")
subject_dict = {'head': tio.ScalarImage(path_to_vol)}
subject = tio.Subject(subject_dict)
subject = tio.transforms.ZNormalization()(subject)

t = time.time()
seg = controller.single_predict(subject, test_settings)
print(time.time() - t)

dt_string = datetime.now().strftime("%d_%m_%Y_%H:%M")
dir_name = os.path.dirname(path_to_vol)

path_to_save = f"/home/msst/IXI_MRA_work/IXI{sample_index}/{model_name}_{dt_string}.nii.gz"
seg_to_save = seg[0].numpy()
save_vol_as_nii(seg_to_save.astype(np.float32), subject.head.affine, path_to_save)

3.5003650188446045


**<h3>GLOBAL TEST**

In [3]:
def run_tests(model_name, path_to_save_models, is2d, test_data_path, epochs):
    if is2d:
        PATCH_SIZE_TRAIN = (512, 512, 1)
        PATCH_SIZE_TEST = (512, 512, 1)
        OVERLAP_TEST = (0, 0, 0)
    else:
        PATCH_SIZE_TRAIN = (64, 64, 64)
        PATCH_SIZE_TEST = (256, 256, 64)
        OVERLAP_TEST = (32, 32, 24)
    
    train_settings  = {
        "patch_shape" : PATCH_SIZE_TRAIN,
        "patches_per_volume" : 64,
        "patches_queue_length" : 1440,
        "batch_size" : 16,
        "num_workers": 4,
        "sampler": "uniform",
    }
    
    test_settings = {
        "patch_shape" : PATCH_SIZE_TEST,
        "overlap_shape" : OVERLAP_TEST,
        "batch_size" : 1,
        "num_workers": 4,
    }
    
    for test in tqdm(os.listdir(test_data_path)):
        dataset = TioDataset(test_data_path + '/' + test,
                  train_settings=train_settings,
                  test_settings=test_settings)
        
        model = get_model(model_name)
        
        controller_config = {
            "loss" : ExponentialLogarithmicLoss(gamma_tversky=0.5, gamma_bce=0.5, lamb=0.5,
                                                freq = 0.1, tversky_alfa=0.5),
            "metric" : DICE_Metric(),
            'device' : 'cuda',
            "model" : model,
            "optimizer_fn" : lambda model: torch.optim.Adam(model.parameters(), lr=0.02),
            "sheduler_fn": lambda optimizer: StepLR(optimizer, step_size=1, gamma=0.9),
            "is2d" : is2d,
            'verbose': False
        }
        controller = Controller(controller_config)
        controller.fit(dataset, epochs)
        controller.save(f"{path_to_save_models}/{model_name}_{test}")

In [4]:
#model_name = 'HessUNet2'
model_name = 'Unet3d_16ch'

run_tests(model_name="HessUNet2", 
          path_to_save_models="/home/msst/save_folder/models_for_tests",
          test_data_path = '/home/msst/Documents/medtech/data/HessData_IXI_test',
          is2d=False,
          epochs=25)

  self.HessBlocks.append(HessBlock(start_scale=(0.5+i/2)*torch.tensor(start_scale),


Epoch 1/25
{'mean_loss': 1.0034169976909955}
{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.1961])}, {'sample': 'IXI100_0', 'metric1': tensor([0.1354])}, {'sample': 'IXI020_0', 'metric1': tensor([0.1839])}]}
Epoch 2/25
{'mean_loss': 0.5577186935891708}
{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.3921])}, {'sample': 'IXI100_0', 'metric1': tensor([0.3414])}, {'sample': 'IXI020_0', 'metric1': tensor([0.3726])}]}
Epoch 3/25
{'mean_loss': 0.4549638119836648}
{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.5136])}, {'sample': 'IXI100_0', 'metric1': tensor([0.4602])}, {'sample': 'IXI020_0', 'metric1': tensor([0.4841])}]}
Epoch 4/25
{'mean_loss': 0.40898488151530427}
{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.5588])}, {'sample': 'IXI100_0', 'metric1': tensor([0.5072])}, {'sample': 'IXI020_0', 'metric1': tensor([0.5356])}]}
Epoch 5/25
{'mean_loss': 0.3945464218656222}
{'metrics': [{'sample': 'IXI111_0', 'metric1': tensor([0.6253])}, {'sample': 'IXI1

  0%|                                                     | 0/2 [22:34<?, ?it/s]


KeyboardInterrupt: 