In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
import os
import gc
from glob import glob
from tqdm import tqdm
from collections import defaultdict
import json
import scipy.ndimage as ndimage
import nrrd
import torchio as tio
import monai
import nibabel as nib
from collections import OrderedDict, defaultdict



In [2]:
city = 'Ellipsoids_FP'
IMAGE_DIR = f'../dataset/{city}/images'
IMAGE_ARTI_DIR = f'../dataset/{city}/images_arti'
MASK_DIR = f'../dataset/{city}/segmentations'

ed_dir = '../dataset/Ellipsoids/models'
encoder_path = f'{ed_dir}/best_encoder.torch'
decoder_path = f'{ed_dir}/best_decoder.torch'

num_classes = 2
TRAIN_SIZE, VAL_SIZE, TEST_SIZE = 50, 10, 21
TOTAL_SIZE = TRAIN_SIZE + VAL_SIZE + TEST_SIZE

output_dir = f'../results/{city}_ellp_ed/'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [3]:
def AE():
    autoencoder = monai.networks.nets.AutoEncoder(
        spatial_dims=3, in_channels=1, out_channels=1,
        kernel_size=(3, 3, 3),
        channels=[channel*1 for channel in (1, 2, 4, 8, 16)],
        strides=(1, 2, 2, 2, 2),
    )
    return autoencoder

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv = AE().encode
        self.fc = nn.Sequential(
            nn.Linear(1024, 256),
            nn.PReLU(),
            nn.Linear(256, 64),
            nn.PReLU(),
            nn.Linear(64, 1)
        )
        
    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(1, 1024),
            nn.PReLU(),
            nn.Linear(1024, 1024),
            nn.PReLU(),
            nn.Linear(1024, 1024),
            nn.PReLU(),
        )
        self.deconv = AE().decode
        
    def forward(self, x):
        x = self.fc(x)
        x = torch.reshape(x, (1, 16, 4, 4, 4))
        x = self.deconv(x)
        return x

class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)

class StraightThroughEstimator(torch.nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
        x = STEFunction.apply(x)
        return x

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.segmentation = monai.networks.nets.UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=num_classes,
            channels=(32, 64, 128, 256, 512),
            strides=(2, 2, 2, 2),
            norm=monai.networks.layers.Norm.BATCH,
        )
        self.softmax = torch.nn.Softmax(dim=1)
        self.ste = StraightThroughEstimator()

        self.encoder, self.decoder = Encoder(), Decoder()
        for coder, coder_path in zip([self.encoder, self.decoder], [encoder_path, decoder_path]):
            coder.load_state_dict(torch.load(coder_path))
            for param in coder.parameters():
                param.requires_grad = False
            coder.eval()

    def forward(self, image):
        prob = self.softmax(self.segmentation(image))
        binary_mask = self.ste(prob[:, 1:2, :, :, :] - 0.5)
        shape = self.encoder(binary_mask)
        shape = torch.clamp(shape, min=10, max=30)
        recon = self.decoder(shape)
        return prob, binary_mask, shape, recon

In [4]:
seed = 0
random_state = np.random.RandomState(seed=seed)
perm = random_state.permutation(TOTAL_SIZE)
perm = {
    'train': perm[:TRAIN_SIZE],
    'validation': perm[TRAIN_SIZE:TRAIN_SIZE+VAL_SIZE],
    'test': perm[-TEST_SIZE:],
}

def get_subjects(mode):
    subjects = []
    image_paths = [sorted(glob(f'{IMAGE_DIR}/*.nrrd'))[i] for i in perm[mode]]
    for image_path in tqdm(image_paths):
        filename = image_path.split('/')[-1]
        mask_path = f'{MASK_DIR}/{filename}'
        image_arti_path = f'{IMAGE_ARTI_DIR}/{filename}'
        subject = tio.Subject(
            t1=tio.ScalarImage(image_arti_path),#image_path if mode == 'train' else image_arti_path),
            label=tio.LabelMap(mask_path),
            radius=torch.Tensor([float(filename.split('_')[-1][:5])]),
        )
        subjects.append(subject)
    return subjects

def get_dataloader(transform):
    dataloader = dict()
    for mode in ['train', 'validation', 'test']:
        dataloader[mode] = torch.utils.data.DataLoader(
            tio.SubjectsDataset(
                subjects[mode], 
                transform=transform
            ),
            batch_size=1, 
            num_workers=os.cpu_count(),
            shuffle=False,
        )
    return dataloader

subjects = dict()
for mode in ['train', 'validation', 'test']:
    subjects[mode] = get_subjects(mode)

signal = tio.Compose([ 
    tio.RescaleIntensity(percentiles=(0.1, 99.9), out_min_max=(0, 1)),
])

transform = tio.Compose([
    signal,
])

dataloader = get_dataloader(transform)

100%|██████████| 50/50 [00:01<00:00, 36.20it/s]
100%|██████████| 10/10 [00:00<00:00, 36.46it/s]
100%|██████████| 21/21 [00:00<00:00, 36.67it/s]


In [5]:
class Metrics:
    def __init__(self):
        self.metrics = defaultdict(list)
        self.names = ['loss', 'loss_segm', 'loss_shape', 'loss_recon']
    def log(self, mode, epoch, values):
        for name, value in zip(self.names, values):
            self.metrics[(mode, epoch, name)].append(value.item())
    def show(self, mode, epoch):
        print()
        for name in self.names:
            mean = np.mean(self.metrics[(mode, epoch, name)])
            print(f'{mode} {name}: {mean}')

