In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import os
import gc
from glob import glob
from tqdm import tqdm
from collections import defaultdict
import json
import scipy.ndimage as ndimage
import nrrd
import torchio as tio
import monai
import nibabel as nib
import time

# Dataset

In [2]:
city = 'Beijing_Zang'

modes = ['train', 'test']
total_size = 197
train_size, test_size = 158, 39
num_classes = 3

image_dir = f'../dataset/{city}/MRI'
image_arti_dir = f'../dataset/{city}/MRI_arti'
label_dir = f'../dataset/{city}/Segmentation'
autoencoder_wm_dir = f'../results/SCAE_WM_temp/best_autoencoder.torch'

model_dir = f'../results/{city}_unet_autoencoder_gmwm'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [5]:
seed = 0
random_state = np.random.RandomState(seed=seed)
perm = random_state.permutation(total_size)
perm = {
    'train': perm[:train_size],
    'test': perm[-test_size:],
}

def get_subjects(mode):
    subjects = []
    image_src = image_dir #image_dir if mode == 'train' else image_arti_dir
    image_paths = [sorted(glob(f'{image_src}/*.nii.gz'))[i] for i in perm[mode]]
    for image_path in tqdm(image_paths, desc=mode):
        fn = image_path.split('/')[-1]
        label_path = f'{label_dir}/{fn}'
        subject = tio.Subject(
            image=tio.ScalarImage(image_path),
            label=tio.LabelMap(label_path),
        )
        subjects.append(subject)
    return subjects

def get_transform():
    resample = tio.Compose([
        tio.Resample(2),
        tio.CropOrPad((96, 128, 128)),
    ])
    signal = tio.Compose([ 
        tio.RescaleIntensity(percentiles=(0.1, 99.9), out_min_max=(0, 1)),
    ])
    spatial = tio.Compose([
        tio.RandomAffine(translation=1),
    ])
    remapping = dict()
    for i in range(139):
        # remapping[i] = 1 if (3<=i<=11 or 19<=i<=20 or 25<=i<=32 or 35<=i) else 0
        # remapping[i] = 1 if i in {12, 13, 16, 17} else 0
        remapping[i] = 1 if (3<=i<=11 or 19<=i<=20 or 25<=i<=32 or 35<=i) else 2 if i in {12, 13, 16, 17} else 0
    remapping = tio.RemapLabels(remapping)

    onehot = tio.OneHot(num_classes=num_classes)
    transform = {
        'train': tio.Compose([
            resample,
            spatial,
            signal,
            remapping,
            onehot,
        ]),
        'test': tio.Compose([
            resample,
            signal,
            remapping,
            onehot,
        ]),
    }
    return transform

def get_dataloader(transform):
    dataloader = dict()
    for mode in modes:
        dataloader[mode] = torch.utils.data.DataLoader(
            tio.SubjectsDataset(
                subjects[mode], 
                transform=transform[mode]
            ),
            batch_size=1, 
            num_workers=os.cpu_count(),
            shuffle=(mode == 'train'),
        )
    return dataloader

subjects = {mode: get_subjects(mode) for mode in modes}
transform = get_transform()
dataloaders = get_dataloader(transform)

train: 100%|██████████| 158/158 [00:01<00:00, 81.80it/s]
test: 100%|██████████| 39/39 [00:00<00:00, 87.30it/s]


In [7]:
tio.__version__

'0.18.86'

In [12]:
torch.__version__

'1.13.1'

In [11]:
!python --version

Python 3.8.5


# Model

In [17]:
def convolution(in_channels, out_channels, stride):
    return torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1)

def deconvolution(in_channels, out_channels, stride):
    return torch.nn.ConvTranspose3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, output_padding=1)

def normalization(channel):
    return torch.nn.BatchNorm3d(channel)

def activation():
    return torch.nn.PReLU()

def pooling(kernel_size):
    return torch.nn.MaxPool3d(kernel_size=kernel_size)

def upsampling(scale_factor):
    return torch.nn.Upsample(scale_factor=scale_factor, mode='trilinear', align_corners=True)

