In [None]:
import torch, os, sys
import torch.nn.functional as F
import SimpleITK as sitk
import numpy as np

sys.path.append('../Packages')
from torch import nn
from torch.utils.data import DataLoader, Dataset
from skimage import data, filters
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm as tqdm
from model import *
from plot import *
from util import riemann, tensors, diff
import data.convert as convert

%matplotlib widget

In [None]:
class ImageDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        
    def __len__(self):
        return 1
        
    def __getitem__(self, idx):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        vector_field1_path = f'{self.data_dir}/sin/sin_vector_field.nhdr'
        mask1_path = f'{self.data_dir}/sin/sin_filt_mask.nhdr'
        vector_field1 = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(vector_field1_path))).permute(2,0,1).to(device).float()*1000.0
        mask1 = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(mask1_path))).permute(1,0)
        boundary_mask1 = torch.where(torch.from_numpy(filters.laplace(mask1))>0,1,0)
        mask1 = (mask1-boundary_mask1).float().to(device)
        
        vector_field2_path = f'{self.data_dir}/cos/cos_vector_field.nhdr'
        mask2_path = f'{self.data_dir}/cos/cos_filt_mask.nhdr'
        vector_field2 = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(vector_field2_path))).permute(2,0,1).to(device).float()*1000.0
        mask2 = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(mask2_path))).permute(1,0)
        boundary_mask2 = torch.where(torch.from_numpy(filters.laplace(mask2))>0,1,0)
        mask2 = (mask2-boundary_mask2).float().to(device)

        sample = {  'vector_field'  : torch.cat((vector_field1,vector_field2),0),
                    'mask1'          : mask1.unsqueeze(0),
                    'mask2'          : mask2.unsqueeze(0)}
        return sample

def matrix_exp_2d(A):
    """
    Construct positive definite matrix from symmetric matrix field A
    Args:
        A, torch.Tensor
    Returns: 
        psd, torch.Tensor
    """
    I = torch.zeros_like(A, device='cuda')
    I[...,0,0] = 1
    I[...,1,1] = 1
    
    s = ((A[...,0,0]+A[...,1,1])/2.).unsqueeze(-1).unsqueeze(-1)
    q = torch.sqrt(-torch.det(A-torch.mul(s, I))).unsqueeze(-1).unsqueeze(-1)
    
    psd = torch.exp(s)*(torch.mul((torch.cosh(q)-s*torch.sinh(q)/q),I)+torch.sinh(q)/q*A)
    return psd

def pde(u, vector_lin, mask, differential_accuracy=2):
    s = tensors.lin2mat(u)
    metric_mat = matrix_exp_2d(s)
    nabla_vv = riemann.covariant_derivative_2d(vector_lin, metric_mat, mask, differential_accuracy=differential_accuracy)
    
    return nabla_vv

In [None]:
mode='gpu'

if mode=='gpu':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # after switch device, you need restart the kernel
    torch.cuda.set_device(1)
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    device = torch.device('cpu')
    torch.set_default_dtype(torch.float64)

### [Matrix exponential for 2D](https://en.wikipedia.org/wiki/Matrix_exponential)
${\displaystyle e^{tA}=e^{st}\left(\left(\cosh(qt)-s{\frac {\sinh(qt)}{q}}\right)~I~+{\frac {\sinh(qt)}{q}}A\right)~.}$ where $t=1$, $s=tr({A}/2), q=-\sqrt{-\det(A-sI)}$

In [None]:
input_dir = '../Brains'
output_dir = f'../Checkpoints/braid'
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

In [None]:
name = 'sin'
mask1 = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/{name}/{name}_filt_mask.nhdr'))).double().permute(1,0)
height, width = mask1.shape[0], mask1.shape[1]
x_range = (0, height-1)
y_range = (0, width-1)
boundary_mask = torch.where(torch.from_numpy(filters.laplace(mask1.cpu().numpy()))>0,1,0)

fig, (ax1,ax2,ax3) = plt.subplots(1, 3, figsize=(9, 3))
ax1.imshow(mask1.numpy())
ax1.set_title('orig mask')
ax2.imshow(boundary_mask.numpy())
ax2.set_title('boundary mask')
mask1 = mask1-boundary_mask
ax3.imshow(mask1.numpy())
ax3.set_title('orig-boundary mask')

