In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import os
import gc
from glob import glob
from tqdm import tqdm
from collections import defaultdict
import json
import scipy.ndimage as ndimage
import nrrd
import torchio as tio
import monai
import nibabel as nib
import time



# Dataset

In [2]:
city = 'Beijing_Zang'

modes = ['train', 'test']
total_size = 197
train_size, test_size = 158, 39

image_dir = f'../dataset/{city}/MRI'
image_arti_dir = f'../dataset/{city}/MRI_arti'
label_dir = f'../dataset/{city}/Ventricles'
autoencoder_dir = f'../results/SCAE_all_monai/best_autoencoder.torch'

model_dir = f'../results/{city}_unet_autoencoder_allvent'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
seed = 0
random_state = np.random.RandomState(seed=seed)
perm = random_state.permutation(total_size)
perm = {
    'train': perm[:train_size],
    'test': perm[-test_size:],
}

def get_subjects(mode):
    subjects = []
    image_src = image_dir #image_dir if mode == 'train' else image_arti_dir
    image_paths = [sorted(glob(f'{image_src}/*.nii.gz'))[i] for i in perm[mode]]
    for image_path in tqdm(image_paths, desc=mode):
        fn = image_path.split('/')[-1]
        label_path = f'{label_dir}/{fn}'
        subject = tio.Subject(
            image=tio.ScalarImage(image_path),
            label=tio.LabelMap(label_path),
        )
        subjects.append(subject)
    return subjects

def get_transform():
    resample = tio.Compose([
        tio.Resample(2),
        tio.CropOrPad((96, 128, 128)), #192, 256, 256
    ])
    signal = tio.Compose([ 
        tio.RescaleIntensity(percentiles=(0.1, 99.9), out_min_max=(0, 1)),
    ])
    # spatial = tio.OneOf(
    #     {tio.RandomAffine(degrees=(-3, 3), translation=(-0.1, 0.1)): 1.0},
    #     p=0.75,
    # )
    remapping = tio.RemapLabels({0:0, 1:1, 2:1, 3:1, 4:1})
    transform = {
        'train': tio.Compose([
            # spatial,
            resample,
            signal,
            remapping,
        ]),
        'test': tio.Compose([
            resample,
            signal,
            remapping,
        ]),
    }
    return transform

def get_dataloader(transform):
    dataloader = dict()
    for mode in modes:
        dataloader[mode] = torch.utils.data.DataLoader(
            tio.SubjectsDataset(
                subjects[mode], 
                transform=transform[mode]
            ),
            batch_size=1, 
            num_workers=os.cpu_count(),
            shuffle=(mode == 'train'),
        )
    return dataloader

subjects = {mode: get_subjects(mode) for mode in modes}
transform = get_transform()
dataloaders = get_dataloader(transform)

train: 100%|██████████| 158/158 [00:01<00:00, 89.48it/s]
test: 100%|██████████| 39/39 [00:00<00:00, 70.54it/s]


# Model

In [4]:
def convolution(in_channels, out_channels, stride):
    return torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1)

def deconvolution(in_channels, out_channels, stride):
    return torch.nn.ConvTranspose3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, output_padding=1)

def normalization(channel):
    return torch.nn.BatchNorm3d(channel)

def activation():
    return torch.nn.PReLU()

def pooling(kernel_size):
    return torch.nn.MaxPool3d(kernel_size=kernel_size)

def upsampling(scale_factor):
    return torch.nn.Upsample(scale_factor=scale_factor, mode='trilinear', align_corners=True)

