In [1]:
%load_ext autoreload
%autoreload 2

<h3> Загрузка библиотек

In [2]:
import os
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [3]:
from scripts.load_and_save import (get_dcm_info, get_dcm_vol, vox_size2affine,
                                   save_vol_as_nii, load_sample_data)
from scripts.load_and_save import load_nii_vol, save_vol_as_nii, load_sample_data

from ml.models.unet3d import U_Net
from ml.models.rog import ROG

from ml.utils import get_total_params, save_model, load_pretrainned
from ml.dataset import preprocess_dataset, HVB_Dataset, norm_vol
from ml.trainer import Trainer
from ml.losses import ExponentialLogarithmicLoss, WeightedExpBCE, TverskyLoss, DiceMetric

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


<h3> Создание экзепляра класса датасет

In [5]:
train_dataset_settings = {
    "data_dir" : "/home/msst/Documents/medtech/brain_seg_dataset",
    "patch_shape" : (64, 64, 64),
    "number_of_patches" : 10,
    "mode": "train",
    "RAM_samples" : True 
}
patch_data_df, sample_data_df = preprocess_dataset(train_dataset_settings)
train_dataset = HVB_Dataset(train_dataset_settings)

In [6]:
val_dataset_settings = {
    "data_dir" : "/home/msst/Documents/medtech/brain_seg_dataset",
    "patch_shape" : (128, 128, 128),
    "mode": "eval",
    "RAM_samples" : train_dataset_settings["RAM_samples"] 
}
val_dataset = HVB_Dataset(val_dataset_settings)

In [7]:
train_loader_params = {"batch_size": 5,
                 "shuffle": True,
                 "num_workers": 6
                }

val_loader_params = {"batch_size": 1,
                     "shuffle": False,
                     "num_workers": 6
                    }

train_dataloader = DataLoader(train_dataset, **train_loader_params)
val_dataloader = DataLoader(val_dataset, **val_loader_params)

<h3> Создание экземпляра модели

In [8]:
rog_params = {
    'classes': 1,
    'modalities': 1,
    'strides': [[2, 2, 1], [2, 2, 1], [2, 2, 2]],
}
#model = ROG(rog_params)

In [9]:
model = U_Net()

In [10]:
print('Number of parameters: {}'.format(get_total_params(model)))

Number of parameters: 103536449


<h3> Обучение модели

In [11]:
#loss_fn = nn.BCELoss(reduction='mean')
loss_fn = ExponentialLogarithmicLoss(gamma_tversky = 0.5, gamma_bce = 0.5, lamb=1, freq = 0.001)
loss_fn.weighted_bce_loss.bce_weight = 1000
print(loss_fn.weighted_bce_loss.bce_weight)
#loss_fn = WeightedExpBCE(0.5)
#loss_fn = TverskyLoss(0.5)
metric_fn = DiceMetric()

trainer_config = {
    'n_epochs': 10,
    "loss" : loss_fn,
    "metric" : metric_fn,
    'device' : device,
    'lr': 3e-4
}
trainer = Trainer(trainer_config)

1000


In [13]:
model = trainer.fit(model, train_dataloader, val_dataloader)

Epoch 1/10


100%|█████████████████████████████████████████████| 6/6 [00:08<00:00,  1.49s/it]


{'loss': 2476.713094075521}


 33%|███████████████                              | 1/3 [00:02<00:05,  2.89s/it]

torch.Size([1, 1, 512, 512, 256])


 67%|██████████████████████████████               | 2/3 [00:04<00:01,  1.87s/it]

torch.Size([1, 1, 512, 512, 384])


100%|█████████████████████████████████████████████| 3/3 [00:05<00:00,  1.74s/it]


torch.Size([1, 1, 512, 512, 512])
Epoch 2/10


 33%|███████████████                              | 2/6 [00:05<00:11,  2.79s/it]


KeyboardInterrupt: 

In [42]:
model_name = "Unet_wBCE_100"
#trainer.save("/home/msst/repo/MSRepo/VesselSegmentation/saved_models/" + model_name)

In [43]:
model.load_state_dict(torch.load("/home/msst/repo/MSRepo/VesselSegmentation/saved_models/" + model_name)["model_state_dict"])

