In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchio as tio
import torch.nn as nn
from torch.nn import MSELoss, BCELoss
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from datetime import datetime
import sys, os
import time

sys.path.insert(1, '../')
from scripts.utils import get_path
from scripts.load_and_save import save_vol_as_nii
from ml.ClassVesselTrainer import VesselTrainer
from ml.ClassVesselInferenceAgent import VesselInferenceAgent
from ml.ClassTioDataset import TioDataset
from ml.utils import get_total_params, load_pretrainned
from ml.metrics import (F1_BINARY, PRECISION_BINARY,
                        RECALL_BINARY, SPECIFICITY_BINARY,
                        ExponentialLogarithmicLoss)

from ml.models.GenUnet import GenUnet
from ml.models.HessNet_new import HessNet
from ml.models.unet3d import U_Net


In [2]:
N_JOBS = 10
DEVICE = 'cuda'
IS2D = 0

LOG_PATH = '/home/msst/save_folder/VesselTrainer_log'


if IS2D:
    PATCH_SIZE_TRAIN = (512, 512, 1)
    PATCH_SIZE_TEST = (512, 512, 1)
    OVERLAP_TEST = (0, 0, 0)
else:
    PATCH_SIZE_TRAIN = (64, 64, 64)
    PATCH_SIZE_TEST = (64, 64, 64)
    OVERLAP_TEST = (4, 4, 4)

## Train

In [3]:
train_settings  = {
    "patch_shape" : PATCH_SIZE_TRAIN,
    "patches_per_volume" : 32,
    "patches_queue_length" : 1440,
    "batch_size" : 16,
    "num_workers": 4,
    "sampler": "uniform",
}

test_settings = {
    "patch_shape" : PATCH_SIZE_TEST,
    "overlap_shape" : OVERLAP_TEST,
    "batch_size" : 1,
    "num_workers": 4,
}

data_dir = "/home/msst/Documents/medtech/data/HessData_IXI"
dataset = TioDataset(data_dir,
                 train_settings=train_settings,
                 val_settings=None,#val_settings,
                 test_settings=test_settings,)

In [4]:
# model_name = "HessNet"
# model = GenUnet(dim=3, in_channels=1, out_channels=1,
#                channels=8, depth=1)

model_name = "HessNet"
model = HessNet(in_channels=1,
                out_channels=1,
                n_hess_blocks=4)

# model_name = "Unet3d"
# model = U_Net()

print("total_params:", get_total_params(model))

total_params: 6630


In [5]:
N_epochs = 15

In [6]:
dt_string = datetime.now().strftime("%d_%m_%Y_%H:%M")

metric_functions = {
    "DICE" : F1_BINARY(),
    "PR" : PRECISION_BINARY(),
    "RC" : RECALL_BINARY(),
    "SP" : SPECIFICITY_BINARY(),
}

# loss_fn =  ExponentialLogarithmicLoss(gamma_tversky=0.5, gamma_bce=0.5, lamb=0.5,
#                                       freq = 0.1, tversky_alfa=0.5)
loss_fn =  ExponentialLogarithmicLoss(gamma_tversky=0.5, gamma_bce=0.5, lamb=0.25,
                                      freq = 0.1, tversky_alfa=0.25)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.02, weight_decay=0.01)
sheduler = CosineAnnealingWarmRestarts(
    optimizer,
    T_0=N_epochs//2,
    T_mult=1,
    eta_min=0,
    last_epoch=-1,
)

trainer_params = {
    'device' : DEVICE,
    "model": model,
    "loss_fn" : loss_fn,
    "optimizer" : optimizer,
    "scheduler": sheduler,
    "metric_functions" : metric_functions,
    "with_warnings": True,
    "log_path": f"{LOG_PATH}/{model_name}_{dt_string}",
    "is2d" : IS2D,
    'verbose' : True,
    'stoper' : None
}

trainer = VesselTrainer(trainer_params)

Trainer.log_path: /home/msst/save_folder/VesselTrainer_log/HessNet_12_05_2024_16:47


In [7]:
#trainer.load_trainer_state('/home/msst/save_folder/VesselTrainer_log/GenUnet/state_dicts/state_dict_epoch_8')

In [8]:
trainer.fit(n_epochs=N_epochs,
            train_loader=dataset.train_dataloader,
            test_loader=dataset.test_dataloader)

  0%|          | 0/68 [00:00<?, ?it/s]

