In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import os
import gc
from glob import glob
from tqdm import tqdm
from collections import defaultdict
import json
import scipy.ndimage as ndimage
import nrrd
import torchio as tio
import monai
import nibabel as nib
import time

In [2]:
def show_cuda_memory():
    t = torch.cuda.get_device_properties(0).total_memory
    r = torch.cuda.memory_reserved(0)
    a = torch.cuda.memory_allocated(0)
    f = r-a  # free inside reserved
    print('Total:     {:0.2f} GiB'.format(t / 2**30))
    print('Reserved:  {:0.2f} GiB'.format(r / 2**30))
    print('Allocated: {:0.2f} GiB'.format(a / 2**30))
    print('Free:      {:0.2f} GiB'.format(f / 2**30))

show_cuda_memory()

Total:     11.17 GiB
Reserved:  0.00 GiB
Allocated: 0.00 GiB
Free:      0.00 GiB


# Dataset

In [3]:
city = 'Beijing_Zang'

modes = ['train', 'test']
total_size = 197
train_size, test_size = 158, 39
num_classes = 3

image_dir = f'../dataset/{city}/MRI'
image_arti_dir = f'../dataset/{city}/MRI_arti'
label_dir = f'../dataset/{city}/Segmentation'
autoencoder_gm_dir = f'../results/SCAE_GM_temp2/best_autoencoder.torch'

model_dir = f'../results/{city}_unet_autoencoder_gmwm'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
seed = 0
random_state = np.random.RandomState(seed=seed)
perm = random_state.permutation(total_size)
perm = {
    'train': perm[:train_size],
    'test': perm[-test_size:],
}

def get_subjects(mode):
    subjects = []
    image_src = image_dir #image_dir if mode == 'train' else image_arti_dir
    image_paths = [sorted(glob(f'{image_src}/*.nii.gz'))[i] for i in perm[mode]]
    for image_path in tqdm(image_paths, desc=mode):
        fn = image_path.split('/')[-1]
        label_path = f'{label_dir}/{fn}'
        subject = tio.Subject(
            image=tio.ScalarImage(image_path),
            label=tio.LabelMap(label_path),
        )
        subjects.append(subject)
    return subjects

def get_transform():
    resample = tio.Compose([
        tio.Resample(2),
        tio.CropOrPad((96, 128, 128)),
    ])
    signal = tio.Compose([ 
        tio.RescaleIntensity(percentiles=(0.1, 99.9), out_min_max=(0, 1)),
    ])
    # spatial = tio.Compose([
    #     tio.RandomAffine(translation=1),
    # ])
    remapping = dict()
    for i in range(139):
        # remapping[i] = 1 if (3<=i<=11 or 19<=i<=20 or 25<=i<=32 or 35<=i) else 0
        # remapping[i] = 1 if i in {12, 13, 16, 17} else 0
        remapping[i] = 1 if (3<=i<=11 or 19<=i<=20 or 25<=i<=32 or 35<=i) else 2 if i in {12, 13, 16, 17} else 0
    remapping = tio.RemapLabels(remapping)

    onehot = tio.OneHot(num_classes=num_classes)
    transform = {
        'train': tio.Compose([
            resample,
            # spatial,
            signal,
            remapping,
            onehot,
        ]),
        'test': tio.Compose([
            resample,
            signal,
            remapping,
            onehot,
        ]),
    }
    return transform

def get_dataloader(transform):
    dataloader = dict()
    for mode in modes:
        dataloader[mode] = torch.utils.data.DataLoader(
            tio.SubjectsDataset(
                subjects[mode], 
                transform=transform[mode]
            ),
            batch_size=1, 
            num_workers=os.cpu_count(),
            shuffle=(mode == 'train'),
        )
    return dataloader

subjects = {mode: get_subjects(mode) for mode in modes}
transform = get_transform()
dataloaders = get_dataloader(transform)

train: 100%|██████████| 158/158 [00:01<00:00, 88.90it/s]
test: 100%|██████████| 39/39 [00:00<00:00, 92.85it/s]


# Model

In [5]:
def convolution(in_channels, out_channels, stride):
    return torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1)

def deconvolution(in_channels, out_channels, stride):
    return torch.nn.ConvTranspose3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, output_padding=1)

def normalization(channel):
    return torch.nn.BatchNorm3d(channel)

def activation():
    return torch.nn.PReLU()

def pooling(kernel_size):
    return torch.nn.MaxPool3d(kernel_size=kernel_size)

def upsampling(scale_factor):
    return torch.nn.Upsample(scale_factor=scale_factor, mode='trilinear', align_corners=True)