class Autoencoder(torch.nn.Module):
    def __init__(self, channels):
        super(Autoencoder, self).__init__()
        self.encoder = torch.nn.Sequential(
            convolution(in_channels=1, out_channels=channels[0], stride=2),
            normalization(channels[0]),
            activation(),
            
            convolution(in_channels=channels[0], out_channels=channels[1], stride=2),
            normalization(channels[1]),
            activation(),
            
            convolution(in_channels=channels[1], out_channels=channels[2], stride=2),
            normalization(channels[2]),
            activation(),

            convolution(in_channels=channels[2], out_channels=channels[3], stride=2),
            normalization(channels[3]),
            activation(),
        )
        self.decoder = torch.nn.Sequential(
            deconvolution(in_channels=channels[3], out_channels=channels[2], stride=2),
            normalization(channels[2]),
            activation(),
            
            deconvolution(in_channels=channels[2], out_channels=channels[1], stride=2),
            normalization(channels[1]),
            activation(),
            
            deconvolution(in_channels=channels[1], out_channels=channels[0], stride=2),
            normalization(channels[0]),
            activation(),
            
            deconvolution(in_channels=channels[0], out_channels=1, stride=2),
            torch.nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

autoencoder = Autoencoder(channels=[64,128,256,512]).to(device)
autoencoder.load_state_dict(torch.load(autoencoder_dir))
for param in autoencoder.parameters():
    param.requires_grad = False

In [5]:
class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)

class StraightThroughEstimator(torch.nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
        x = STEFunction.apply(x)
        return x

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.segmentation = monai.networks.nets.UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=1,
            channels=(32, 64, 128, 256, 512),
            strides=(2, 2, 2, 2),
            norm=monai.networks.layers.Norm.BATCH,
        )
        self.sigmoid = torch.nn.Sigmoid()
        self.ste = StraightThroughEstimator()

    def forward(self, image):
        prob = self.sigmoid(self.segmentation(image))
        mask = self.ste(prob - 0.5)
        shape = autoencoder.encoder(mask)
        recon = autoencoder.decoder(shape)
        return prob, mask, shape, recon

# Training

In [6]:
def clean(model):
    del model
    gc.collect()
    torch.cuda.empty_cache()

class Metrics:
    def __init__(self):
        self.metrics = defaultdict(list)
        self.names = ['loss', 'loss_segm', 'loss_shape', 'loss_recon']
    def log(self, mode, epoch, values):
        for name, value in zip(self.names, values):
            self.metrics[(mode, epoch, name)].append(value.item())
    def show(self, mode, epoch):
        print()
        for name in self.names:
            mean = np.mean(self.metrics[(mode, epoch, name)])
            print(f'{mode} {name}: {mean}')
'''
def predict(dataloaders):
    mode = 'test'
    model = Model().to(device)
    model.load_state_dict(torch.load(f'{model_dir}/best_model.torch'))
    model.eval()

    for i, subject in enumerate(tqdm(dataloaders[mode])):
        image = subject['image'][tio.DATA].to(device)
        label = subject['label'][tio.DATA].to(device)
        prob, mask, shape, recon = model(image)
        
        nrrd.write(f'{model_dir}/true{i}.nrrd', label[0, 0].detach().cpu().numpy())
        nrrd.write(f'{model_dir}/pred{i}.nrrd', mask[0, 0].detach().cpu().numpy())
        nrrd.write(f'{model_dir}/recon{i}.nrrd', recon[0, 0].detach().cpu().numpy())

    !rm file.zip
    !zip -r file.zip $model_dir

    del model
    gc.collect()
    torch.cuda.empty_cache()
'''
def train(model, lambdas, n_epochs, dataloaders, learning_rate):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)

    loss_dice = monai.losses.DiceLoss(squared_pred=True).to(device)
    loss_mse = torch.nn.MSELoss().to(device)
    
    metric_unet = monai.metrics.DiceMetric(include_background=False, reduction='mean_batch')
    metric_recon = monai.metrics.DiceMetric(include_background=False, reduction='mean_batch')
    metrics = Metrics()
    best_val_dsc_unet = 0
    best_val_dsc_recon = 0

    for epoch in range(1, n_epochs+1):
        t0 = time.time()
        print(f'\nEpoch {epoch}/{n_epochs}')
        for mode in modes:
            if mode == 'train':
                model.train()
            else:
                model.eval()
            
            for subject in dataloaders[mode]:
                image = subject['image'][tio.DATA].to(device)
                label = subject['label'][tio.DATA].to(device).float()
                
                prob, mask, shape, recon = model(image)
                loss_segm = loss_dice(prob, label)
                loss_shape = loss_mse(shape, autoencoder.encoder(label))
                loss_recon = loss_dice(recon, label)

                loss = lambdas[0] * loss_segm + lambdas[1] * loss_shape + lambdas[2] * loss_recon

                if mode == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                
                metric_unet((prob > 0.5).float(), label)
                metric_recon((recon > 0.5).float(), label)
                metrics.log(mode, epoch, [loss, loss_segm, loss_shape, loss_recon])

            metrics.show(mode, epoch)
            mean_dsc_unet = metric_unet.aggregate().tolist()[0]
            metric_unet.reset()
            mean_dsc_recon = metric_recon.aggregate().tolist()[0]
            metric_recon.reset()
            print(f'{mode} unet DSC: {mean_dsc_unet}')
            print(f'{mode} recon DSC: {mean_dsc_recon}')
            
        scheduler.step()

        if mean_dsc_unet > best_val_dsc_unet:
            best_val_dsc_unet = mean_dsc_unet
        if mean_dsc_recon > best_val_dsc_recon:
            best_val_dsc_recon = mean_dsc_recon
        print(f'Best val unet DSC: {best_val_dsc_unet}')
        print(f'Best val recon DSC: {best_val_dsc_recon}')

        time_elapsed = time.time() - t0
        print(f'Time: {time_elapsed}\n')

