In [13]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchio as tio
import torch.nn as nn
from torch.nn import MSELoss, BCELoss
from torch.utils.data import DataLoader
from datetime import datetime
import sys, os
import time

sys.path.insert(1, '../')

from ml.ClassVesselTrainer import VesselTrainer
from ml.ClassVesselInferenceAgent import VesselInferenceAgent
from ml.ClassTioDataset import TioDataset
from ml.utils import get_total_params, load_pretrainned
from ml.metrics import (F1_BINARY, PRECISION_BINARY,
                        RECALL_BINARY, SPECIFICITY_BINARY,
                        ExponentialLogarithmicLoss)
from ml.models.GenUnet import GenUnet

from scripts.utils import get_path
from scripts.load_and_save import save_vol_as_nii

N_JOBS = 6
DEVICE = 'cuda'
IS2D = 0

LOG_PATH = '/home/msst/save_folder/VesselTrainer_log'


if IS2D:
    PATCH_SIZE_TRAIN = (512, 512, 1)
    PATCH_SIZE_TEST = (512, 512, 1)
    OVERLAP_TEST = (0, 0, 0)
else:
    PATCH_SIZE_TRAIN = (64, 64, 64)
    PATCH_SIZE_TEST = (64, 64, 64)
    OVERLAP_TEST = (4, 4, 4)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Train

In [2]:
train_settings  = {
    "patch_shape" : PATCH_SIZE_TRAIN,
    "patches_per_volume" : 32,
    "patches_queue_length" : 1440,
    "batch_size" : 8,
    "num_workers": 4,
    "sampler": "uniform",
}

test_settings = {
    "patch_shape" : PATCH_SIZE_TEST,
    "overlap_shape" : OVERLAP_TEST,
    "batch_size" : 1,
    "num_workers": 4,
}

data_dir = "/home/msst/Documents/medtech/data/HessData_IXI"
dataset = TioDataset(data_dir,
                 train_settings=train_settings,
                 val_settings=None,#val_settings,
                 test_settings=test_settings,)

In [3]:
model_name = "GenUnet"
model = GenUnet(dim=3, in_channels=1, out_channels=1,
               channels=8, depth=1)

print("total_params:", get_total_params(model))

total_params: 21385


In [4]:
dt_string = datetime.now().strftime("%d_%m_%Y_%H:%M")

metric_functions = {
    "DICE" : F1_BINARY(),
    "PR" : PRECISION_BINARY(),
    "RC" : RECALL_BINARY(),
    "SP" : SPECIFICITY_BINARY(),
}

loss_fn =  ExponentialLogarithmicLoss(gamma_tversky=0.5, gamma_bce=0.5, lamb=0.5,
                                      freq = 0.1, tversky_alfa=0.5)
# loss_fn =  ExponentialLogarithmicLoss(gamma_tversky=0.5, gamma_bce=0.5, lamb=0.1,
#                                       freq = 0.1, tversky_alfa=0.2)

trainer_params = {
    'device' : DEVICE,
    "model": model,
    "loss_fn" : loss_fn,
    "optimizer" : torch.optim.Adam(model.parameters(), lr=0.02),
    "scheduler": None,#lambda optimizer: StepLR(optimizer, step_size=5, gamma=0.5)
    "metric_functions" : metric_functions,
    "with_warnings": True,
    "log_path": f"{LOG_PATH}/{model_name}_{dt_string}",
    "is2d" : IS2D,
    'verbose' : True,
    'stoper' : None
}

trainer = VesselTrainer(trainer_params)

Trainer.log_path: /home/msst/save_folder/VesselTrainer_log/GenUnet_26_03_2024_01:35


In [5]:
trainer.load_trainer_state('/home/msst/save_folder/VesselTrainer_log/GenUnet/state_dicts/state_dict_epoch_8')

Trainer.log_path: /home/msst/save_folder/VesselTrainer_log/GenUnet_26_03_2024_01:35


In [6]:
trainer.fit(n_epochs=3,
            train_loader=dataset.train_dataloader,
            test_loader=dataset.test_dataloader)

  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 9: Validation.
 DICE: 0.796 | PR: 0.699 | RC: 0.924 | SP: 0.999, 
threshold: 0.950


  0%|          | 0/72 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

RESULT 10: Validation.
 DICE: 0.853 | PR: 0.828 | RC: 0.881 | SP: 0.999, 
threshold: 0.550


  0%|          | 0/72 [00:00<?, ?it/s]

KeyboardInterrupt


AttributeError: 'VesselTrainer' object has no attribute 'neptune_loger'

## Inference

