In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import os
import gc
from glob import glob
from tqdm import tqdm
from collections import defaultdict
import json
import scipy.ndimage as ndimage
import shapeworks as sw
import DeepSSMUtils
from DeepSSMUtils import model
import nrrd
import torchio as tio
import monai
import nibabel as nib

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  method='lar', copy_X=True, eps=np.finfo(np.float).eps,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  method='lar', copy_X=True, eps=np.finfo(np.float).eps,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_Gram=True, verbose=0,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_X=True, fit_path=True,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_X=True, fit_path=True,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes

In [2]:
city = 'Beijing_Zang'
IMAGE_DIR = f'../dataset/{city}/MRI'
MASK_DIR = f'../dataset/{city}/Ventricles3'
num_classes = 2
TRAIN_SIZE = 10
TEST_SIZE = 50
BATCH_SIZE = 1

output_dir = f'../results/{city}_train10unc_art_part/'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [3]:
def show_cuda_memory():
    t = torch.cuda.get_device_properties(0).total_memory
    r = torch.cuda.memory_reserved(0)
    a = torch.cuda.memory_allocated(0)
    f = r-a  # free inside reserved
    print('Total:     {:0.2f} GiB'.format(t / 2**30))
    print('Reserved:  {:0.2f} GiB'.format(r / 2**30))
    print('Allocated: {:0.2f} GiB'.format(a / 2**30))
    print('Free:      {:0.2f} GiB'.format(f / 2**30))

show_cuda_memory()

Total:     11.17 GiB
Reserved:  0.00 GiB
Allocated: 0.00 GiB
Free:      0.00 GiB


In [4]:
deepssm_project = 'DeepSSM_rw1'
explained_var = 90

deepssm_dir = f'../dataset/All/Ventricles_64_3_cleaned/{deepssm_project}'
model_path = f'{deepssm_dir}/DeepSSM{explained_var}.json'
state_path = f'{deepssm_dir}/DeepSSM{explained_var}/best_model.torch'

deepssm = model.DeepSSMNet(model_path).to(DEVICE)
deepssm.load_state_dict(torch.load(state_path))
for param in deepssm.parameters():
    param.requires_grad = False
deepssm.eval()

std_PCA = np.load(f'{deepssm_dir}/torch_loaders{explained_var}/std_PCA.npy').reshape(1, -1)
std_PCA /= std_PCA.sum()
std_PCA = torch.Tensor(std_PCA).to(DEVICE)

MLP layers: 192 -> 96 -> 48 -> 23


In [5]:
mask256_dir = f'../dataset/{city}/Ventricles_256_3'

centers = np.zeros((TRAIN_SIZE+TEST_SIZE, 3), dtype=int)
mus = []
for i, path in enumerate(tqdm(sorted(glob(f'{mask256_dir}/*.nrrd'))[:TRAIN_SIZE+TEST_SIZE])):
    mask = nrrd.read(path)[0]
    centers[i] = ndimage.center_of_mass(mask)
    a, b, c = centers[i]
    _, _, [mu, _, _, _] = deepssm(
        torch.from_numpy(mask[None, None, a-32:a+32, b-32:b+32, c-32:c+32]).to(DEVICE).float()
    )
    mus.append(mu)

100%|██████████| 60/60 [01:49<00:00,  1.82s/it]


In [6]:
class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)

class StraightThroughEstimator(torch.nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
        x = STEFunction.apply(x)
        return x

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.segmentation = monai.networks.nets.UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=num_classes,
            channels=(32, 64, 128, 256, 512),
            strides=(2, 2, 2, 2),
            norm=monai.networks.layers.Norm.BATCH,
        )
        self.softmax = torch.nn.Softmax(dim=1)
        self.ste = StraightThroughEstimator()

    def forward(self, x, i):
        mask = self.segmentation(x) # (1, C, 256, 256, 256)
        mask = self.softmax(mask) # (1, C, 256, 256, 256)
        binary_mask = self.ste(mask - 0.5)

        np_mask = binary_mask[0, 1].detach().cpu().numpy()
        if np.any(np_mask):
            a, b, c = ndimage.center_of_mass(np_mask)
            a, b, c = int(a), int(b), int(c)
            out = False
            for j in [a, b, c]:
                if j - 32 < 0 or j + 32 > 256:
                    out = True
                    break
            if out:
                a, b, c = centers[i]
        else:
            a, b, c = centers[i]

        # a, b, c = centers[i]

        binary_mask = binary_mask[:, 1:2, a-32:a+32, b-32:b+32, c-32:c+32]
        _, _, dist_params = deepssm(binary_mask) # (1, PCs)
        return mask, dist_params

In [7]:
def get_subjects(image_dir, mask_dir):
    subjects = []
    for i, image_path in enumerate(tqdm(sorted(glob(f'{image_dir}/*.nii.gz'))[:TRAIN_SIZE+TEST_SIZE], desc='Creating Subjects')):
        filename = image_path.split('/')[-1]
        mask_path = f'{mask_dir}/{filename}'
        subject = tio.Subject(
            t1=tio.ScalarImage(
                image_path.replace('MRI', 'MRI_arti') if i < TRAIN_SIZE else image_path
            ),
            label=tio.LabelMap(mask_path),
        )
        subjects.append(subject)
    return subjects