class Autoencoder(torch.nn.Module):
    def __init__(self, channels):
        super(Autoencoder, self).__init__()
        self.encoder = torch.nn.Sequential(
            convolution(in_channels=1, out_channels=channels[0], stride=2),
            normalization(channels[0]),
            activation(),
            
            convolution(in_channels=channels[0], out_channels=channels[1], stride=2),
            normalization(channels[1]),
            activation(),
            
            convolution(in_channels=channels[1], out_channels=channels[2], stride=2),
            normalization(channels[2]),
            activation(),

            convolution(in_channels=channels[2], out_channels=channels[3], stride=2),
            normalization(channels[3]),
            activation(),
        )
        self.decoder = torch.nn.Sequential(
            deconvolution(in_channels=channels[3], out_channels=channels[2], stride=2),
            normalization(channels[2]),
            activation(),
            
            deconvolution(in_channels=channels[2], out_channels=channels[1], stride=2),
            normalization(channels[1]),
            activation(),
            
            deconvolution(in_channels=channels[1], out_channels=channels[0], stride=2),
            normalization(channels[0]),
            activation(),
            
            deconvolution(in_channels=channels[0], out_channels=1, stride=2),
            normalization(1),
            torch.nn.ReLU(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

autoencoder = Autoencoder(channels=[128,256,512,1024]).to(device)
autoencoder.load_state_dict(torch.load(autoencoder_gm_dir))
for param in autoencoder.parameters():
    param.requires_grad = False

In [9]:
class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)

class StraightThroughEstimator(torch.nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
        x = STEFunction.apply(x)
        return x

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.segmentation = monai.networks.nets.UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=num_classes,
            channels=(16, 32, 64, 128),
            strides=(2, 2, 2,),
            norm=monai.networks.layers.Norm.BATCH,
            dropout=0.3,
        )
        self.ste = StraightThroughEstimator()
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, image):
        prob = self.softmax(self.segmentation(image))
        mask = prob[:, 1:2]#self.ste(prob[:, 2:]-0.5)
        shape = autoencoder.encoder(mask)
        recon = autoencoder.decoder(shape)
        return prob, mask, shape, recon

# Training

In [7]:
def clean(model):
    del model
    gc.collect()
    torch.cuda.empty_cache()

class Metrics:
    def __init__(self):
        self.metrics = defaultdict(list)
        self.names = ['loss', 'loss_segm', 'loss_shape', 'loss_recon']
    def log(self, mode, epoch, values):
        for name, value in zip(self.names, values):
            self.metrics[(mode, epoch, name)].append(0 if value == 0 else value.item())
    def show(self, mode, epoch):
        print()
        for name in self.names:
            mean = np.mean(self.metrics[(mode, epoch, name)])
            print(f'{mode} {name}: {mean}')
            if name == 'loss':
                mean_loss = mean
        return mean_loss
'''
def predict(dataloaders):
    mode = 'test'
    model = Model().to(device)
    model.load_state_dict(torch.load(f'{model_dir}/best_model.torch'))
    model.eval()

    for i, subject in enumerate(tqdm(dataloaders[mode])):
        image = subject['image'][tio.DATA].to(device)
        label = subject['label'][tio.DATA].to(device)
        prob, mask, shape, recon = model(image)
        
        nrrd.write(f'{model_dir}/true{i}.nrrd', label[0, 0].detach().cpu().numpy())
        nrrd.write(f'{model_dir}/pred{i}.nrrd', mask[0, 0].detach().cpu().numpy())
        nrrd.write(f'{model_dir}/recon{i}.nrrd', recon[0, 0].detach().cpu().numpy())

    !rm file.zip
    !zip -r file.zip $model_dir

    del model
    gc.collect()
    torch.cuda.empty_cache()
'''
def train(model, n_epochs, dataloaders, learning_rate, lambda_shape, lambda_recon):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.2)

    loss_dice = monai.losses.DiceLoss(squared_pred=True)
    loss_mse = torch.nn.MSELoss()
    
    metric = monai.metrics.DiceMetric(reduction='mean_batch')
    metrics = Metrics()
    best_val_loss = np.Inf

    tol, tol_total = 0, 0
    for epoch in range(1, n_epochs+1):
        t0 = time.time()
        print(f'\nEpoch {epoch}/{n_epochs}')
        for mode in modes:
            if mode == 'train':
                model.train()
            else:
                model.eval()
            
            for subject in dataloaders[mode]:
                image = subject['image'][tio.DATA].to(device)
                label = subject['label'][tio.DATA].to(device).float()
                label_gm = label[:, 1:2]
                
                prob, mask, shape, recon = model(image)
                loss_segm = loss_dice(prob, label)
                loss = loss_segm
                loss_shape, loss_recon = 0, 0
                if lambda_shape > 0:
                    loss_shape = loss_mse(shape, autoencoder.encoder(label_gm))
                    loss += lambda_shape * loss_shape
                if lambda_recon > 0:
                    loss_recon = loss_dice(recon, label_gm)
                    loss += lambda_recon * loss_recon

                if mode == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                
                pred = monai.networks.utils.one_hot(
                    torch.argmax(prob, dim=1, keepdim=True), 
                    num_classes=num_classes, 
                    dim=1
                ).to(device)
                metric(pred, label)
                metrics.log(mode, epoch, [loss, loss_segm, loss_shape, loss_recon])

            dsc = metric.aggregate().tolist()
            metric.reset()
            mean_loss = metrics.show(mode, epoch)
            print(f'{mode} unet DSC: {dsc}')
        
        if mean_loss <= best_val_loss:
            best_val_loss = mean_loss
            best_epoch = epoch
            tol = 0
            tol_total = 0
        else:
            tol += 1
            tol_total += 1
        print(f'Best val loss: {best_val_loss}')
        
        if tol == 10:
            scheduler.step()
            print('Validation loss stopped to decrease for 10 epochs (LR /= 5).')
            tol = 0
        
        time_elapsed = time.time() - t0
        print(f'Time: {time_elapsed}\n')

        if tol_total == 22:
            print('Validation loss stopped to decrease for 30 epochs. Training terminated.')
            break
    print(f'Best epoch: {best_epoch}')

