In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
import os
import gc
from glob import glob
from tqdm import tqdm
from collections import defaultdict
import json
import scipy.ndimage as ndimage
import nrrd
import torchio as tio
import monai
import nibabel as nib
from collections import OrderedDict, defaultdict

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  method='lar', copy_X=True, eps=np.finfo(np.float).eps,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  method='lar', copy_X=True, eps=np.finfo(np.float).eps,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_Gram=True, verbose=0,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_X=True, fit_path=True,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_X=True, fit_path=True,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes

In [7]:
city = 'Ellipsoids'
IMAGE_DIR = f'../dataset/{city}/images'
IMAGE_ARTI_DIR = f'../dataset/{city}/images_arti'
MASK_DIR = f'../dataset/{city}/segmentations'
num_classes = 2
TRAIN_SIZE = 8
VAL_SIZE = 50
TEST_SIZE = 23
TOTAL_SIZE = TRAIN_SIZE + VAL_SIZE + TEST_SIZE

output_dir = f'../results/{city}_arti_train2/'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [9]:
hidden1, hidden2 = 96, 48
base_filters = 12

class ConvolutionalBackbone(nn.Module):
    def __init__(self, img_dims):
        super(ConvolutionalBackbone, self).__init__()
        self.img_dims = img_dims
        self.out_fc_dim = np.copy(img_dims)
        padvals = [4, 8, 8]
        for i in range(3):
            self.out_fc_dim[0] = net_utils.poolOutDim(self.out_fc_dim[0] - padvals[i], 2)
            self.out_fc_dim[1] = net_utils.poolOutDim(self.out_fc_dim[1] - padvals[i], 2)
            self.out_fc_dim[2] = net_utils.poolOutDim(self.out_fc_dim[2] - padvals[i], 2)

        self.features = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv3d(1, base_filters, 5)),
            # ('bn1', nn.BatchNorm3d(base_filters)),
            ('relu1', nn.PReLU()),
            ('mp1', nn.MaxPool3d(2)),

            ('conv2', nn.Conv3d(base_filters, base_filters*2, 5)),
            # ('bn2', nn.BatchNorm3d(base_filters*2)),
            ('relu2', nn.PReLU()),
            ('conv3', nn.Conv3d(base_filters*2, base_filters*4, 5)),
            # ('bn3', nn.BatchNorm3d(base_filters*4)),
            ('relu3', nn.PReLU()),
            ('mp2', nn.MaxPool3d(2)),

            ('conv4', nn.Conv3d(base_filters*4, base_filters*8, 5)),
            # ('bn4', nn.BatchNorm3d(base_filters*8)),
            ('relu4', nn.PReLU()),
            ('conv5', nn.Conv3d(base_filters*8, base_filters*16, 5)),
            # ('bn5', nn.BatchNorm3d(base_filters*16)),
            ('relu5', nn.PReLU()),
            ('mp3', nn.MaxPool3d(2)),

            ('flatten', net_utils.Flatten()),
            
            ('fc1', nn.Linear(self.out_fc_dim[0]*self.out_fc_dim[1]*self.out_fc_dim[2]*base_filters*16, hidden1)),
            ('relu6', nn.PReLU()),
            ('fc2', nn.Linear(hidden1, hidden2)),
            ('relu7', nn.PReLU()),
        ]))

    def forward(self, x):
        x_features = self.features(x)
        return x_features

class DenseNormalGamma(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dense = nn.Linear(in_features, out_features * 4)
        
    def evidence(self, x):
        return F.softplus(x)
    
    def forward(self, x):
        output = self.dense(x)
        mu, logv, logalpha, logbeta = torch.split(output, self.out_features, -1)
        v = self.evidence(logv)
        alpha = self.evidence(logalpha) + 1
        beta = self.evidence(logbeta)
        
        return mu, v, alpha, beta

class DeepSSMNet(nn.Module):
    def __init__(self):
        super(DeepSSMNet, self).__init__()
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        self.num_latent = 1
        self.img_dims = (64, 64, 64)
        print(f'MLP layers: {base_filters*16} -> {hidden1} -> {hidden2} -> {self.num_latent}')
        self.ConvolutionalBackbone = ConvolutionalBackbone(self.img_dims)
        self.pca_pred = nn.Sequential(OrderedDict([
            ('linear', DenseNormalGamma(hidden2, self.num_latent))
        ]))

    def forward(self, x):
        x = self.ConvolutionalBackbone(x)
        dist_params = self.pca_pred(x)
        return dist_params

class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)