plt.show()

mask1 = mask1.to(device)

In [None]:
name = 'cos'
mask2 = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/{name}/{name}_filt_mask.nhdr'))).double().permute(1,0)
height, width = mask2.shape[0], mask2.shape[1]
x_range = (0, height-1)
y_range = (0, width-1)
boundary_mask = torch.where(torch.from_numpy(filters.laplace(mask2.cpu().numpy()))>0,1,0)

fig, (ax1,ax2,ax3) = plt.subplots(1, 3, figsize=(9, 3))
ax1.imshow(mask2.numpy())
ax1.set_title('orig mask')
ax2.imshow(boundary_mask.numpy())
ax2.set_title('boundary mask')
mask2 = mask2-boundary_mask
ax3.imshow(mask2.numpy())
ax3.set_title('orig-boundary mask')

plt.show()

mask2 = mask2.to(device)

In [None]:
plt.figure()
plt.imshow((mask1+mask2).cpu().numpy())

In [None]:
name = 'sin'
vector_field1 = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/{name}/{name}_vector_field.nhdr'))).to(device)
x = torch.linspace(0,height-1,height)
y = torch.linspace(0,width-1,width)
xx, yy = torch.meshgrid(x,y)
plt.figure(figsize=(8,8))
plt.quiver(xx.cpu(),yy.cpu(),vector_field1[:,:,0].cpu().numpy()*mask1.cpu().numpy(),vector_field1[:,:,1].cpu().numpy()*mask1.cpu().numpy())

In [None]:
name = 'cos'
vector_field2 = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/{name}/{name}_vector_field.nhdr'))).to(device)
x = torch.linspace(0,height-1,height)
y = torch.linspace(0,width-1,width)
xx, yy = torch.meshgrid(x,y)
plt.figure(figsize=(8,8))
plt.quiver(xx.cpu(),yy.cpu(),vector_field2[:,:,0].cpu().numpy()*mask2.cpu().numpy(),vector_field2[:,:,1].cpu().numpy()*mask2.cpu().numpy(),scale=3e1)

# Training

In [None]:
resume = False
save_model = True
print(f'resume:{resume}, save_model:{save_model}')

In [None]:
epoch_loss_list = []
epoch_num = 1001
start_epoch_num = 1001
learning_rate = 1e-2

model = DenseED(in_channels=4, out_channels=3, 
                imsize=100,
                blocks=[6, 8, 6],
                growth_rate=16, 
                init_features=48,
                drop_rate=0,
                out_activation=None,
                upsample='nearest')
model.train()
if mode=='gpu':
    model.cuda()
criterion = torch.nn.MSELoss()
optimizer_id = torch.optim.Adadelta(model.parameters(), lr=learning_rate)

dataset_id = ImageDataset(input_dir)
dataloader_id = DataLoader(dataset_id, batch_size=1, shuffle=True, num_workers=0)

if resume:
    checkpoint = torch.load(f'{output_dir}/epoch_{start_epoch_num-1}_checkpoint.pth.tar')    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    with open(f'{output_dir}/loss.txt', 'a') as f:
        f.write(f'From {start_epoch_num} to {epoch_num+start_epoch_num}\n')
        f.write(f'MSE; Adadelta, lr={learning_rate};\n')
else:
    start_epoch_num = 0  
    
    with open(f'{output_dir}/loss.txt', 'w+') as f:
        f.write(f'From {start_epoch_num} to {epoch_num+start_epoch_num}\n')
        f.write(f'MSE; Adadelta: lr={learning_rate};\n')
    
print(f'Starting from iteration {start_epoch_num} to iteration {epoch_num+start_epoch_num}')