In [10]:
if 'model' in globals(): clean(model)
model = Model().to(device)
train(
    model=model, n_epochs=300, dataloaders=dataloaders, learning_rate=1e-3, 
    lambda_shape=0, lambda_recon=0,
    # lambda_shape=0.01, lambda_recon=0.001,
)


Epoch 1/300

train loss: 0.3140457736540444
train loss_segm: 0.3140457736540444
train loss_shape: 0.0
train loss_recon: 0.0
train unet DSC: [0.9389461278915405, 0.5425962805747986, 0.5318793654441833]

test loss: 0.1394583529386765
test loss_segm: 0.1394583529386765
test loss_shape: 0.0
test loss_recon: 0.0
test unet DSC: [0.9900426864624023, 0.7536183595657349, 0.7919955253601074]
Best val loss: 0.1394583529386765
Time: 63.4544415473938


Epoch 2/300

train loss: 0.1329192851163164
train loss_segm: 0.1329192851163164
train loss_shape: 0.0
train loss_recon: 0.0
train unet DSC: [0.9907858967781067, 0.7636953592300415, 0.7887741327285767]

test loss: 0.09654399905449305
test loss_segm: 0.09654399905449305
test loss_shape: 0.0
test loss_recon: 0.0
test unet DSC: [0.9930020570755005, 0.8238855004310608, 0.8505937457084656]
Best val loss: 0.09654399905449305
Time: 65.96640419960022


Epoch 3/300

train loss: 0.1064275563989259
train loss_segm: 0.1064275563989259
train loss_shape: 0.0
train

In [11]:
if 'model' in globals(): clean(model)
model = Model().to(device)
train(
    model=model, n_epochs=300, dataloaders=dataloaders, learning_rate=1e-3, 
    lambda_shape=0.01, lambda_recon=0.01,
)


Epoch 1/300

train loss: 0.29476251820974714
train loss_segm: 0.29476251820974714
train loss_shape: 0.20519594933036006
train loss_recon: 0.43805305678633194
train unet DSC: [0.9579780697822571, 0.5785843729972839, 0.5882948637008667]

test loss: 0.13700910294667268
test loss_segm: 0.13700910294667268
test loss_shape: 0.10304195032669948
test loss_recon: 0.2597254820359059
test unet DSC: [0.9906886219978333, 0.765700101852417, 0.7974656820297241]
Best val loss: 0.13700910294667268
Time: 86.2115626335144


Epoch 2/300

train loss: 0.13499636745339708
train loss_segm: 0.13499636745339708
train loss_shape: 0.1342805171314674
train loss_recon: 0.28173125318334075
train unet DSC: [0.9912410974502563, 0.7697640061378479, 0.7840234041213989]

test loss: 0.10009106000264485
test loss_segm: 0.10009106000264485
test loss_shape: 0.07873722547904038
test loss_recon: 0.20848735784872985
test unet DSC: [0.9929248094558716, 0.8198614120483398, 0.8447510004043579]
Best val loss: 0.10009106000264485
T

In [12]:
if 'model' in globals(): clean(model)
model = Model().to(device)
train(
    model=model, n_epochs=300, dataloaders=dataloaders, learning_rate=1e-3, 
    lambda_shape=0.01, lambda_recon=0.001,
)