In [11]:
lambdas = (1, 0.1, 0)
clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=60, dataloaders=dataloaders, learning_rate=1e-1)


Epoch 1/60

train loss: 0.48635669599605513
train loss_segm: 0.4689623185350925
train loss_shape: 0.17394377880647213
train loss_recon: 0.39367952678776996
train unet DSC: 0.5249937176704407
train recon DSC: 0.5760893821716309

test loss: 0.49298153053491545
test loss_segm: 0.4772439308655568
test loss_shape: 0.15737599172653297
test loss_recon: 0.42882718489720273
test unet DSC: 0.5186718702316284
test recon DSC: 0.556290864944458
Best val unet DSC: 0.5186718702316284
Best val recon DSC: 0.556290864944458
Time: 47.515700817108154


Epoch 2/60

train loss: 0.24581069536978686
train loss_segm: 0.23853012051763414
train loss_shape: 0.07280575080857246
train loss_recon: 0.21563810110092163
train unet DSC: 0.7540978789329529
train recon DSC: 0.7589353322982788

test loss: 0.7152521381011376
test loss_segm: 0.6923614877920884
test loss_shape: 0.22890660262260681
test loss_recon: 0.6475739570764395
test unet DSC: 0.30394211411476135
test recon DSC: 0.34401166439056396
Best val unet DSC: 0.5

In [12]:
lambdas = (1, 0, 0.1)
clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=100, dataloaders=dataloaders, learning_rate=1e-3)


Epoch 1/100

train loss: 1.0333815160431439
train loss_segm: 0.9755964905400819
train loss_shape: 0.636299678607832
train loss_recon: 0.5778502342821676
train unet DSC: 0.03385075926780701
train recon DSC: 0.3756972551345825

test loss: 1.0103249366466815
test loss_segm: 0.9597185804293706
test loss_shape: 0.5272288750379514
test loss_recon: 0.5060635300783011
test unet DSC: 0.05250579118728638
test recon DSC: 0.4349593222141266
Best val unet DSC: 0.05250579118728638
Best val recon DSC: 0.4349593222141266
Time: 46.937642097473145


Epoch 2/100

train loss: 0.9946577598022509
train loss_segm: 0.95964319645604
train loss_shape: 0.45205046726933007
train loss_recon: 0.3501456413842455
train unet DSC: 0.1021006852388382
train recon DSC: 0.6077445149421692

test loss: 0.9967163296846243
test loss_segm: 0.9567486613224714
test loss_shape: 0.4150015367911412
test loss_recon: 0.39967676462271273
test unet DSC: 0.058223456144332886
test recon DSC: 0.526386559009552
Best val unet DSC: 0.0582234

In [13]:
lambdas = (1, 0.1, 0.1)
clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=100, dataloaders=dataloaders, learning_rate=1e-3)