In [None]:
for epoch in tqdm(range(start_epoch_num, start_epoch_num+epoch_num)):
    epoch_loss_id = 0
    epoch_loss_bd = 0
            
    for i, batched_id_sample in enumerate(dataloader_id):
        '''innerdomain backpropagate'''
        input_id = batched_id_sample['vector_field'].to(device)
        input_id.requires_grad = True
        # u_pred_id: [batch_size, *data_shape, feature_num] = [1, 3, ...]
        u_pred_id = model(input_id)
        # f_pred_id: [batch_size, feature_num, *data_shape] = [1, 2, ...]
        mask1 = batched_id_sample['mask1'].squeeze()
        mask2 = batched_id_sample['mask2'].squeeze()
        pde_loss1 = pde(u_pred_id.squeeze(), input_id[0,:2].squeeze(), mask1, differential_accuracy=2)
        pde_loss2 = pde(u_pred_id.squeeze(), input_id[0,2:].squeeze(), mask2, differential_accuracy=2)
        f_pred_1 = torch.einsum('...ij,...ij->...ij', pde_loss1, mask1.unsqueeze(0).expand(2,-1,-1))
        f_pred_2 = torch.einsum('...ij,...ij->...ij', pde_loss2, mask2.unsqueeze(0).expand(2,-1,-1))
        f_pred_id = torch.cat((f_pred_1,f_pred_2),0)
        # f_true_id: [batch_size, feature_num, *data_shape] = [1, 2, ...]
        f_true_id = torch.zeros_like(f_pred_id)
    
        optimizer_id.zero_grad()
        loss_id = criterion(f_pred_id, f_true_id)
        loss_id.backward()
        epoch_loss_id += loss_id.item()
        optimizer_id.step()
        
    with open(f'{output_dir}/loss.txt', 'a') as f:
        f.write(f'{epoch_loss_id}\n')
    
    print(f'epoch {epoch} innerdomain loss: {epoch_loss_id}, norm: {torch.norm(f_pred_id,2)**2}')
    epoch_loss_list.append(epoch_loss_id)
    if epoch%10==0:       
        if save_model:
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_id_state_dict': optimizer_id.state_dict(),
            'loss_id': epoch_loss_id
            }, f'{output_dir}/model.pth.tar')

In [None]:
plt.figure(figsize=(7,5))
plt.xlabel('Iteration')
plt.ylabel('MSE Loss')
plt.plot(epoch_loss_list)
plt.savefig(f'{output_dir}/adadelta_loss_{learning_rate}.png')

# Inference

In [None]:
checkpoint = torch.load(f'{output_dir}/model.pth.tar')
model = DenseED(in_channels=4, out_channels=3, 
                imsize=100,
                blocks=[6, 8, 6],
                growth_rate=16, 
                init_features=48,
                drop_rate=0,
                out_activation=None,
                upsample='nearest')
model.load_state_dict(checkpoint['model_state_dict'])

name = 'sin'
vector_field1 = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/{name}/{name}_vector_field.nhdr'))).to(device).permute(2,0,1).unsqueeze(0).float()
mask1 = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/{name}/{name}_filt_mask.nhdr'))).permute(1,0).to(device)
boundary_mask1 = torch.where(torch.from_numpy(filters.laplace(mask1.cpu()))>0,1,0)
mask1 = (mask1.cpu()-boundary_mask1.cpu()).float()

name = 'cos'
vector_field2 = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/{name}/{name}_vector_field.nhdr'))).to(device).permute(2,0,1).unsqueeze(0).float()
mask2 = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(f'{input_dir}/{name}/{name}_filt_mask.nhdr'))).permute(1,0).to(device)
boundary_mask2 = torch.where(torch.from_numpy(filters.laplace(mask2.cpu()))>0,1,0)
mask2 = (mask2.cpu()-boundary_mask2.cpu()).float()

mask = torch.where(mask1+mask2>0,1,0)

u_pred = model(torch.cat((vector_field1,vector_field2),1))
u_pred = u_pred.squeeze()
s_pred = tensors.lin2mat(u_pred)

metric_pred_mat = matrix_exp_2d(s_pred)
metric_pred_lin = tensors.mat2lin(metric_pred_mat)
show_2d_tensors(metric_pred_mat.cpu().detach()*mask.cpu().unsqueeze(-1).unsqueeze(-1).numpy(), scale=1e0, title='Learned Metric', margin=0.05, dpi=15)

file_name = f'{output_dir}/braid_learned_metric_final.nhdr'
sitk.WriteImage(sitk.GetImageFromArray(np.transpose(metric_pred_lin.cpu().detach().numpy(),(2,1,0))), file_name)
print(f'{file_name} saved')