In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [3]:
from ml.model import U_Net
from ml.utils import get_total_params

from scripts.load_dcm import get_dcm_info, get_dcm_vol, vox_size2affine, save_vol_as_nii

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [5]:
import nibabel as nib
def load_data(path_to_folder):
    path_to_head = path_to_folder + "/head.nii.gz"
    path_to_vessels = path_to_folder + "/vessels.nii.gz"
    path_to_brain = path_to_folder + "/brain.nii.gz"
    
    head_file = nib.load(path_to_head)
    vessels_file = nib.load(path_to_vessels)
    brain_file = nib.load(path_to_brain)
    
    head_vol = np.array(head_file.dataobj, dtype=np.float32)
    vessels_vol = np.array(vessels_file.dataobj, dtype=np.float32)
    brain_vol = np.array(brain_file.dataobj, dtype=np.float32)
    return(head_vol, vessels_vol, brain_vol)

In [6]:
head_vol, vessels_vol, brain_vol = load_data("/home/msst/Documents/medtech/brain_seg_dataset/CT_S5020")

In [7]:
print(head_vol.shape)
print(vessels_vol.shape)
print(brain_vol.shape)

(512, 512, 392)
(512, 512, 392)
(512, 512, 392)


In [8]:
head_vol = torch.tensor(head_vol).unsqueeze(0)
vessels_vol = torch.tensor(vessels_vol).unsqueeze(0)
brain_vol = torch.tensor(brain_vol).unsqueeze(0)
print(head_vol.shape)
print(vessels_vol.shape)
print(vessels_vol.shape)

torch.Size([1, 512, 512, 392])
torch.Size([1, 512, 512, 392])
torch.Size([1, 512, 512, 392])


In [11]:
def made_patch_dataset(head_vol, brain_vol, patch_shape, patches_number):
    vol_shape = head_vol.shape[1:]
    patches = []
    for i in range(patches_number):
        x = np.random.randint(low=0, high=vol_shape[0]-patch_shape[0])
        y = np.random.randint(low=0, high=vol_shape[1]-patch_shape[1])
        z = np.random.randint(low=0, high=vol_shape[2]-patch_shape[2])
        head_patch = head_vol[0,
                              x:x+patch_shape[0],
                              y:y+patch_shape[1],
                              z:z+patch_shape[2]].unsqueeze(0)
        brain_patch = brain_vol[0,
                                x:x+patch_shape[0],
                                y:y+patch_shape[1],
                                z:z+patch_shape[2]].unsqueeze(0)
    
        patches.append(torch.stack((head_patch, brain_patch)))
    return(torch.stack(patches))

In [12]:
patches = made_patch_dataset(head_vol, brain_vol, patch_shape=(64,64,64), patches_number=64)

In [13]:
loader_params = {"batch_size": 6,
                 "shuffle": True,
                 "num_workers": 6
                }


dataloader = DataLoader(patches, **loader_params)

In [14]:
for batch in dataloader:
    print(batch.shape)
    head_batch = batch[:, 0]
    brain_batch = batch[:, 1]
    print(head_batch.shape)
    print(brain_batch.shape)
    break

torch.Size([6, 2, 1, 64, 64, 64])
torch.Size([6, 1, 64, 64, 64])
torch.Size([6, 1, 64, 64, 64])


In [15]:
model = U_Net()

In [16]:
get_total_params(model)

103536449

In [18]:
model = model.to(device)

In [19]:
loss_fn = nn.BCELoss(reduction='mean')
optim = torch.optim.Adam(model.parameters(), lr=3e-4)

In [20]:
for epoch in range(300):
    model.train()
    losses = []
    for batch in dataloader:
        head_batch = batch[:, 0].to(device)
        brain_batch = batch[:, 1].to(device)
        
        optim.zero_grad()
        output = model.forward(head_batch)
        output = output[0]    

        loss_train = loss_fn(output, brain_batch)

        loss_train.backward()
        optim.step()
        losses.append(loss_train.item())
    print(f"epoch: {epoch}", f"loss: {sum(losses)/len(losses)}")
    
model.eval()

epoch: 0 0.5377492281523618
epoch: 1 0.41534620794382965
epoch: 2 0.4029851528731259
epoch: 3 0.34674580801617017
epoch: 4 0.34752940318801184
epoch: 5 0.2634409097107974
epoch: 6 0.34526861933144654
epoch: 7 0.30496888810938055
epoch: 8 0.27077665518630634
epoch: 9 0.2940451773730191
epoch: 10 0.336543780836192
epoch: 11 0.2495629692619497
epoch: 12 0.31375139545310626
epoch: 13 0.23769358748739416
epoch: 14 0.2210525247183713
epoch: 15 0.2513866289095445
epoch: 16 0.2425802539695393
epoch: 17 0.2411855079910972
epoch: 18 0.2651109045202082
epoch: 19 0.19683902236548337
epoch: 20 0.24212597446008163
epoch: 21 0.2190647558732466
epoch: 22 0.22314275259321387
epoch: 23 0.18985793536359613
epoch: 24 0.2023992809382352
epoch: 25 0.16845728863369336
epoch: 26 0.23605426265434784
epoch: 27 0.21794541383331473
epoch: 28 0.19281262836673044
epoch: 29 0.2064010419628837
epoch: 30 0.1396695219657638


KeyboardInterrupt: 

In [32]:
def save_model(model, path_to_save_model, epoch):
    path_to_save_model_ = path_to_save_model + f"/epoch_{epoch}" 
    os.makedirs(path_to_save_model_, exist_ok=True)
    torch.save(model.state_dict(), path_to_save_model_ + "/model") 

In [33]:
#save_model(model, '.', 69)

In [17]:
model.load_state_dict(torch.load("/home/msst/repo/MSRepo/VesselSegmentation/epoch_69/model"))

<All keys matched successfully>

In [18]:
device = "cpu"

In [None]:
model_brain_seg = None
with torch.no_grad():
    model.eval()
    model_brain_seg = model(head_vol.unsqueeze(0).to(device))
    model_brain_seg = model_brain_seg.cpu()

In [None]:
data_dir = "seg_data/CT_S5020"
if not os.path.exists(data_dir):
    os.mkdir(data_dir)

path_to_brain = path_to_folder + "/brain.nii.gz" 
head_file = nib.load(path_to_brain)
    
path_to_save_seg = data_dir + '/model_brain_seg.nii.gz'

save_vol_as_nii(model_brain_seg, head_file.affine, path_to_save_seg)