Set learning rate: [0.01900968867902419]


  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 1: Validation.
 DICE: 0.633 | PR: 0.475 | RC: 0.948 | SP: 0.997, 
threshold: 0.950


  0%|          | 0/68 [00:00<?, ?it/s]

Set learning rate: [0.016234898018587338]


  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 2: Validation.
 DICE: 0.735 | PR: 0.603 | RC: 0.941 | SP: 0.998, 
threshold: 0.950


  0%|          | 0/68 [00:00<?, ?it/s]

Set learning rate: [0.012225209339563144]


  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 3: Validation.
 DICE: 0.772 | PR: 0.653 | RC: 0.942 | SP: 0.999, 
threshold: 0.950


  0%|          | 0/68 [00:00<?, ?it/s]

Set learning rate: [0.007774790660436856]


  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 4: Validation.
 DICE: 0.750 | PR: 0.618 | RC: 0.951 | SP: 0.998, 
threshold: 0.950


  0%|          | 0/68 [00:00<?, ?it/s]

Set learning rate: [0.0037651019814126654]


  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 5: Validation.
 DICE: 0.773 | PR: 0.651 | RC: 0.951 | SP: 0.999, 
threshold: 0.950


  0%|          | 0/68 [00:00<?, ?it/s]

Set learning rate: [0.0009903113209758097]


  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 6: Validation.
 DICE: 0.778 | PR: 0.656 | RC: 0.953 | SP: 0.999, 
threshold: 0.950


  0%|          | 0/68 [00:00<?, ?it/s]

Set learning rate: [0.02]


  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 7: Validation.
 DICE: 0.778 | PR: 0.655 | RC: 0.956 | SP: 0.999, 
threshold: 0.950


  0%|          | 0/68 [00:00<?, ?it/s]

Set learning rate: [0.01900968867902419]


  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 8: Validation.
 DICE: 0.772 | PR: 0.647 | RC: 0.959 | SP: 0.999, 
threshold: 0.950


  0%|          | 0/68 [00:00<?, ?it/s]

Set learning rate: [0.016234898018587338]


  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 9: Validation.
 DICE: 0.770 | PR: 0.642 | RC: 0.961 | SP: 0.999, 
threshold: 0.950


  0%|          | 0/68 [00:00<?, ?it/s]

Set learning rate: [0.012225209339563144]


  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 10: Validation.
 DICE: 0.778 | PR: 0.653 | RC: 0.961 | SP: 0.999, 
threshold: 0.950


  0%|          | 0/68 [00:00<?, ?it/s]

Set learning rate: [0.007774790660436856]


  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 11: Validation.
 DICE: 0.780 | PR: 0.655 | RC: 0.962 | SP: 0.999, 
threshold: 0.950


  0%|          | 0/68 [00:00<?, ?it/s]

Set learning rate: [0.0037651019814126654]


  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 12: Validation.
 DICE: 0.787 | PR: 0.666 | RC: 0.961 | SP: 0.999, 
threshold: 0.950


  0%|          | 0/68 [00:00<?, ?it/s]

Set learning rate: [0.0009903113209758097]


  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 13: Validation.
 DICE: 0.784 | PR: 0.661 | RC: 0.963 | SP: 0.999, 
threshold: 0.950


  0%|          | 0/68 [00:00<?, ?it/s]

Set learning rate: [0.02]


  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 14: Validation.
 DICE: 0.785 | PR: 0.662 | RC: 0.963 | SP: 0.999, 
threshold: 0.950


  0%|          | 0/68 [00:00<?, ?it/s]

Set learning rate: [0.01900968867902419]


  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 15: Validation.
 DICE: 0.784 | PR: 0.661 | RC: 0.964 | SP: 0.999, 
threshold: 0.950
Finished Training and Validating


## Inference

In [10]:
metric_functions = {
    "DICE" : F1_BINARY(),
    "PR" : PRECISION_BINARY(),
    "RC" : RECALL_BINARY(),
    "SP" : SPECIFICITY_BINARY(),
}

runner_params = {
    'device' : DEVICE,
    'metric_functions' : metric_functions,
    "patch_shape" : PATCH_SIZE_TEST,
    "overlap_shape" : OVERLAP_TEST,
    "batch_size" : 16,
    "num_workers": 4,
}

runner = VesselInferenceAgent(runner_params)

In [34]:
#model_name_with_date = 'HessNet_07_05_2024_00:43'
#model_name_with_date = 'Unet3d_07_05_2024_00:13'
#model_name_with_date = 'Unet3d_12_05_2024_16:26'
#model_name_with_date = 'HessNet_12_05_2024_16:47'