Epoch 1/300

train loss: 0.30170830744731275
train loss_segm: 0.30170830744731275
train loss_shape: 0.2228647405399552
train loss_recon: 0.46985984603060954
train unet DSC: [0.9415330290794373, 0.5637030005455017, 0.5453333854675293]

test loss: 0.16016409565240908
test loss_segm: 0.16016409565240908
test loss_shape: 0.10993720782108796
test loss_recon: 0.2989827470901685
test unet DSC: [0.9884563088417053, 0.7014467716217041, 0.7868847846984863]
Best val loss: 0.16016409565240908
Time: 86.01402878761292


Epoch 2/300

train loss: 0.1422858934236478
train loss_segm: 0.1422858934236478
train loss_shape: 0.1412860945433001
train loss_recon: 0.3102436627768263
train unet DSC: [0.9908702373504639, 0.7530251741409302, 0.7671416401863098]

test loss: 0.10539026138110039
test loss_segm: 0.10539026138110039
test loss_shape: 0.0753032454313376
test loss_recon: 0.19482113612003815
test unet DSC: [0.9934009909629822, 0.8073424696922302, 0.8342234492301941]
Best val loss: 0.10539026138110039
Time

In [13]:
if 'model' in globals(): clean(model)
model = Model().to(device)
train(
    model=model, n_epochs=300, dataloaders=dataloaders, learning_rate=1e-3, 
    lambda_shape=0.001, lambda_recon=0.01,
)


Epoch 1/300

train loss: 0.3248138095759138
train loss_segm: 0.3248138095759138
train loss_shape: 0.22586101717964002
train loss_recon: 0.4618474124353143
train unet DSC: [0.9364240765571594, 0.535178005695343, 0.5045997500419617]

test loss: 0.16810433834027022
test loss_segm: 0.16810433834027022
test loss_shape: 0.11919112159655644
test loss_recon: 0.28791040640610915
test unet DSC: [0.9890183806419373, 0.7258719205856323, 0.7246515154838562]
Best val loss: 0.16810433834027022
Time: 85.78967618942261


Epoch 2/300

train loss: 0.14535284622371952
train loss_segm: 0.14535284622371952
train loss_shape: 0.14725861495620088
train loss_recon: 0.3009297259246247
train unet DSC: [0.9898287653923035, 0.7447307705879211, 0.7751407623291016]

test loss: 0.10818571998522832
test loss_segm: 0.10818571998522832
test loss_shape: 0.08088792822299859
test loss_recon: 0.21016045869925085
test unet DSC: [0.9930209517478943, 0.8082422018051147, 0.8318229913711548]
Best val loss: 0.10818571998522832
Ti

In [14]:
if 'model' in globals(): clean(model)
model = Model().to(device)
train(
    model=model, n_epochs=300, dataloaders=dataloaders, learning_rate=1e-3, 
    lambda_shape=0.1, lambda_recon=0.1,
)


Epoch 1/300

train loss: 0.405095580918125
train loss_segm: 0.405095580918125
train loss_shape: 0.19913736711951752
train loss_recon: 0.42969625735584693
train unet DSC: [0.9322459697723389, 0.5393196940422058, 0.45848700404167175]

test loss: 0.19002980910814726
test loss_segm: 0.19002980910814726
test loss_shape: 0.09654266635576884
test loss_recon: 0.2566548540041997
test unet DSC: [0.9892118573188782, 0.7091829776763916, 0.7857166528701782]
Best val loss: 0.19002980910814726
Time: 85.30607652664185


Epoch 2/300

train loss: 0.18207280573588383
train loss_segm: 0.18207280573588383
train loss_shape: 0.12333071071513091
train loss_recon: 0.26518886451479756
train unet DSC: [0.9898360967636108, 0.74237060546875, 0.7732837200164795]

test loss: 0.14960497655929664
test loss_segm: 0.14960497655929664
test loss_shape: 0.06922480177420837
test loss_recon: 0.19778631895016402
test unet DSC: [0.9912402629852295, 0.7452601194381714, 0.829781174659729]
Best val loss: 0.14960497655929664
Time

KeyboardInterrupt: 

In [None]:
if 'model' in globals(): clean(model)
model = Model().to(device)
train(
    model=model, n_epochs=300, dataloaders=dataloaders, learning_rate=1e-3, 
    lambda_shape=0.1, lambda_recon=0.01,
)

In [None]:
if 'model' in globals(): clean(model)
model = Model().to(device)
train(
    model=model, n_epochs=300, dataloaders=dataloaders, learning_rate=1e-3, 
    lambda_shape=0.01, lambda_recon=0.1,
)