Epoch 1/100

train loss: 1.0843037193334555
train loss_segm: 0.9928277734714218
train loss_shape: 0.3124498567437824
train loss_recon: 0.6023096295097207
train unet DSC: 0.004615956451743841
train recon DSC: 0.35769912600517273

test loss: 1.0573536157608032
test loss_segm: 0.9915795096984277
test loss_shape: 0.2507311824040535
test loss_recon: 0.4070098201433818
test unet DSC: 0.005692462436854839
test recon DSC: 0.5570905804634094
Best val unet DSC: 0.005692462436854839
Best val recon DSC: 0.5570905804634094
Time: 47.8427517414093


Epoch 2/100

train loss: 1.0432616555238072
train loss_segm: 0.9898624601243418
train loss_shape: 0.2257448431057266
train loss_recon: 0.308247098817101
train unet DSC: 0.0070691038854420185
train recon DSC: 0.6485552191734314

test loss: 1.052119410954989
test loss_segm: 0.988805109109634
test loss_shape: 0.21288877152479613
test loss_recon: 0.42025425495245516
test unet DSC: 0.008434081450104713
test recon DSC: 0.5262463092803955
Best val unet DSC: 0.0

In [14]:
lambdas = (1, 0.1, 0.01)
clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=100, dataloaders=dataloaders, learning_rate=1e-3)


Epoch 1/100

train loss: 1.0181910403167145
train loss_segm: 0.9814558210252207
train loss_shape: 0.2960032386870324
train loss_recon: 0.7134892521779749
train unet DSC: 0.024405397474765778
train recon DSC: 0.22854971885681152

test loss: 0.9894745991780207
test loss_segm: 0.9530696120017614
test loss_shape: 0.31371526229075897
test loss_recon: 0.5033454390672537
test unet DSC: 0.07853485643863678
test recon DSC: 0.43557730317115784
Best val unet DSC: 0.07853485643863678
Best val recon DSC: 0.43557730317115784
Time: 46.99534463882446


Epoch 2/100

train loss: 0.9728201884257642
train loss_segm: 0.9379513705078559
train loss_shape: 0.29733825559857524
train loss_recon: 0.513499385571178
train unet DSC: 0.12624502182006836
train recon DSC: 0.4234963059425354

test loss: 0.9713067244260739
test loss_segm: 0.9355452595612942
test loss_shape: 0.300185868755365
test loss_recon: 0.5742874267773751
test unet DSC: 0.09229128062725067
test recon DSC: 0.3623772859573364
Best val unet DSC: 0.09

In [15]:
lambdas = (1, 0.01, 0.1)
clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=100, dataloaders=dataloaders, learning_rate=1e-3)


Epoch 1/100

train loss: 1.0351380717150773
train loss_segm: 0.9696784374080126
train loss_shape: 0.5182229355920719
train loss_recon: 0.602774051925804
train unet DSC: 0.0520453080534935
train recon DSC: 0.34413987398147583

test loss: 1.0727004301853669
test loss_segm: 0.9849187410794772
test loss_shape: 0.5650630165369083
test loss_recon: 0.8213105675501701
test unet DSC: 0.014634245075285435
test recon DSC: 0.04745763540267944
Best val unet DSC: 0.014634245075285435
Best val recon DSC: 0.04745763540267944
Time: 47.59533667564392


Epoch 2/100

train loss: 1.0002686200262625
train loss_segm: 0.9485242853436289
train loss_shape: 0.4194762523792967
train loss_recon: 0.4754956802235374
train unet DSC: 0.09479192644357681
train recon DSC: 0.467365026473999

test loss: 1.0039181587023613
test loss_segm: 0.9562271650020893
test loss_shape: 0.3644883105388054
test loss_recon: 0.4404611281859569
test unet DSC: 0.058126915246248245
test recon DSC: 0.5147019624710083
Best val unet DSC: 0.058

In [16]:
lambdas = (1, 0.01, 0.01)
clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=100, dataloaders=dataloaders, learning_rate=1e-3)


Epoch 1/100

