In [None]:
import torch
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch import nn
from torch.utils.data import DataLoader

import numpy as np
import SimpleITK as sitk
import os, skimage, sys

sys.path.append('../Packages')
from util import riemann, tensors, diff
import data.convert as convert
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm as tqdm
from plot import *
from pde import *
from dataset import *
from model import *

%matplotlib widget

In [None]:
mode='gpu'

if mode=='gpu':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # after switch device, you need restart the kernel
    torch.cuda.set_device(1)
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    device = torch.device('cpu')
    torch.set_default_dtype(torch.float64)

### [Matrix exponential for 2D](https://en.wikipedia.org/wiki/Matrix_exponential)
${\displaystyle e^{tA}=e^{st}\left(\left(\cosh(qt)-s{\frac {\sinh(qt)}{q}}\right)~I~+{\frac {\sinh(qt)}{q}}A\right)~.}$ where $t=1$, $s=tr({A}/2), q=-\sqrt{-\det(A-sI)}$

In [None]:
brain_id = 100610
input_dir = '../Brains'
output_dir = f'../Checkpoints/{brain_id}'
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

# Training

In [None]:
resume = False
save_model = True
print(f'resume:{resume}, save_model:{save_model}')

In [None]:
epoch_loss_list = []
epoch_num = 10001
start_epoch_num = 10001
batch_num = 1
learning_rate = 1e-2

model = DenseED(in_channels=2, out_channels=3, 
                imsize=100,
                blocks=[6, 8, 6],
                growth_rate=16, 
                init_features=48,
                drop_rate=0,
                out_activation=None,
                upsample='nearest')
model.train()
if mode=='gpu':
    model.cuda()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5)

dataset_id = ImageDataset(input_dir, sample_name_list=[brain_id])
dataloader_id = DataLoader(dataset_id, batch_size=1, shuffle=False, num_workers=0)

if resume:
    checkpoint = torch.load(f'{output_dir}/epoch_{start_epoch_num-1}_checkpoint.pth.tar')    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    with open(f'{output_dir}/loss.txt', 'a') as f:
        f.write(f'From {start_epoch_num} to {epoch_num+start_epoch_num}\n')
        f.write(f'MSE; Adadelta, lr={learning_rate}; \n')
else:
    start_epoch_num = 0  
    
    with open(f'{output_dir}/loss.txt', 'w+') as f:
        f.write(f'From {start_epoch_num} to {epoch_num+start_epoch_num}\n')
        f.write(f'MSE; Adadelta: lr={learning_rate}; \n')
    
print(f'Starting from iteration {start_epoch_num} to iteration {epoch_num+start_epoch_num}')

In [None]:
for epoch in tqdm(range(start_epoch_num, start_epoch_num+epoch_num)):
    epoch_loss_id = 0
            
    for i, batched_id_sample in enumerate(dataloader_id):
        input_id = batched_id_sample['vector_field'].to(device).float()
        input_id.requires_grad = True
        mask = batched_id_sample['mask'].float()
        u_pred_id = model(input_id)[...,1:146,1:175]
        pde_loss = pde(u_pred_id.squeeze(), input_id.squeeze(), mask.squeeze(), differential_accuracy=2)
        f_pred_id = torch.einsum('...ij,...ij->...ij', pde_loss, mask.squeeze().unsqueeze(0).expand(2,-1,-1))
        f_true_id = torch.zeros_like(f_pred_id)
    
        optimizer.zero_grad()
        loss_id = criterion(f_pred_id, f_true_id)
        loss_id.backward()
        epoch_loss_id += loss_id.item()
        optimizer.step()
    scheduler.step(epoch_loss_id)
        
    with open(f'{output_dir}/loss.txt', 'a') as f:
        f.write(f'{epoch_loss_id}\n')
    
    print(f'epoch {epoch} MSELoss: {epoch_loss_id}')
    epoch_loss_list.append(epoch_loss_id)
    if epoch%100==0:       
        if save_model:
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_id_state_dict': optimizer.state_dict(),
            'loss_id': epoch_loss_id,
            }, f'{output_dir}/model.pth.tar')

In [None]:
plt.figure(figsize=(5,5))
plt.xlabel('Iteration')
plt.ylabel('MSE loss')
plt.plot(epoch_loss_list)
plt.savefig(f'{output_dir}/adadelta_loss_{learning_rate}.png')

# Inference

In [None]:
checkpoint = torch.load(f'{output_dir}/model.pth.tar')
model = DenseED(in_channels=2, out_channels=3, 
                imsize=100,
                blocks=[6, 8, 6],
                growth_rate=16, 
                init_features=48,
                drop_rate=0,
                out_activation=None,
                upsample='nearest')
model.load_state_dict(checkpoint['model_state_dict'])

vector_lin = convert.read_nhdr(f'{input_dir}/{brain_id}/{brain_id}_vector_field.nhdr').to(device).permute(2,0,1).float()
mask = convert.read_nhdr(f'{input_dir}/{brain_id}/{brain_id}_filt_mask.nhdr').permute(1,0).float()
eroded_mask = skimage.morphology.erosion(mask.cpu().numpy(), skimage.morphology.square(4))

u_pred = model(vector_lin.unsqueeze(0))
u_pred = u_pred[...,1:146,1:175].squeeze()
s_pred = tensors.lin2mat(u_pred)

metric_pred_mat = matrix_exp_2d(s_pred)
metric_pred_lin = tensors.mat2lin(metric_pred_mat)
show_2d_tensors(metric_pred_mat.cpu().detach()*mask.unsqueeze(-1).unsqueeze(-1), scale=1e0, title='Learned Metric', margin=0.05, dpi=15)

file_name = f'{input_dir}/{brain_id}/{brain_id}_learned_metric_final.nhdr'
sitk.WriteImage(sitk.GetImageFromArray(np.transpose(metric_pred_lin.cpu().detach().numpy(),(2,1,0))), file_name)
print(f'{file_name} saved')