In [12]:
import os
import glob
import time
import warnings

import torch
import numpy as np
import nibabel as nib
from nilearn.image import reorder_img, resample_img
import matplotlib.pyplot as plt
from tqdm import tqdm

import plot_image
import transforms
import models
import losses

In [13]:
train_dir = dict()
train_dir['mask'] = r"label_samseg/*"
test_dir = dict()
test_dir['image'] = r'candi_oasis_aseg/raw123/*'
test_dir['mask'] = r'candi_oasis_aseg/label123/*'

label_all = dict()
label_all['synsg'] = (
    2,3,4,5,7,8,10,11,12,13,14,15,16,17,18,24,26,28,31,
    41,42,43,44,46,47,49,50,51,52,53,54,58,60,63,77
)

# dgm_labels = [0, 10, 49, 11, 50, 12, 51, 13, 52, 17, 53, 18, 54]

label_transforms = transforms.Compose([
    transforms.RandomSkullStrip(),
    transforms.LinearDeform(scales=(0.8, 1.2), degrees=(-20, 20), shears=(-0.015, 0.015), trans=(-30, 30)),
#     transforms.NonlinearDeform(max_std=4),
    transforms.NonlinearDeformTio(),
    transforms.RandomCrop(160)
])

image_transforms = transforms.Compose([
    transforms.GMMSample(mean=(0, 255), std=(0, 35)),
    transforms.RandomBiasField(max_std=0.6),
    transforms.Rescale(),
    transforms.GammaTransform(std=0.4),
    transforms.RandomDownSample(max_slice_space=9, alpha=(0.95, 1.05), r_hr=1)
])

warnings.filterwarnings("ignore")

device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
print("device:", device)

device: cuda


In [14]:
def resample_voxel(data_nib, voxelsize, target_shape=None, interpolation='continuous'):
    affine = data_nib.affine
    target_affine = affine.copy()

    factor = np.zeros(3)
    for i in range(3):
        factor[i] = voxelsize[i] / \
            np.sqrt(affine[0, i]**2 + affine[1, i]**2 + affine[2, i]**2)
        target_affine[:3, i] = target_affine[:3, i]*factor[i]

    new_nib = resample_img(data_nib, target_affine=target_affine,
                           target_shape=target_shape, interpolation=interpolation)

    return new_nib

def nib_to_tensor(input_nib, resample='continuous'):
#     input_nib_resp = reorder_img(input_nib, resample=resample)
#     input_nib_resp = resample_voxel(input_nib_resp, (1, 1, 1), interpolation=resample)
    
#     vol = torch.from_numpy(input_nib_resp.get_fdata()).float()[None, None, ...]
    vol = torch.from_numpy(input_nib.get_fdata()).float()[None, None, ...]
    return vol, input_nib_resp.affine

In [15]:
class GetData(torch.utils.data.Dataset):
    def __init__(self, image_dir='', mask_dir='', mode='both'):
        self.image_ffs = glob.glob(image_dir)
        self.mask_ffs = glob.glob(mask_dir)
        self.mask_ffs.sort()
        self.image_ffs.sort()
        
        self.mode = mode

    def __len__(self):
        return len(self.mask_ffs)
    
    def __getitem__(self, index):
        data = dict()
        
        if(self.mode == 'image' or self.mode == 'both'):
            data['image'] = nib.load(self.image_ffs[index])
            
        if(self.mode == 'mask' or self.mode == 'both'):
            data['mask'] = nib.load(self.mask_ffs[index])
            
        return data

In [16]:
# Loss Functions    
class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = torch.nn.functional.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return torch.mean(focal_loss)
        elif self.reduction == 'sum':
            return torch.sum(focal_loss)
        else:
            return focal_loss
    
class SoftDiceLossWithLogit(torch.nn.Module):
    def __init__(self, smooth=1e-6):
        super(SoftDiceLossWithLogit, self).__init__()
        self.smooth = smooth
    
    def forward(self, y_pred, y_true):
        y_logit = torch.nn.functional.softmax(y_pred, dim=1)
        intersection = (y_logit * y_true).sum(dim=(2, 3, 4))
        union = y_logit.sum(dim=(2, 3, 4)) + y_true.sum(dim=(2, 3, 4))
        return 1 - torch.mean((2 * intersection + self.smooth) / (union + self.smooth))

In [17]:
train_set = GetData(mask_dir=train_dir['mask'], mode='mask')
valid_set = GetData(test_dir['image'], test_dir['mask'], mode='both')
print(f"Training Dataset has {len(train_set)} Nifti images.")
print(f"Test Dataset has {len(valid_set)} Nifti images.")