train loss: 0.5568197094186952
train loss_segm: 0.5502118985109692
train loss_shape: 0.20763914549982623
train loss_recon: 0.4531419688387762
train unet DSC: 0.5017874240875244
train recon DSC: 0.5184509754180908

test loss: 0.22217339735764724
test loss_segm: 0.21937500819181785
test loss_shape: 0.06035673092955198
test loss_recon: 0.21948215900323328
test unet DSC: 0.7504677176475525
test recon DSC: 0.7514169812202454
Best val unet DSC: 0.7504677176475525
Best val recon DSC: 0.7514169812202454
Time: 47.5423002243042


Epoch 2/100

train loss: 0.1736730830861798
train loss_segm: 0.17148833863342863
train loss_shape: 0.04456447775746825
train loss_recon: 0.17390993199770963
train unet DSC: 0.7975248694419861
train recon DSC: 0.7995679974555969

test loss: 0.16808490053965494
test loss_segm: 0.16594800429466444
test loss_shape: 0.04054885567762913
test loss_recon: 0.17314065572543022
test unet DSC: 0.8007994890213013
test recon DSC: 0.8012375831604004
Best val unet DSC: 0.

In [17]:
lambdas = (1, 0.001, 0.01)
clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=100, dataloaders=dataloaders, learning_rate=1e-3)


Epoch 1/100

train loss: 0.564402615344977
train loss_segm: 0.5594379784940164
train loss_shape: 0.2882153672177958
train loss_recon: 0.46764201189898236
train unet DSC: 0.5037084817886353
train recon DSC: 0.5059838891029358

test loss: 0.30468896451668864
test loss_segm: 0.3016135707879678
test loss_shape: 0.09456849203277858
test loss_recon: 0.29808231041981625
test unet DSC: 0.6592626571655273
test recon DSC: 0.6835687160491943
Best val unet DSC: 0.6592626571655273
Best val recon DSC: 0.6835687160491943
Time: 47.64653444290161


Epoch 2/100

train loss: 0.18341623270247556
train loss_segm: 0.18152147871029528
train loss_shape: 0.04933967628763824
train loss_recon: 0.18454136501384688
train unet DSC: 0.7869527339935303
train recon DSC: 0.788482666015625

test loss: 0.18672175743640998
test loss_segm: 0.18474362446711615
test loss_shape: 0.049521334803639315
test loss_recon: 0.19286139806111655
test unet DSC: 0.7797231674194336
test recon DSC: 0.7761989235877991
Best val unet DSC: 0.

In [18]:
lambdas = (1, 0.01, 0.001)
clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=100, dataloaders=dataloaders, learning_rate=1e-3)


Epoch 1/100

train loss: 0.5373284914825536
train loss_segm: 0.5345380336423463
train loss_shape: 0.23378767742763593
train loss_recon: 0.4525816572617881
train unet DSC: 0.5136671662330627
train recon DSC: 0.5201465487480164

test loss: 0.23085119785406652
test loss_segm: 0.2299970220296811
test loss_shape: 0.06298016785429074
test loss_recon: 0.2243752433703496
test unet DSC: 0.7369555830955505
test recon DSC: 0.7518736124038696
Best val unet DSC: 0.7369555830955505
Best val recon DSC: 0.7518736124038696
Time: 47.30076551437378


Epoch 2/100

train loss: 0.16705043920421903
train loss_segm: 0.1664584271515472
train loss_shape: 0.042245069884141034
train loss_recon: 0.16956099339678318
train unet DSC: 0.8018706440925598
train recon DSC: 0.8038728833198547

test loss: 0.2706321642184869
test loss_segm: 0.26954672122612977
test loss_shape: 0.08105813167416133
test loss_recon: 0.27486059298882115
test unet DSC: 0.6867274045944214
test recon DSC: 0.7071641683578491
Best val unet DSC: 0.7

In [None]:
lambdas = (1, 0.001, 0.001)
clean(model)
model = Model().to(device)
train(model=model, lambdas=lambdas, n_epochs=100, dataloaders=dataloaders, learning_rate=1e-3)


Epoch 1/100