class Autoencoder(torch.nn.Module):
    def __init__(self, channels):
        super(Autoencoder, self).__init__()
        self.encoder = torch.nn.Sequential(
            convolution(in_channels=1, out_channels=channels[0], stride=2),
            normalization(channels[0]),
            activation(),
            
            convolution(in_channels=channels[0], out_channels=channels[1], stride=2),
            normalization(channels[1]),
            activation(),
            
            convolution(in_channels=channels[1], out_channels=channels[2], stride=2),
            normalization(channels[2]),
            activation(),

            convolution(in_channels=channels[2], out_channels=channels[3], stride=2),
            normalization(channels[3]),
            activation(),
        )
        self.decoder = torch.nn.Sequential(
            deconvolution(in_channels=channels[3], out_channels=channels[2], stride=2),
            normalization(channels[2]),
            activation(),
            
            deconvolution(in_channels=channels[2], out_channels=channels[1], stride=2),
            normalization(channels[1]),
            activation(),
            
            deconvolution(in_channels=channels[1], out_channels=channels[0], stride=2),
            normalization(channels[0]),
            activation(),
            
            deconvolution(in_channels=channels[0], out_channels=1, stride=2),
            normalization(1),
            torch.nn.ReLU(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

autoencoder_wm = Autoencoder(channels=[64,128,256,512]).to(device)
autoencoder_wm.load_state_dict(torch.load(autoencoder_wm_dir))
for param in autoencoder_wm.parameters():
    param.requires_grad = False

In [5]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.segmentation = monai.networks.nets.UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=4,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            norm=monai.networks.layers.Norm.BATCH,
            dropout=0.3,
        )
        self.conv3d = torch.nn.Conv3d(in_channels=4, out_channels=1, kernel_size=1, stride=1, padding=0)

    def forward(self, image):
        prob4 = self.segmentation(image).sigmoid()
        prob1 = self.conv3d(prob4).sigmoid()
        shape = autoencoder.encoder(prob1)
        recon = autoencoder.decoder(shape)
        return prob4, prob1, shape, recon

# Training

In [6]:
def clean(model):
    del model
    gc.collect()
    torch.cuda.empty_cache()

class Metrics:
    def __init__(self):
        self.metrics = defaultdict(list)
        self.names = ['loss', 'loss_segm', 'loss_shape', 'loss_recon']
    def log(self, mode, epoch, values):
        for name, value in zip(self.names, values):
            self.metrics[(mode, epoch, name)].append(value.item())
    def show(self, mode, epoch):
        print()
        for name in self.names:
            mean = np.mean(self.metrics[(mode, epoch, name)])
            print(f'{mode} {name}: {mean}')
            if name == 'loss':
                mean_loss = mean
        return mean_loss
'''
def predict(dataloaders):
    mode = 'test'
    model = Model().to(device)
    model.load_state_dict(torch.load(f'{model_dir}/best_model.torch'))
    model.eval()

    for i, subject in enumerate(tqdm(dataloaders[mode])):
        image = subject['image'][tio.DATA].to(device)
        label = subject['label'][tio.DATA].to(device)
        prob, mask, shape, recon = model(image)
        
        nrrd.write(f'{model_dir}/true{i}.nrrd', label[0, 0].detach().cpu().numpy())
        nrrd.write(f'{model_dir}/pred{i}.nrrd', mask[0, 0].detach().cpu().numpy())
        nrrd.write(f'{model_dir}/recon{i}.nrrd', recon[0, 0].detach().cpu().numpy())

    !rm file.zip
    !zip -r file.zip $model_dir

    del model
    gc.collect()
    torch.cuda.empty_cache()
'''
def train(model, lambdas, n_epochs, dataloaders, learning_rate):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.2)

    loss_dice = monai.losses.DiceLoss()
    loss_mse = torch.nn.MSELoss()
    
    metric = monai.metrics.DiceMetric(reduction='mean_batch')
    metrics = Metrics()
    best_val_loss = np.Inf

    tol, tol50 = 0, 0
    for epoch in range(1, n_epochs+1):
        t0 = time.time()
        print(f'\nEpoch {epoch}/{n_epochs}')
        for mode in modes:
            if mode == 'train':
                model.train()
            else:
                model.eval()
            
            for subject in dataloaders[mode]:
                image = subject['image'][tio.DATA].to(device)
                label = subject['label'][tio.DATA].to(device).float()
                label1, label4 = 1 - label[:, :1], label[:, 1:]
                
                prob4, prob1, shape, recon = model(image)
                loss_segm4 = loss_dice(prob4, label4)
                loss_segm1 = loss_dice(prob1, label1)
                loss_segm = loss_segm4 + loss_segm1
                loss_shape = loss_mse(shape, autoencoder.encoder(label1))
                loss_recon = torch.nn.ReLU()(loss_dice(recon, label1) - 0.0517497765712249)

                loss = lambdas[0] * loss_segm + lambdas[1] * loss_shape + lambdas[2] * loss_recon

                if mode == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                
                metric((prob4 > 0.5).float(), label4)
                metrics.log(mode, epoch, [loss, loss_segm, loss_shape, loss_recon])

            dsc = metric.aggregate().tolist()
            metric.reset()
            mean_loss = metrics.show(mode, epoch)
            print(f'{mode} unet DSC: {dsc}')
        
        if mean_loss <= best_val_loss:
            best_val_loss = mean_loss
            best_epoch = epoch
            tol = 0
            tol50 = 0
        else:
            tol += 1
            tol50 += 1
        print(f'Best val loss: {best_val_loss}')
        
        if tol == 10:
            scheduler.step()
            print('Validation loss stopped to decrease for 10 epochs (LR /= 5).')
            tol = 0
        
        time_elapsed = time.time() - t0
        print(f'Time: {time_elapsed}\n')

        if tol50 == 50:
            print('Validation loss stopped to decrease for 50 epochs. Training terminated.')
            break
    
    print(f'Best epoch: {best_epoch}')