In [38]:
metric_functions = {
    "DICE" : F1_BINARY(),
    "PR" : PRECISION_BINARY(),
    "RC" : RECALL_BINARY(),
    "SP" : SPECIFICITY_BINARY(),
}

runner_params = {
    'device' : DEVICE,
    'metric_functions' : metric_functions,
    "patch_shape" : PATCH_SIZE_TEST,
    "overlap_shape" : OVERLAP_TEST,
    "batch_size" : 1,
    "num_workers": 4,
}

runner = VesselInferenceAgent(runner_params)

model_name_with_date = "GenUnet_26_03_2024_01:35"
epoch = 10
runner.load_from_trainer_state(
    get_path(f"{LOG_PATH}/{model_name_with_date}/state_dicts", f'state_dict_epoch_{epoch}')
)

In [40]:
runner.threshold = None

In [41]:
sample_index = "300"

path_to_vol = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="head")
subject_dict = {'head': tio.ScalarImage(path_to_vol)}
subject = tio.Subject(subject_dict)
subject = tio.transforms.ZNormalization()(subject)

t = time.time()
seg = runner.single_predict(subject)
print(time.time() - t)

1.4138813018798828


In [42]:
seg.unique()

tensor([1.5237e-35, 2.6731e-35, 3.5477e-35,  ..., 1.0000e+00, 1.0000e+00,
        1.0000e+00])

In [43]:
dir_name = os.path.dirname(path_to_vol)
path_to_save = f"/home/msst/IXI_MRA_work/IXI{sample_index}/{model_name_with_date}.nii.gz"
save_vol_as_nii(seg, subject.head.affine, path_to_save)

In [49]:
import subprocess
import re

path_to_EvaluateSegmentation = '/home/msst/repo/MSRepo/VesselSegmentation/Inference/EvaluateSegmentation'

GT_path = path_to_vol = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="vessels")
SEG_path = path_to_save

command_output = subprocess.run([f"{path_to_EvaluateSegmentation}",
                                    GT_path, SEG_path], stdout=subprocess.PIPE, text=True)
command_output = command_output.stdout.split('\n')

metrics = ["DICE", "AVGDIST", "SPCFTY", "PRCISON", "SNSVTY"]
metric_dict = {}
for metric in metrics:
    for line in command_output:
        if re.search(metric, line):
            metric_dict.update({metric : line.split('\t')[1][2:]})


print(metric_dict)
print()
command_output

{'DICE': '0.814698', 'AVGDIST': '1.107438', 'SPCFTY': '0.999706', 'PRCISON': '0.878064', 'SNSVTY': '0.759863'}



['Similarity:',
 'DICE\t= 0.814698\tDice Coefficient (F1-Measure) ',
 'JACRD\t= 0.687334\tJaccard Coefficient ',
 'AUC\t= 0.879785\tArea under ROC Curve ',
 'KAPPA\t= 0.814220\tCohen Kappa ',
 'RNDIND\t= 0.998084\tRand Index ',
 'ADJRIND\t= 0.813436\tAdjusted Rand Index ',
 'ICCORR\t= 0.831508\tInterclass Correlation ',
 'VOLSMTY\t= 0.927835\tVolumetric Similarity Coefficient ',
 'MUTINF\t= 0.018286\tMutual Information ',
 '',
 'Distance:',
 'HDRFDST\t= 58.949131\tHausdorff Distance (in voxel)',
 'AVGDIST\t= 1.107438\tAverage Hausdorff Distance (in voxel)',
 'MAHLNBS\t= 0.162250\tMahanabolis Distance ',
 'VARINFO\t= 0.015334\tVariation of Information ',
 'GCOERR\t= 0.001758\tGlobal Consistency Error ',
 'PROBDST\t= 0.000892\tProbabilistic Distance ',
 '',
 'Classic Measures:',
 'SNSVTY\t= 0.759863\tSensitivity (Recall, true positive rate) ',
 'SPCFTY\t= 0.999706\tSpecificity (true negative rate) ',
 'PRCISON\t= 0.878064\tPrecision (Confidence) ',
 'FMEASR\t= 0.814698\tF-Measure ',
 'AC

In [None]:
GT_mask_path = path_to_save_masked
SEG_mask_path = path_to_save_GT_masked

command_output = subprocess.run([f"{path_to_EvaluateSegmentation}",
                                    GT_mask_path, SEG_mask_path], stdout=subprocess.PIPE, text=True)
command_output = command_output.stdout.split('\n')
metrics = ["DICE", "AVGDIST", "SNSVTY"]
metric_dict_mask = {}
for metric in metrics:
    for line in command_output:
        if re.search(metric, line):
            metric_dict_mask.update({metric : line.split('\t')[1][2:]})

metric_dict_mask