pred_labels = label_all['synsg']
learning_rate = 1e-4
model = models.Unet3D(1, len(pred_labels), 24).to(device)
print(f"Model predicts {len(pred_labels)} labels.")
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = SoftDiceLossWithLogit()

Training Dataset has 991 Nifti images.
Test Dataset has 123 Nifti images.
Model predicts 35 labels.


In [18]:
save_dir = "save"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    
bestmodel_path = os.path.join(save_dir, "bestmodel.pth")
checkpoint_path = os.path.join(save_dir, "checkpoint.pth.tar")

if input("Load checkpoint ? [y/n]") == "y":
    checkpoint = torch.load(os.path.join(save_dir, "checkpoint.pth.tar"))
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_step = checkpoint['step']
    loss_list = checkpoint['loss']
    dice_list = checkpoint['dice']
    best_score = checkpoint['best_score']
    print("load checkpoint succesfully")
else:
    start_step = 0
    dice_list = []
    loss_list = []
    best_score = -1
    
print("starting step: ", start_step)

Load checkpoint ? [y/n]n
starting step:  0


In [19]:
def validate(model, dataset):
    def get_dice(mask1, mask2):
        dice = torch.sum(mask1 & mask2) * 2
        dice = dice / (1e-6 + torch.sum(mask1) + torch.sum(mask2))
        return dice
    
    deepgm_to_aseg = [0, 10, 49, 11, 50, 12, 51, 13, 52, 17, 53, 18, 54]
    
    total_scores = 0
    with torch.no_grad():
        for data in tqdm(dataset, desc="validating: "):
            image, image_affine = nib_to_tensor(data["image"], resample='continuous')
            mask, label_affine = nib_to_tensor(data["mask"], resample='nearest')
            
            image = image.to(device)
            mask = mask.to(device)
            pred_mask = torch.argmax(
                torch.nn.functional.softmax(model(image), dim=1)[0, ...],
                dim=0
            )

            pred_scores = 0
            for deepgm in range(1, 12+1):
                pred_scores += get_dice(
                    pred_mask == deepgm,
                    mask == deepgm
                ).item()
                
            total_scores += (pred_scores / 12)
            
    return total_scores / len(dataset)

In [21]:
savestep = 200
steps =  100 * savestep
# savestep = 2
# steps = 2 * savestep

print(f"Number of Steps: {steps}, Learning Rate: {learning_rate}")
print(f"Optimizer: {optimizer}")
print(f"Loss Function: {loss_fn}")
print("Start Training...\n")
torch.cuda.empty_cache()
idx = 0
model.train()
for i in range(steps//savestep):
    total_loss = 0
    pbar = tqdm(range(savestep), desc="training")
    for j in pbar:
        mask, _ = nib_to_tensor(train_set[idx]['mask'], resample="nearest")
        mask = label_transforms(mask.to(device))
        
        label = transforms.split_labels(mask, pred_labels)
        image = image_transforms(mask)
        

        pred = model(image)
        loss = loss_fn(pred, label)
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        idx += 1
        idx = idx if idx < len(train_set) else 0

    with torch.no_grad():
        total_loss /= savestep
        dice_score = validate(model, valid_set)
        
        step = (i+1) * savestep + start_step
        print(f"[{step}/{steps + start_step}] Dice: {dice_score}, Loss: {total_loss}")
        
        if best_score == -1 or dice_score > best_score:
            best_score = dice_score
            torch.save(model, bestmodel_path)
            print("! save best model !")
        
        loss_list.append(total_loss)
        dice_list.append(dice_score)

        # save checkpoint
        torch.save({
            'step': step,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss_list,
            'dice': dice_list,
            'best_score': best_score
        }, checkpoint_path)
        print("=> save checkpoint")

Number of Steps: 20000, Learning Rate: 0.0001
Optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 0
)
Loss Function: SoftDiceLossWithLogit()
Start Training...



training:  12%|█▏        | 24/200 [02:08<15:44,  5.37s/it]


KeyboardInterrupt: 

In [None]:
checkpoint = torch.load(os.path.join(save_dir, "checkpoint.pth.tar"))
losses = checkpoint["loss"]
scores = checkpoint["dice"]
x = np.arange(1, 13000+1, 200)
plt.figure(figsize=(20, 5))
plt.plot(x, loss_list, x , dice_list)
plt.ylim(0, 1)
plt.yticks(np.arange(1.1, step=0.1))
plt.legend(("loss", "dice"))
plt.grid()
plt.show()