def test(lambdas, save):
    mode = 'test'
    model.load_state_dict(torch.load(f'{output_dir}/best_model.torch'))
    model.eval()

    loss_dice_segm = monai.losses.DiceLoss(squared_pred=True).to(DEVICE)
    loss_mae_shape = torch.nn.L1Loss().to(DEVICE)
    loss_dice_recon = monai.losses.DiceLoss(sigmoid=True, squared_pred=True).to(DEVICE)
    
    metric = monai.metrics.DiceMetric(include_background=False, reduction='mean_batch')
    metrics = Metrics()

    for i, subject in enumerate(dataloader['test']):
        image = subject['t1'][tio.DATA].to(DEVICE)
        label = subject['label'][tio.DATA].to(DEVICE)
        one_hot_label = monai.networks.utils.one_hot(label, num_classes=num_classes, dim=1).to(DEVICE)
        radius = subject['radius'].to(DEVICE)

        prob, binary_mask, shape, recon = model(image)
        loss_segm = loss_dice_segm(prob, one_hot_label)
        loss_shape = loss_mae_shape(shape, radius)
        # loss_recon = loss_dice_recon(recon, one_hot_label[:, 1:2, :, :, :])
        loss_recon = loss_dice_recon(recon, binary_mask)

        loss = lambdas[0] * loss_segm + lambdas[1] * loss_shape + lambdas[2] * loss_recon
                
        metric(binary_mask, one_hot_label)
        metrics.log(mode, 1, [loss, loss_segm, loss_shape, loss_recon])
        
        if save:
            dest = f'{output_dir}{mode}_sample{i}.nrrd'
            nrrd.write(dest, binary_mask.detach().cpu().numpy()[0, 0])

    metrics.show(mode, 1)
    mean_dsc = metric.aggregate().tolist()[0]
    metric.reset()
    print(f'{mode} DSC: {mean_dsc}')
    if save:
        !rm file.zip
        !zip -r file.zip $output_dir

def train(model, lambdas, n_epochs, dataloader, learning_rate):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)

    loss_dice_segm = monai.losses.DiceLoss(squared_pred=True).to(DEVICE)
    loss_mae_shape = torch.nn.L1Loss().to(DEVICE)
    loss_dice_recon = monai.losses.DiceLoss(sigmoid=True, squared_pred=True).to(DEVICE)
    
    metric = monai.metrics.DiceMetric(include_background=False, reduction='mean_batch')
    metrics = Metrics()
    best_val_dsc = 0

    for epoch in range(1, n_epochs+1):
        print(f'\nEpoch {epoch}/{n_epochs}')
        for mode in ['train', 'validation']:
            if mode == 'train':
                model.train()
            else:
                model.eval()

            for subject in dataloader[mode]:
                image = subject['t1'][tio.DATA].to(DEVICE)
                label = subject['label'][tio.DATA].to(DEVICE)
                one_hot_label = monai.networks.utils.one_hot(label, num_classes=num_classes, dim=1).to(DEVICE)
                radius = subject['radius'].to(DEVICE)

                prob, binary_mask, shape, recon = model(image)
                loss_segm = loss_dice_segm(prob, one_hot_label)
                loss_shape = loss_mae_shape(shape, radius)
                # loss_recon = loss_dice_recon(recon, one_hot_label[:, 1:2, :, :, :])
                loss_recon = loss_dice_recon(recon, binary_mask)

                loss = lambdas[0] * loss_segm + lambdas[1] * loss_shape + lambdas[2] * loss_recon

                if mode == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                
                metric(binary_mask, one_hot_label)
                metrics.log(mode, epoch, [loss, loss_segm, loss_shape, loss_recon])

            metrics.show(mode, epoch)
            mean_dsc = metric.aggregate().tolist()[0]
            metric.reset()
            print(f'{mode} DSC: {mean_dsc}')
            
        scheduler.step()
        if mean_dsc > best_val_dsc:
            best_val_dsc = mean_dsc
            best_epoch = epoch
            torch.save(model.state_dict(), f'{output_dir}/best_model.torch')
    
    print(f'Best model saved after epoch {best_epoch} (val dsc = {best_val_dsc}).')

In [26]:
del model
gc.collect()
torch.cuda.empty_cache()

In [6]:
lambdas = (1, 0, 0)
model = Model().to(DEVICE)
train(model=model, lambdas=lambdas, n_epochs=30, dataloader=dataloader, learning_rate=1e-4)
test(lambdas=lambdas, save=True)


Epoch 1/30

train loss: 0.4916696602106094
train loss_segm: 0.4916696602106094
train loss_shape: 6.958824768066406
train loss_recon: 0.8825270140171051
train DSC: 0.14639776945114136

validation loss: 0.45185542702674864
validation loss_segm: 0.45185542702674864
validation loss_shape: 3.964399528503418
validation loss_recon: 0.8222893357276917
validation DSC: 0.195532888174057

Epoch 2/30

train loss: 0.3885744571685791
train loss_segm: 0.3885744571685791
train loss_shape: 4.396340789794922
train loss_recon: 0.6954705333709716
train DSC: 0.33226895332336426

validation loss: 0.35789439976215365
validation loss_segm: 0.35789439976215365
validation loss_shape: 3.2054320335388184
validation loss_recon: 0.5427184939384461
validation DSC: 0.4494542181491852

Epoch 3/30

train loss: 0.29368232369422914
train loss_segm: 0.29368232369422914
train loss_shape: 2.259503574371338
train loss_recon: 0.2725664293766022
train DSC: 0.7043095231056213

validation loss: 0.26661414802074435
validation lo

In [25]:
# image = next(iter(dataloader['validation']))['t1'][tio.DATA].numpy()[0, 0]
# for i in range(1, 64, 10):
#     plt.imshow(image[i], cmap='gray')
#     plt.show()