all_subjects = get_subjects(IMAGE_DIR, MASK_DIR)
subjects = {
    'train': all_subjects[:TRAIN_SIZE],
    'validation': all_subjects[TRAIN_SIZE:],
}

Creating Subjects: 100%|██████████| 60/60 [00:02<00:00, 24.51it/s]


In [8]:
# spatial = tio.OneOf(
#     {tio.RandomAffine(degrees=(-3, 3), translation=(-0.1, 0.1)): 1.0},
#     p=0.75,
# )

resample = tio.Compose([
    tio.Resample(1),
    tio.CropOrPad(256),
])

signal = tio.Compose([ 
    tio.RescaleIntensity(percentiles=(0.1, 99.9), out_min_max=(0, 1)),
])

def get_transform(std):
    noise = tio.Compose([ 
        tio.RandomNoise(mean=0, std=(std, std)),
    ])
    transform = {
        'train': tio.Compose([
            # spatial, 
            resample, 
            noise,
            signal,
        ]),
        'validation': tio.Compose([
            resample, 
            noise,
            signal,
        ]),
    }
    return transform

def get_dataloader(transform):
    dataloader = dict()
    for mode in ['train', 'validation']:
        dataloader[mode] = torch.utils.data.DataLoader(
            tio.SubjectsDataset(
                subjects[mode], 
                transform=transform[mode]
            ),
            batch_size=BATCH_SIZE, 
            num_workers=os.cpu_count(),
            shuffle=False,
        )
    return dataloader

In [9]:
def validate(model, loss_dice, loss_mse, metric, losses1, losses2, dscs, std, dataloader, epoch):
    model.eval()
    with torch.no_grad():
        mean_loss1, mean_loss2, mean_loss_au, mean_loss_eu = 0, 0, 0, 0
        for i, subject in enumerate(dataloader['validation']):
            image = subject['t1'][tio.DATA].to(DEVICE)
            label = subject['label'][tio.DATA].to(DEVICE)

            mask, [mu, v, alpha, beta] = model(image, TRAIN_SIZE+i)
            one_hot_label = monai.networks.utils.one_hot(
                label, num_classes=num_classes, dim=1
            ).to(DEVICE)
                
            loss1 = loss_dice(mask, one_hot_label)
            loss2 = torch.abs(mu - mus[TRAIN_SIZE+i]).mean()
            loss_au = torch.mean(beta / (alpha - 1))
            loss_eu = torch.mean(beta / (v * (alpha - 1)))

            # loss2 = torch.sum(std_PCA * pca ** 2)
            # loss2 = torch.mean(pca ** 2)
            # loss2 = torch.mean((corr_out - corr_outs[TRAIN_SIZE+i]) ** 2)

            mean_loss1 += loss1 * image.shape[0]
            mean_loss2 += loss2 * image.shape[0]
            mean_loss_au += loss_au * image.shape[0]
            mean_loss_eu += loss_eu * image.shape[0]

            one_hot_pred = monai.networks.utils.one_hot(
                torch.argmax(mask, dim=1, keepdim=True), 
                num_classes=num_classes, 
                dim=1
            ).to(DEVICE)
            metric(one_hot_pred, one_hot_label)
            
            # if i % 10 == 0:
            #     if not os.path.exists(f'{output_dir}valimage{i}_label.nrrd'):
            #         nrrd.write(f'{output_dir}valimage{i}_label.nrrd', label.detach().cpu().numpy()[0, 0])
            #     nrrd.write(f'{output_dir}valimage{i}_prob_epoch{epoch+1}.nrrd', binary_mask.detach().cpu().numpy()[0, 0])

            del image, label, mask, one_hot_label, one_hot_pred, loss1, loss2, loss_au, loss_eu, mu, v, alpha, beta
            gc.collect()
            torch.cuda.empty_cache()

        mean_loss1 = mean_loss1.item() / TEST_SIZE
        mean_loss2 = mean_loss2.item() / TEST_SIZE
        mean_loss_au = mean_loss_au.item() / TEST_SIZE
        mean_loss_eu = mean_loss_eu.item() / TEST_SIZE
        print(f'Validation Loss1: {mean_loss1}')
        print(f'Validation Loss2: {mean_loss2}')
        print(f'Validation Loss AU: {mean_loss_au}')
        print(f'Validation Loss EU: {mean_loss_eu}')

        mean_dsc = metric.aggregate().tolist()
        metric.reset()
        print(f'Validation DSC: {mean_dsc}\n')