class StraightThroughEstimator(torch.nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
        x = STEFunction.apply(x)
        return x

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.segmentation = monai.networks.nets.UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=num_classes,
            channels=(32, 64, 128, 256, 512),
            strides=(2, 2, 2, 2),
            norm=monai.networks.layers.Norm.BATCH,
        )
        self.softmax = torch.nn.Softmax(dim=1)
        self.ste = StraightThroughEstimator()

    def forward(self, image):
        mask = self.segmentation(image)
        mask = self.softmax(mask)
        binary_mask = self.ste(mask - 0.5)[:, 1:2, :, :, :]
        dist_params = deepssm(binary_mask)
        return mask, binary_mask, dist_params

In [10]:
lamb = 1e-2
deepssm = DeepSSMNet().to(DEVICE)
deepssm.load_state_dict(torch.load(f'../dataset/Ellipsoids/segmentations/DeepSSM/best_model_{lamb}.torch'))
for param in deepssm.parameters():
    param.requires_grad = False
deepssm.eval();

MLP layers: 192 -> 96 -> 48 -> 1


In [21]:
seed = 0
random_state = np.random.RandomState(seed=seed)
perm = random_state.permutation(TOTAL_SIZE)
perm = {
    'train': perm[:TRAIN_SIZE],
    'validation': perm[TRAIN_SIZE:TRAIN_SIZE+VAL_SIZE],
    'test': perm[-TEST_SIZE:],
}

def get_subjects():
    subjects = []
    # train_indices = random_state.permutation(TOTAL_SIZE)[:TRAIN_SIZE]
    for i, image_path in enumerate(tqdm(sorted(glob(f'{IMAGE_DIR}/*.nrrd'))[:TOTAL_SIZE], desc='Creating Subjects')):
        filename = image_path.split('/')[-1]
        mask_path = f'{MASK_DIR}/{filename}'
        image_arti_path = f'{IMAGE_ARTI_DIR}/{filename}'
        subject = tio.Subject(
            t1=tio.ScalarImage(
                image_arti_path if os.path.exists(image_arti_path) else image_path
            ),
            label=tio.LabelMap(mask_path),
            radius=torch.Tensor([float(filename.split('_')[-1][:5])]),
        )
        subjects.append(subject)
    return subjects

def get_dataloader(transform):
    dataloader = dict()
    for mode in ['train', 'validation', 'test']:
        dataloader[mode] = torch.utils.data.DataLoader(
            tio.SubjectsDataset(
                subjects[mode], 
                transform=transform
            ),
            batch_size=1, 
            num_workers=os.cpu_count(),
            shuffle=False,
        )
    return dataloader

all_subjects = get_subjects()
subjects = dict()
for mode in ['train', 'validation', 'test']:
    subjects[mode] = [all_subjects[i] for i in perm[mode]] 
    if mode == 'train':
        subjects[mode] = subjects[mode][:2]

signal = tio.Compose([ 
    tio.RescaleIntensity(percentiles=(0.1, 99.9), out_min_max=(0, 1)),
])

transform = tio.Compose([
    signal,
])

dataloader = get_dataloader(transform)

Creating Subjects: 100%|██████████| 81/81 [00:01<00:00, 47.34it/s]


In [22]:
class Metrics:
    def __init__(self):
        self.metrics = defaultdict(list)
        self.names = ['loss', 'loss_unet', 'loss_ssm', 'au', 'eu']
    def log(self, mode, epoch, values):
        for name, value in zip(self.names, values):
            self.metrics[(mode, epoch, name)].append(value.item())
    def show(self, mode, epoch):
        print()
        for name in self.names:
            mean = np.mean(self.metrics[(mode, epoch, name)])
            print(f'{mode} {name}: {mean}')