In [7]:
lambdas = (1, 0, 0)
if 'model' in globals(): clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=250, dataloaders=dataloaders, learning_rate=1e-3)


Epoch 1/250

train loss: 1.9797150122968457
train loss_segm: 1.9797150122968457
train loss_shape: 0.9787265115146395
train loss_recon: 0.9399794027020659
train unet DSC: [0.07165322452783585, 0.07009866088628769, 9.672767191659659e-06, 0.004965457133948803]

test loss: 1.8971243760524652
test loss_segm: 1.8971243760524652
test loss_shape: 1.0499779964104676
test loss_recon: 0.9148685611211337
test unet DSC: [0.3220730125904083, 0.3492332994937897, 1.029572194966022e-05, 0.010743088088929653]
Best val loss: 1.8971243760524652
Time: 58.37938618659973


Epoch 2/250

train loss: 1.7802003584330595
train loss_segm: 1.7802003584330595
train loss_shape: 0.9393068714232384
train loss_recon: 0.9120554244970973
train unet DSC: [0.3479003310203552, 0.3744296431541443, 1.28859528558678e-05, 0.2490081489086151]

test loss: 1.6428748980546608
test loss_segm: 1.6428748980546608
test loss_shape: 0.8552038639019697
test loss_recon: 0.9094769771282489
test unet DSC: [0.3800884485244751, 0.46923828125, 

In [8]:
lambdas = (1, 0.1, 0.1)
if 'model' in globals(): clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=250, dataloaders=dataloaders, learning_rate=1e-3)


Epoch 1/250

train loss: 2.1087590772894362
train loss_segm: 1.995371024819869
train loss_shape: 0.19677230866649484
train loss_recon: 0.9371084167232996
train unet DSC: [0.0003391536301933229, 0.0005034269415773451, 4.100931255379692e-05, 0.0006307260482572019]