<All keys matched successfully>

<h3> Сегментация с помощью обученной модели
    

In [28]:
test_data = "P12_CTA(no_brain)"
head_vol = norm_vol(dataset.RAM_samples[test_data]['head'])
vessels_vol = dataset.RAM_samples[test_data]['vessels']
affine = dataset.RAM_samples[test_data]['affine']
print(head_vol.shape)
print(vessels_vol.shape)

(512, 512, 256)
(512, 512, 256)


In [29]:
def np2torch(np_arr):
        return(torch.tensor(np_arr).unsqueeze(0).unsqueeze(0))


def seg_by_patch(model, head_tensor_5_dim, device, patch_shape=(64, 64, 64), thresh=0.5):
    ps = patch_shape
    model.to(device)
    vol_shape = head_tensor_5_dim.shape
    s1 = vol_shape[2]//ps[0]#+1
    s2 = vol_shape[3]//ps[1]#+1
    s3 = vol_shape[4]//ps[2]#+1
    
    seg = np.zeros_like(head_tensor_5_dim[0, 0])
    with torch.no_grad():
        model.eval()
        for i in range(s1):
            for j in range(s2):
                for k in range(s3):
                    patch = head_tensor_5_dim[:,
                                              :,
                                              i*ps[0]:(i+1)*ps[0],
                                              j*ps[1]:(j+1)*ps[1],
                                              k*ps[2]:(k+1)*ps[2]].to(device)
                    seg[i*ps[0]:(i+1)*ps[0],
                        j*ps[1]:(j+1)*ps[1],
                        k*ps[2]:(k+1)*ps[2]] = model(patch)[0].cpu()
    
    seg[seg<thresh] = 0
    seg[seg>0] = 1
    return(seg)

In [32]:
vessels_seg = seg_by_patch(model, np2torch(head_vol), device, patch_shape=(128, 128, 128), thresh=0.1)
#vessels_seg = seg_by_vol(model, np2torch(head_vol), 'cpu')

In [33]:
print(vessels_vol.sum())
print(vessels_seg.sum())

78585.0
0.0


In [23]:
data_dir = "seg_data/P12_CTA"
if not os.path.exists(data_dir):
    os.mkdir(data_dir)


path_to_save_vessels = data_dir + '/' + model_name + '.nii.gz'
save_vol_as_nii(vessels_seg, affine, path_to_save_vessels)

In [4]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
ideal = torch.tensor([[[0, 0, 0],
                       [0, 0, 0],
                       [0, 0, 0]],
                      [[0, 0, 0],
                       [0, 1.0, 0],
                       [0, 0, 0]],
                      [[0, 0, 0],
                       [0, 0, 0],
                       [0, 0, 0]]]).unsqueeze(0).unsqueeze(0)
g_seg = torch.tensor([[[0, 0, 0],
                       [0, 0.1, 0],
                       [0, 0, 0]],
                      [[0, 0.1, 0],
                       [0.1, 0.9, 0.1],
                       [0, 0.1, 0]],
                      [[0, 0, 0],
                       [0, 0.1, 0],
                       [0, 0, 0]]]).unsqueeze(0).unsqueeze(0)
b_seg = torch.rand((1, 1, 3, 3, 3))

In [6]:
from ml.losses import ExponentialLogarithmicLoss, DiceLoss, TverskyLoss, WeightedExpBCE

In [7]:
loss_fn = ExponentialLogarithmicLoss(gamma_tversky = 0.3, gamma_bce = 0.3,
                                     lamb=0.5, freq = 1/27)
print(loss_fn(ideal, g_seg))
print(loss_fn(ideal, b_seg))

tensor(0.4348)
tensor(1.1514)


In [8]:
loss_fn = TverskyLoss(0.5)
print(loss_fn(ideal, g_seg))
print(loss_fn(ideal, b_seg))

tensor(0.2000)
tensor(0.8511)


In [9]:
loss_fn = DiceLoss()
print(loss_fn(ideal, g_seg))
print(loss_fn(ideal, b_seg))

tensor(0.2000)
tensor(0.8511)