def test(save):
    mode = 'test'
    model.load_state_dict(torch.load(f'{output_dir}/best_model.torch'))
    model.eval()

    loss_dice = monai.losses.DiceLoss(squared_pred=True).to(DEVICE)
    loss_mae = torch.nn.L1Loss().to(DEVICE)
    metric = monai.metrics.DiceMetric(include_background=False, reduction='mean_batch')
    metrics = Metrics()

    for i, subject in enumerate(dataloader['test']):
        image = subject['t1'][tio.DATA].to(DEVICE)
        label = subject['label'][tio.DATA].to(DEVICE)
        one_hot_label = monai.networks.utils.one_hot(label, num_classes=num_classes, dim=1).to(DEVICE)
        radius = subject['radius'].to(DEVICE)

        mask, binary_mask, [mu, v, alpha, beta] = model(image)
        loss_unet = loss_dice(mask, one_hot_label)
        loss_ssm = loss_mae(mu, radius)
        au = torch.mean(beta / (alpha - 1))
        eu = torch.mean(beta / (v * (alpha - 1)))

        loss = loss_unet
                
        metric(binary_mask, one_hot_label)
        metrics.log(mode, 1, [loss, loss_unet, loss_ssm, au, eu])
        
        if save:
            dest = f'{output_dir}{mode}_sample{i}.nrrd'
            nrrd.write(dest, binary_mask.detach().cpu().numpy()[0, 0])

    metrics.show(mode, 1)
    mean_dsc = metric.aggregate().tolist()[0]
    metric.reset()
    print(f'{mode} DSC: {mean_dsc}')

def train(model, n_epochs, dataloader, learning_rate, save):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
    loss_dice = monai.losses.DiceLoss(squared_pred=True).to(DEVICE)
    loss_mae = torch.nn.L1Loss().to(DEVICE)
    metric = monai.metrics.DiceMetric(include_background=False, reduction='mean_batch')
    metrics = Metrics()
    best_val_dsc = 0

    for epoch in range(1, n_epochs+1):
        print(f'\nEpoch {epoch}/{n_epochs}')
        for mode in ['train', 'validation']:
            if mode == 'train':
                model.train()
            else:
                model.eval()
            for subject in dataloader[mode]:
                image = subject['t1'][tio.DATA].to(DEVICE)
                label = subject['label'][tio.DATA].to(DEVICE)
                one_hot_label = monai.networks.utils.one_hot(label, num_classes=num_classes, dim=1).to(DEVICE)
                radius = subject['radius'].to(DEVICE)

                optimizer.zero_grad()
                mask, binary_mask, [mu, v, alpha, beta] = model(image)
                loss_unet = loss_dice(mask, one_hot_label)
                loss_ssm = loss_mae(mu, radius)
                au = torch.mean(beta / (alpha - 1))
                eu = torch.mean(beta / (v * (alpha - 1)))

                loss = loss_unet if epoch <= 20 else loss_unet + 0.2 * loss_ssm + 10 * eu

                if mode == 'train':
                    loss.backward()
                    optimizer.step()
                
                metric(binary_mask, one_hot_label)
                metrics.log(mode, epoch, [loss, loss_unet, loss_ssm, au, eu])
            
                if save:
                    dest = f'{output_dir}{mode}_epoch{epoch}.nrrd'
                    if not os.path.exists(dest):
                        nrrd.write(dest, binary_mask.detach().cpu().numpy()[0, 0])

            metrics.show(mode, epoch)
            mean_dsc = metric.aggregate().tolist()[0]
            metric.reset()
            print(f'{mode} DSC: {mean_dsc}')
            
        scheduler.step()
        if mean_dsc > best_val_dsc:
            best_val_dsc = mean_dsc
            best_epoch = epoch
            if save:
                torch.save(model.state_dict(), f'{output_dir}/best_model.torch')
    
    print(f'Best model saved after epoch {best_epoch}.')

In [23]:
model = Model().to(DEVICE)
train(model=model, n_epochs=100, dataloader=dataloader, learning_rate=3e-4, save=False)
# del model
# gc.collect()
# torch.cuda.empty_cache()


Epoch 1/100

train loss: 0.5895840227603912
train loss_unet: 0.5895840227603912
train loss_ssm: 16.737570762634277
train au: 1.1403746604919434
train eu: 6.71121621131897
train DSC: 0.04935355484485626

validation loss: 0.5385707068443298
validation loss_unet: 0.5385707068443298
validation loss_ssm: 18.922767543792723
validation au: 0.4119389832019806
validation eu: 1.3886494088172912
validation DSC: 0.09845607727766037

Epoch 2/100

train loss: 0.5558838546276093
train loss_unet: 0.5558838546276093
train loss_ssm: 13.272321701049805
train au: 0.5839272439479828
train eu: 1.1228052973747253
train DSC: 0.07550905644893646

validation loss: 0.5357460451126098
validation loss_unet: 0.5357460451126098
validation loss_ssm: 18.15577075958252
validation au: 0.22167615711688995
validation eu: 0.5099799835681915
validation DSC: 0.15805105865001678

Epoch 3/100

train loss: 0.535557895898819
train loss_unet: 0.535557895898819
train loss_ssm: 8.435646533966064
train au: 0.2773507907986641
train 

In [24]:
test(save=True)