In [10]:
def train(model, n_epochs, dataloader, std, weight):
    losses1, losses2, dscs = [defaultdict(list) for _ in range(3)]
    
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4 * 0.99 ** 30)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
    loss_dice = monai.losses.DiceLoss(squared_pred=True).to(DEVICE)
    loss_mse = torch.nn.MSELoss().to(DEVICE)
    metric = monai.metrics.DiceMetric(include_background=False, reduction='mean_batch')

    for epoch in range(n_epochs):
        print(f'Epoch {epoch+1}/{n_epochs}')
        model.train()

        mean_loss1, mean_loss2, mean_loss_au, mean_loss_eu = 0, 0, 0, 0
        for i, subject in enumerate(dataloader['train']):
            
            image = subject['t1'][tio.DATA].to(DEVICE)
            label = subject['label'][tio.DATA].to(DEVICE)
            one_hot_label = monai.networks.utils.one_hot(
                label, num_classes=num_classes, dim=1
            ).to(DEVICE)
        
            mask, [mu, v, alpha, beta] = model(image, i)
            loss1 = loss_dice(mask, one_hot_label)
            loss2 = torch.abs(mu - mus[i]).mean()
            loss_au = torch.mean(beta / (alpha - 1))
            loss_eu = torch.mean(beta / (v * (alpha - 1)))

            # loss2 = loss_mse(pca, pca_scores[i])
            # loss2 = torch.sum(std_PCA * pca ** 2)
            # loss2 = torch.mean(pca ** 2)
            # loss2 = torch.mean((corr_out - corr_outs[i]) ** 2)

            loss = loss1 + weight * (loss2 + loss_au + loss_eu)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            mean_loss1 += loss1 * image.shape[0]
            mean_loss2 += loss2 * image.shape[0]
            mean_loss_au += loss_au * image.shape[0]
            mean_loss_eu += loss_eu * image.shape[0]

            one_hot_pred = monai.networks.utils.one_hot(
                torch.argmax(mask, dim=1, keepdim=True), 
                num_classes=num_classes, 
                dim=1
            ).to(DEVICE)
            metric(one_hot_pred, one_hot_label)

            # if i % 10 == 0:
            #     if not os.path.exists(f'{output_dir}trainimage{i}_label.nrrd'):
            #         nrrd.write(f'{output_dir}trainimage{i}_label.nrrd', label.detach().cpu().numpy()[0, 0])
            #     nrrd.write(f'{output_dir}trainimage{i}_prob_epoch{epoch+1}.nrrd', binary_mask.detach().cpu().numpy()[0, 0])

            del image, label, mask, one_hot_label, one_hot_pred, loss, loss1, loss2, loss_au, loss_eu, mu, v, alpha, beta
            gc.collect()
            torch.cuda.empty_cache()

        scheduler.step()

        mean_loss1 = mean_loss1.item() / TRAIN_SIZE
        mean_loss2 = mean_loss2.item() / TRAIN_SIZE
        mean_loss_au = mean_loss_au.item() / TRAIN_SIZE
        mean_loss_eu = mean_loss_eu.item() / TRAIN_SIZE
        print(f'Train Loss1: {mean_loss1}')
        print(f'Train Loss2: {mean_loss2}')
        print(f'Train Loss AU: {mean_loss_au}')
        print(f'Train Loss EU: {mean_loss_eu}')

        mean_dsc = metric.aggregate().tolist()
        metric.reset()
        print(f'Train DSC: {mean_dsc}')

        validate(model, loss_dice, loss_mse, metric, losses1, losses2, dscs, std, dataloader, epoch)

        # if (epoch+1) % 10 == 0:
        #     torch.save(model.state_dict(), f'{output_dir}model_epoch{epoch+1}.pth')

In [11]:
std = 0
transform = get_transform(std=std)
dataloader = get_dataloader(transform)

In [13]:
for weight in [0.3, 0.5]:
    print(f'weight: {weight}')
    model = Model().to(DEVICE)
    model.load_state_dict(torch.load(f'{output_dir}model_epoch30.pth'))
    train(model, n_epochs=30, dataloader=dataloader, std=std, weight=weight)

    # torch.save(model.state_dict(), f'{output_dir}/UNet_std1000.pth')

    del model#, transform, dataloader, losses1, losses2, dscs
    gc.collect()
    torch.cuda.empty_cache()

weight: 0.3
Epoch 1/30
Train Loss1: 0.18345942497253417
Train Loss2: 0.1873113751411438
Train Loss AU: 0.3035589218139648
Train Loss EU: 0.38634819984436036
Train DSC: [0.6925975680351257]
Validation Loss1: 0.20671138763427735
Validation Loss2: 0.228212890625
Validation Loss AU: 0.33244895935058594
Validation Loss EU: 0.40496665954589844
Validation DSC: [0.6427181959152222]

Epoch 2/30
Train Loss1: 0.1561685562133789
Train Loss2: 0.15501117706298828
Train Loss AU: 0.28419544696807864
Train Loss EU: 0.3742222309112549
Train DSC: [0.7347933650016785]
Validation Loss1: 0.21047082901000977
Validation Loss2: 0.2594340515136719
Validation Loss AU: 0.34664169311523435
Validation Loss EU: 0.41710739135742186
Validation DSC: [0.5866115689277649]

Epoch 3/30
Train Loss1: 0.14797120094299315
Train Loss2: 0.1644556403160095
Train Loss AU: 0.28671631813049314
Train Loss EU: 0.3681331634521484
Train DSC: [0.7377884387969971]
Validation Loss1: 0.2704115295410156
Validation Loss2: 0.3139496803283691
V