epoch = 10
runner.load_from_trainer_state(
    get_path(f"{LOG_PATH}/{model_name_with_date}/state_dicts", f'state_dict_epoch_{epoch}')
)
print(runner.threshold)

tensor(0.9000)


In [35]:
# model_name_with_date = 'Unet3d_16ch_21.03_2_weights'
# model = U_Net()

# model.load_state_dict(torch.load(f'/home/msst/save_folder/saved_models/{model_name_with_date}.pth'))
# runner.set_model(model)
# model_name_with_date = 'Unet3d_16ch'

In [36]:
#runner.threshold = 0.5

In [37]:
device = 'cuda'
runner.device = device
runner.model.to(device)
1

1

In [38]:
#indexes = ['341', '342', '344',]
indexes = ['160', '161', '162', '163']

#indexes = ['252', '253', '254', '256', '257', '258', '259', '260', '261', '262', '263']


for sample_index in indexes:
    path_to_vol = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="head")
    
    
    subject_dict = {'head': tio.ScalarImage(path_to_vol)}
    subject = tio.Subject(subject_dict)
    subject = tio.transforms.ZNormalization()(subject)
    
    t = time.time()
    seg = runner.single_predict(subject)
    print(time.time() - t)
    
    path_to_save = f"/home/msst/new_segs/IXI{sample_index}"
    os.makedirs(path_to_save, exist_ok=True)
    seg_path_to_save = f'{path_to_save}/{model_name_with_date}.nii.gz'
    save_vol_as_nii(seg, subject.head.affine, seg_path_to_save)

3.3835253715515137
3.437211513519287
3.5923876762390137
3.532599449157715


In [37]:
# path_to_save = f"/home/msst/IXI_MRA_work/IXI{sample_index}/{model_name_with_date}.nii.gz"
#save_vol_as_nii(seg, subject.head.affine, path_to_save)

In [30]:
import subprocess
import re

path_to_EvaluateSegmentation = '/home/msst/repo/MSRepo/VesselSegmentation/Inference/EvaluateSegmentation'

GT_path = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="vessels")
SEG_path = path_to_save

command_output = subprocess.run([f"{path_to_EvaluateSegmentation}",
                                    GT_path, SEG_path], stdout=subprocess.PIPE, text=True)
command_output = command_output.stdout.split('\n')

metrics = ["DICE", "AVGDIST", "SPCFTY", "PRCISON", "SNSVTY"]
metric_dict = {}
for metric in metrics:
    for line in command_output:
        if re.search(metric, line):
            metric_dict.update({metric : line.split('\t')[1][2:]})


print(metric_dict)
print()
command_output

{'DICE': '0.613138', 'AVGDIST': '4.506680', 'SPCFTY': '0.997142', 'PRCISON': '0.474295', 'SNSVTY': '0.866915'}



['Similarity:',
 'DICE\t= 0.613138\tDice Coefficient (F1-Measure) ',
 'JACRD\t= 0.442105\tJaccard Coefficient ',
 'AUC\t= 0.932029\tArea under ROC Curve ',
 'KAPPA\t= 0.611649\tCohen Kappa ',
 'RNDIND\t= 0.993533\tRand Index ',
 'ADJRIND\t= 0.609648\tAdjusted Rand Index ',
 'ICCORR\t= 0.613125\tInterclass Correlation ',
 'VOLSMTY\t= 0.707264\tVolumetric Similarity Coefficient ',
 'MUTINF\t= 0.018736\tMutual Information ',
 '',
 'Distance:',
 'HDRFDST\t= 117.630778\tHausdorff Distance (in voxel)',
 'AVGDIST\t= 4.506680\tAverage Hausdorff Distance (in voxel)',
 'MAHLNBS\t= 0.132073\tMahanabolis Distance ',
 'VARINFO\t= 0.040303\tVariation of Information ',
 'GCOERR\t= 0.004990\tGlobal Consistency Error ',
 'PROBDST\t= 0.002474\tProbabilistic Distance ',
 '',
 'Classic Measures:',
 'SNSVTY\t= 0.866915\tSensitivity (Recall, true positive rate) ',
 'SPCFTY\t= 0.997142\tSpecificity (true negative rate) ',
 'PRCISON\t= 0.474295\tPrecision (Confidence) ',
 'FMEASR\t= 0.613138\tF-Measure ',
 'A