test loss: 2.100265710781782
test loss_segm: 1.995046520844484
test loss_shape: 0.11667958838053238
test loss_recon: 0.9355125778760666
test unet DSC: [0.0001986196730285883, 0.0004923303495161235, 2.2826770873507485e-06, 0.0008464180282317102]
Best val loss: 2.100265710781782
Time: 57.82144379615784


Epoch 2/250

train loss: 2.0908771496784837
train loss_segm: 1.9949659780610967
train loss_shape: 0.16490313292870037
train loss_recon: 0.7942084360726273
train unet DSC: [0.0004432617861311883, 0.0004593630146700889, 3.7360787246143445e-05, 0.0012333305785432458]

test loss: 2.0738546603765244
test loss_segm: 1.9949707709825957
test loss_shape: 0.16495221127302218
test loss_recon: 0.6238865210459783
test unet DSC: [0.000660398

In [9]:
lambdas = (1, 0.01, 0.1)
if 'model' in globals(): clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=250, dataloaders=dataloaders, learning_rate=1e-3)


Epoch 1/250

train loss: 2.0819810161107704
train loss_segm: 1.9950085184242152
train loss_shape: 0.4429322841423976
train loss_recon: 0.8254318154310878
train unet DSC: [0.0029797807801514864, 0.005523730535060167, 0.00018772351904772222, 0.00043243003892712295]

test loss: 2.0631533830593796
test loss_segm: 1.994207168236757
test loss_shape: 0.3576461122586177
test loss_recon: 0.6536975778066195
test unet DSC: [0.004681847989559174, 0.011594888754189014, 0.0001874046283774078, 0.0003627249097917229]
Best val loss: 2.0631533830593796
Time: 58.901795625686646


Epoch 2/250

train loss: 2.0470866387403466
train loss_segm: 1.9843992628628695
train loss_shape: 0.31069158357155474
train loss_recon: 0.5958046203927149
train unet DSC: [0.007809534668922424, 0.08312167227268219, 0.00018983325571753085, 0.00039585138438269496]

test loss: 1.9516642124224932
test loss_segm: 1.8990147205499501
test loss_shape: 0.2931606249931531
test loss_recon: 0.49717885102981174
test unet DSC: [0.01948983408

In [None]:
lambdas = (1, 0.1, 0.01)
if 'model' in globals(): clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=250, dataloaders=dataloaders, learning_rate=1e-3)


Epoch 1/250

train loss: 2.02580789221993
train loss_segm: 1.9961237688607807
train loss_shape: 0.20212556308583368
train loss_recon: 0.9471587489677381
train unet DSC: [0.002621248597279191, 0.0036599356681108475, 0.0003427632909733802, 0.00010464298975421116]

test loss: 2.017950143569555
test loss_segm: 1.9958823827596812
test loss_shape: 0.1261158463282463
test loss_recon: 0.9456169070341648
test unet DSC: [0.0025205116253346205, 0.003937106113880873, 0.00029245304176583886, 0.00016121016233228147]
Best val loss: 2.017950143569555
Time: 58.691466093063354


Epoch 2/250

train loss: 2.0180559580839135
train loss_segm: 1.9958300190635874
train loss_shape: 0.12788370044171055
train loss_recon: 0.9437563585329659
train unet DSC: [0.002219975693151355, 0.003512290772050619, 0.00021634955191984773, 0.00026937341317534447]

test loss: 2.0172505256457205
test loss_segm: 1.995748153099647
test loss_shape: 0.12021416234664428
test loss_recon: 0.9480951061615577
test unet DSC: [0.00269799237

In [None]:
lambdas = (1, 0.01, 0.01)
if 'model' in globals(): clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=250, dataloaders=dataloaders, learning_rate=1e-3)

In [None]:
lambdas = (1, 0.01, 0.001)
if 'model' in globals(): clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=250, dataloaders=dataloaders, learning_rate=1e-3)

In [None]:
lambdas = (1, 0.001, 0.01)
if 'model' in globals(): clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=250, dataloaders=dataloaders, learning_rate=1e-3)

In [None]:
# test unet DSC: [0.7999164462089539, 0.8284406661987305, 0.8151174187660217, 0.7770190238952637]