In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchio as tio
import torch.nn as nn
from torch.nn import MSELoss, BCELoss
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from datetime import datetime
import sys, os
import time

sys.path.insert(1, '../')
from scripts.utils import get_path
from scripts.load_and_save import save_vol_as_nii
from ml.ClassVesselTrainer import VesselTrainer
from ml.ClassVesselInferenceAgent import VesselInferenceAgent
from ml.ClassTioDataset import TioDataset
from ml.utils import get_total_params, load_pretrainned
from ml.metrics import (F1_BINARY, PRECISION_BINARY,
                        RECALL_BINARY, SPECIFICITY_BINARY,
                        ExponentialLogarithmicLoss)

from ml.models.GenUnet import GenUnet
from ml.models.HessNet_new import HessNet
from ml.models.unet3d import U_Net

In [2]:
N_JOBS = 10
DEVICE = 'cuda'
IS2D = 0


if IS2D:
    PATCH_SIZE_TRAIN = (512, 512, 1)
    PATCH_SIZE_TEST = (512, 512, 1)
    OVERLAP_TEST = (0, 0, 0)
else:
    PATCH_SIZE_TRAIN = (64, 64, 64)
    PATCH_SIZE_TEST = (64, 64, 64)
    OVERLAP_TEST = (4, 4, 4)

In [3]:
metric_functions = {
    #"DICE" : F1_BINARY(),
    #"PR" : PRECISION_BINARY(),
    #"RC" : RECALL_BINARY(),
    #"SP" : SPECIFICITY_BINARY(),
}

runner_params = {
    'device' : DEVICE,
    'metric_functions' : metric_functions,
    "patch_shape" : PATCH_SIZE_TEST,
    "overlap_shape" : OVERLAP_TEST,
    "batch_size" : 16,
    "num_workers": 4,
}

runner = VesselInferenceAgent(runner_params)

In [4]:
LOG_PATH = '/home/msst/save_folder/VesselTrainer_log'
model_name_with_date = 'HessNet_12_05_2024_16:47'
epoch = 10

In [5]:
runner.load_from_trainer_state(
    get_path(f"{LOG_PATH}/{model_name_with_date}/state_dicts", f'state_dict_epoch_{epoch}')
)
print(runner.threshold)

tensor(0.9500)


### Run sample segmentation

In [6]:
sample_index = '111'

path_to_vol = get_path(f"/home/msst/IXI_MRA_work/IXI{sample_index}", key="head")
path_to_save = f"/home/msst/new_segs/IXI{sample_index}"
  
    
subject_dict = {'head': tio.ScalarImage(path_to_vol)}
subject = tio.Subject(subject_dict)
subject = tio.transforms.ZNormalization()(subject)

t = time.time()
seg = runner.single_predict(subject)
print(time.time() - t)

os.makedirs(path_to_save, exist_ok=True)
seg_path_to_save = f'{path_to_save}/{model_name_with_date}.nii.gz'
save_vol_as_nii(seg, subject.head.affine, seg_path_to_save)

3.729891300201416
