In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [3]:
from scripts.load_and_save import (get_dcm_info, get_dcm_vol, vox_size2affine,
                                   save_vol_as_nii, load_sample_data)
from scripts.load_and_save import load_nii_vol, save_vol_as_nii, load_sample_data

from ml.models.unet3d import U_Net
from ml.models.rog import ROG

from ml.utils import get_total_params, save_model, load_pretrainned
from ml.dataset import preprocess_dataset, HVB_Dataset
from ml.trainer import Trainer
from ml.losses import ComboLoss

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [14]:
PATCH_SHAPE = (64, 64, 64)

In [15]:
dataset_settings = {
    "data_dir" : "/home/msst/Documents/medtech/brain_seg_dataset",
    "patch_shape" : PATCH_SHAPE,
    "number_of_patches" : 128,
    "mode": "train",
    "RAM_samples" : True 
}

In [16]:
patch_data_df, sample_data_df = preprocess_dataset(dataset_settings)

In [17]:
dataset = HVB_Dataset(dataset_settings)

In [18]:
loader_params = {"batch_size": 6,
                 "shuffle": True,
                 "num_workers": 6
                }

train_dataloader = DataLoader(dataset, **loader_params)

In [10]:
rog_params = {
    'classes': 1,
    'modalities': 1,
    'strides': [[2, 2, 1], [2, 2, 1], [2, 2, 2]],
}
model = ROG(rog_params)

In [19]:
model = U_Net()

In [20]:
print('Number of parameters: {}'.format(get_total_params(model)))

Number of parameters: 103536449


In [24]:
#loss_fn = ComboLoss(lamb = 0)
loss_fn = nn.BCELoss(weight=0.1, reduction='mean')
#loss_fn = nn.BCELoss(reduction='mean')
trainer_config = {
    'n_epochs': 10,
    "loss" : loss_fn,
    'device' : device,
    'lr': 3e-4
}
trainer = Trainer(trainer_config)

TypeError: cannot assign 'float' object to buffer 'weight' (torch Tensor or None required)

In [22]:
model = trainer.fit(model, train_dataloader)

Epoch 1/10


100%|███████████████████████████████████████████| 22/22 [00:33<00:00,  1.51s/it]


{'loss': 0.9831557084213604}
Epoch 2/10


100%|███████████████████████████████████████████| 22/22 [00:35<00:00,  1.59s/it]


{'loss': 0.9772604378786954}
Epoch 3/10


100%|███████████████████████████████████████████| 22/22 [00:39<00:00,  1.79s/it]


{'loss': 0.971570071848956}
Epoch 4/10


 45%|███████████████████▌                       | 10/22 [00:19<00:23,  1.95s/it]


KeyboardInterrupt: 

In [15]:
model_name = "test_save_ROG"
trainer.save("/home/msst/repo/MSRepo/VesselSegmentation/saved_models/" + model_name)

In [16]:
model.load_state_dict(torch.load("/home/msst/repo/MSRepo/VesselSegmentation/saved_models/" + model_name)["model_state_dict"])

<All keys matched successfully>

In [17]:
head_vol = dataset.RAM_samples["CT_S5020_uint16"]['head']
vessels_vol = dataset.RAM_samples["CT_S5020_uint16"]['vessels']
affine = dataset.RAM_samples["CT_S5020_uint16"]['affine']
print(head_vol.shape)
print(vessels_vol.shape)

(512, 512, 384)
(512, 512, 384)


In [18]:
def np2torch(np_arr):
        return(torch.tensor(np_arr).unsqueeze(0).unsqueeze(0))


def seg_by_patch(model, head_tensor_5_dim, device, patch_shape=(64, 64, 64), thresh=0.5):
    ps = patch_shape
    model.to(device)
    vol_shape = head_tensor_5_dim.shape
    s1 = vol_shape[2]//ps[0]#+1
    s2 = vol_shape[3]//ps[1]#+1
    s3 = vol_shape[4]//ps[2]#+1
    
    seg = np.zeros_like(head_tensor_5_dim[0, 0])
    with torch.no_grad():
        model.eval()
        for i in range(s1):
            for j in range(s2):
                for k in range(s3):
                    patch = head_tensor_5_dim[:,
                                              :,
                                              i*ps[0]:(i+1)*ps[0],
                                              j*ps[1]:(j+1)*ps[1],
                                              k*ps[2]:(k+1)*ps[2]].to(device)
                    seg[i*ps[0]:(i+1)*ps[0],
                        j*ps[1]:(j+1)*ps[1],
                        k*ps[2]:(k+1)*ps[2]] = model(patch)[0].cpu()
    
    seg[seg<thresh] = 0
    seg[seg>0] = 1
    return(seg)

In [19]:
vessels_seg = seg_by_patch(model, np2torch(head_vol), device, patch_shape=(256, 256, 128), thresh=0.1)
#vessels_seg = seg_by_vol(model, np2torch(head_vol), 'cpu')

2 2 3 12


In [20]:
print(vessels_vol.sum())
print(vessels_seg.sum())

288187.0
226566.0


In [21]:
data_dir = "seg_data/CT_S5020_uint16"
if not os.path.exists(data_dir):
    os.mkdir(data_dir)


path_to_save_vessels = data_dir + '/' + model_name + '.nii.gz'
save_vol_as_nii(vessels_seg, affine, path_to_save_vessels)