In [1]:
import os
import time
import json
import numpy as np
import matplotlib.pyplot as plt
import nrrd
import sys
from tqdm import tqdm
from glob import glob
import json
import gc
import torch
from torch import nn
from torch.nn import functional as F
import monai

# Dataloaders

In [2]:
class Ellipsoids(torch.utils.data.Dataset):
    def __init__(self, images, radii):
        self.images = torch.FloatTensor(np.array(images))
        self.radii = torch.FloatTensor(np.array(radii))
    def __len__(self):
        return len(self.images)
    def __getitem__(self, index):
        image = self.images[index]
        radius = self.radii[index]
        return image, radius

In [3]:
image_paths = sorted(glob('../dataset/Ellipsoids/segmentations/*.nrrd'))
images = np.array([nrrd.read(path)[0][None, :] for path in image_paths])
radii = np.array([float(path.split('_')[-1][:5]) for path in image_paths])[:, None]
images.shape, radii.shape

((81, 1, 64, 64, 64), (81, 1))

In [4]:
train_size = 61
perm = np.random.RandomState(seed=0).permutation(81)
perm = {
    'train': perm[:train_size],
    'val': perm[train_size:],
}
modes = list(perm.keys())
dataloaders = dict()
for mode in modes:
    dataloaders[mode] = torch.utils.data.DataLoader(
        Ellipsoids(images[perm[mode]], radii[perm[mode]]),
        batch_size=1,
        shuffle=(mode == 'train'),
        num_workers=6,
        pin_memory=torch.cuda.is_available()
    )

# Model

In [5]:
def init_autoencoder():
    autoencoder = monai.networks.nets.AutoEncoder(
        spatial_dims=3, in_channels=1, out_channels=1,
        kernel_size=(3, 3, 3),
        channels=[channel*1 for channel in (1, 2, 4, 8, 16)],
        strides=(1, 2, 2, 2, 2),
    )
    return autoencoder

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv = init_autoencoder().encode
        self.fc = nn.Sequential(
            nn.Linear(1024, 256),
            nn.PReLU(),
            nn.Linear(256, 64),
            nn.PReLU(),
            nn.Linear(64, 1)
        )
        
    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(1, 1024),
            nn.PReLU(),
            nn.Linear(1024, 1024),
            nn.PReLU(),
            nn.Linear(1024, 1024),
            nn.PReLU(),
        )
        self.deconv = init_autoencoder().decode
        
    def forward(self, x):
        x = self.fc(x)
        x = torch.reshape(x, (1, 16, 4, 4, 4))
        x = self.deconv(x)
        return x

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_dir = '../dataset/Ellipsoids/models/'
if not os.path.exists(model_dir):
    os.mkdir(model_dir)

# Encoder Training

In [10]:
def train_encoder(model, dataloaders, num_epochs, learning_rate):
    opt = torch.optim.Adam(model.parameters(), learning_rate)
    opt.zero_grad()
    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.99)
    loss_mae = torch.nn.L1Loss().to(device)
    
    t0 = time.time()
    best_val_loss = np.Inf
    
    for epoch in range(1, num_epochs+1):
        print(f"Epoch {epoch}/{num_epochs}")
        for mode in ['train', 'val']:
            if mode == 'train':
                model.train()
            else:
                model.eval()
            
            losses = []
            for image, radius in dataloaders[mode]:
                image = image.to(device)
                radius = radius.to(device)
                
                pred_radius = model(image)
                loss = loss_mae(pred_radius, radius)
                
                if mode == 'train':
                    opt.zero_grad()
                    loss.backward()
                    opt.step()

                losses.append(loss.item())

            print(f'{mode} loss: {np.mean(losses)}')
        
        scheduler.step()

        val_loss = np.mean(losses)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = epoch
            torch.save(model.state_dict(), f'{model_dir}/best_encoder.torch')
        print(f'Best val loss: {best_val_loss}')

        time_elapsed = time.time() - t0
        print(f'Time: {time_elapsed}\n')
        t0 = time.time()
        
    print(f"Training complete, model saved. Best model after epoch {best_epoch}")

In [11]:
encoder = Encoder().to(device)
train_encoder(model=encoder, dataloaders=dataloaders, num_epochs=200, learning_rate=1e-4)

Epoch 1/200
train loss: 17.52146376156416
val loss: 13.900359511375427
Best val loss: 13.900359511375427
Time: 1.7360072135925293

