In [None]:
import torch, os, sys
sys.path.append('../Packages')
import numpy as np
import SimpleITK as sitk
import util.riemann as riemann
import util.tensors as tensors
import data.convert as convert

from torch.utils.data import DataLoader
from skimage import data, filters
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm as tqdm
from itkwidgets import view
from dataset import DatasetHCP
from pde import *
from model3D import *

%matplotlib widget

# Overview
## Network Architecture
<img src="../Figures/architecture.png" alt="drawing" width="800"/>

## Eigen Composition
<img src="../Figures/eigencomposition.png" alt="drawing" width="500"/>

In [None]:
mode='gpu'

if mode=='gpu':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # after switch device, you need restart the kernel
    torch.cuda.set_device(1)
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    device = torch.device('cpu')
    torch.set_default_dtype(torch.float64)

In [None]:
brain_id = 111312
input_dir = '../Brains'
output_dir = f'../Checkpoints/{brain_id}'
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

# Training

In [None]:
resume = False
save_model = True
print(f'resume:{resume}, save_model:{save_model}')

In [None]:
epoch_loss_list = []
epoch_num = 10000
start_epoch_num = 10001
batch_num = 1
learning_rate = 3e-4
blocks = [40,30,40]

model = DenseED(in_channels=3, out_channels=7, 
                imsize=100,
                blocks=blocks,
                growth_rate=16, 
                init_features=48,
                drop_rate=0,
                out_activation=None,
                upsample='nearest')
model.train()
if mode=='gpu':
    model.cuda()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)

dataset_id = DatasetHCP(input_dir, sample_name_list=[str(brain_id)])
dataloader_id = DataLoader(dataset_id, batch_size=1, shuffle=False, num_workers=0)

if resume:
    checkpoint = torch.load(f'{output_dir}/epoch_{start_epoch_num-1}_checkpoint.pth.tar')    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    with open(f'{output_dir}/loss.txt', 'a') as f:
        f.write(f'From {start_epoch_num} to {epoch_num+start_epoch_num}\n')
        f.write(f'Adadelta, lr={learning_rate};\n')
else:
    start_epoch_num = 0  
    
    with open(f'{output_dir}/loss.txt', 'w+') as f:
        f.write(f'Architecture {blocks}\n')
        f.write(f'From {start_epoch_num} to {epoch_num+start_epoch_num}\n')
        f.write(f'Adadelta: lr={learning_rate};\n')
    
print(f'Starting from iteration {start_epoch_num} to iteration {epoch_num+start_epoch_num}')

In [None]:
for epoch in tqdm(range(start_epoch_num, start_epoch_num+epoch_num)):
    epoch_loss_id = 0
            
    for i, batched_id_sample in enumerate(dataloader_id):
        input_id = batched_id_sample['vector_field']
        input_id.requires_grad = True
        optimizer.zero_grad()
        
        u_pred_id = model(input_id)
        mask = batched_id_sample['mask'].squeeze()
        pde_loss = pde(u_pred_id.squeeze(), input_id.squeeze(), mask)
        f_pred_id = torch.einsum('...ij,...ij->...ij', pde_loss, mask.expand(3,-1,-1,-1))
        f_true_id = torch.zeros_like(f_pred_id)
    
        loss_id = criterion(f_pred_id, f_true_id)
        loss_id.backward()
        optimizer.step()
        epoch_loss_id += loss_id.item()
    scheduler.step(epoch_loss_id)
        
    with open(f'{output_dir}/loss.txt', 'a') as f:
        f.write(f'{epoch_loss_id}\n')
    
    print(f'epoch {epoch} MSE loss: {epoch_loss_id}, lr: ', optimizer.param_groups[0]['lr'])
    epoch_loss_list.append(epoch_loss_id)
    if epoch%100==0:
        if save_model:
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_id_state_dict': optimizer.state_dict(),
            'loss_id': epoch_loss_id,
            }, f'{output_dir}/model.pth.tar')

    if epoch_loss_id<1e6:
        torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_id_state_dict': optimizer.state_dict(),
        'loss_id': epoch_loss_id,
        }, f'{output_dir}/model.pth.tar')
        break

In [None]:
plt.figure(figsize=(5,5))
plt.xlabel('Iterations')
plt.ylabel('MSE Loss')
plt.plot(epoch_loss_list)
plt.savefig(f'{output_dir}/adadelta_loss_{learning_rate}.png')

# Inference

In [None]:
checkpoint = torch.load(f'{output_dir}/model.pth.tar')
model = DenseED(in_channels=3, out_channels=7, 
                imsize=100,
                blocks=blocks,
                growth_rate=16, 
                init_features=48,
                drop_rate=0,
                out_activation=None,
                upsample='nearest')
model.load_state_dict(checkpoint['model_state_dict'])

vector_lin = convert.read_nhdr(f'{input_dir}/{brain_id}/{brain_id}_shrinktensor_principal_vector_field.nhdr').to(device).float()
mask = convert.read_nhdr(f'{input_dir}/{brain_id}/{brain_id}_shrinktensor_filt_mask.nhdr').to(device).float()

u_pred = model(vector_lin.unsqueeze(0))
u_pred = u_pred.squeeze()

metric_pred_mat = eigen_composite(u_pred)
metric_pred_lin = tensors.mat2lin(metric_pred_mat)
tensor_pred_mat = torch.inverse(metric_pred_mat)
tensor_pred_lin = tensors.mat2lin(tensor_pred_mat)

file_name = f'{output_dir}/{brain_id}_learned_metric_final.nhdr'
sitk.WriteImage(sitk.GetImageFromArray(np.transpose(metric_pred_lin.cpu().detach().numpy(),(3,2,1,0))), file_name)