Epoch 2/200
train loss: 5.3993022402779
val loss: 1.872096633911133
Best val loss: 1.872096633911133
Time: 1.6278278827667236

Epoch 3/200
train loss: 1.1942282817402825
val loss: 1.0452473640441895
Best val loss: 1.0452473640441895
Time: 1.6498723030090332

Epoch 4/200
train loss: 0.8686908190367651
val loss: 0.7083781719207763
Best val loss: 0.7083781719207763
Time: 1.690495252609253

Epoch 5/200
train loss: 0.6139494317476867
val loss: 0.4434977054595947
Best val loss: 0.4434977054595947
Time: 1.6232569217681885

Epoch 6/200
train loss: 0.41545977357958186
val loss: 0.2651317596435547
Best val loss: 0.2651317596435547
Time: 1.5865592956542969

Epoch 7/200
train loss: 0.5464769113259237
val loss: 0.3924841403961182
Best val loss: 0.2651317596435547
Time: 1.5332262516021729

Epoch 8/200
train loss: 0.34663602172351277
val loss: 0.2639397144317627
Best val 

# Decoder Training

In [15]:
def train_decoder(model, dataloaders, num_epochs, learning_rate):
    opt = torch.optim.Adam(model.parameters(), learning_rate)
    opt.zero_grad()
    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.99)
    loss_dice = monai.losses.DiceLoss(sigmoid=True, squared_pred=True).to(device)

    metric = monai.metrics.DiceMetric(include_background=False, reduction='mean_batch')

    t0 = time.time()
    best_val_dsc = 0
    
    for epoch in range(1, num_epochs+1):
        print(f"Epoch {epoch}/{num_epochs}")
        for mode in ['train', 'val']:
            if mode == 'train':
                model.train()
            else:
                model.eval()
            
            losses = []
            for image, radius in dataloaders[mode]:
                image = image.to(device)
                radius = radius.to(device)
                
                pred_segm = model(radius)
                loss = loss_dice(pred_segm, image)
                
                if mode == 'train':
                    opt.zero_grad()
                    loss.backward()
                    opt.step()

                losses.append(loss.item())
                metric((pred_segm > 0.5).float(), image)

            print(f'{mode} loss: {np.mean(losses)}')
            mean_dsc = metric.aggregate().tolist()[0]
            metric.reset()
            print(f'{mode} DSC: {mean_dsc}')
        
        scheduler.step()

        if mean_dsc > best_val_dsc:
            best_val_loss = mean_dsc
            best_epoch = epoch
            torch.save(model.state_dict(), f'{model_dir}/best_decoder.torch')

        time_elapsed = time.time() - t0
        print(f'Time: {time_elapsed}\n')
        t0 = time.time()
        
    print(f"Training complete, model saved. Best model after epoch {best_epoch}")

In [20]:
decoder = Decoder().to(device)
train_decoder(model=decoder, dataloaders=dataloaders, num_epochs=400, learning_rate=3e-4)

Epoch 1/400
train loss: 0.8785346687817183
train DSC: 0.37331223487854004
val loss: 0.8469746828079223
val DSC: 0.5152209401130676
Time: 2.258800506591797

Epoch 2/400
train loss: 0.8348241272519846
train DSC: 0.5664443373680115
val loss: 0.8111542105674744
val DSC: 0.6486846208572388
Time: 2.0835392475128174

Epoch 3/400
train loss: 0.8177562336452672
train DSC: 0.6284284591674805
val loss: 0.8024275362491607
val DSC: 0.6639474630355835
Time: 2.0355422496795654

Epoch 4/400
train loss: 0.8106378727271909
train DSC: 0.6417031288146973
val loss: 0.7956203520298004
val DSC: 0.6958715319633484
Time: 2.068946599960327

Epoch 5/400
train loss: 0.8042841964080686
train DSC: 0.6601043343544006
val loss: 0.7885804742574691
val DSC: 0.7123348116874695
Time: 2.034174919128418

Epoch 6/400
train loss: 0.7972859202838335
train DSC: 0.6764822602272034
val loss: 0.7806931883096695
val DSC: 0.7121955752372742
Time: 2.007211446762085

Epoch 7/400
train loss: 0.7895635464152352
train DSC: 0.68335330486

In [19]:
del decoder
gc.collect()
torch.cuda.empty_cache()