In [None]:
import unittest
import random

import numpy as np
import astra

from skimage.transform import AffineTransform, warp
from skimage.metrics import structural_similarity as ssim
from skimage.util import random_noise
from skimage.metrics import peak_signal_noise_ratio

from pykeops.torch import Vi, Vj
from pykeops.torch import LazyTensor

from src.algs.arm import lv_indicator
from src.tools.recon.projector import forward_projector
from src.tools.lddmm.manifold_mapping import pull_back
from src.tools.manip.manip import normalize_volume

# data fetching and handling
from data.check_database import load_remote_data
from data.fetch_data import fetch_data
from src.tools.data.loadvolumes import LoadVolumes

from scipy.ndimage import convolve

import matplotlib.pyplot as plt

import torchvision.transforms as torch_transform
from torch.autograd import grad
import torch
from torchvision.transforms import Lambda

from math import prod

use_cuda = torch.cuda.is_available()
torchdeviceId = torch.device("cuda:0") if use_cuda else "cpu"

from kornia.filters import SpatialGradient as k_grad
from kornia.filters import GaussianBlur2d

from geomloss import SamplesLoss

from src.util.timer import tic, toc
import time

err_eps = 1e-10

## Setup for the data storage

In [None]:
lv_model_volume = None
lv_model_frames = None
lv_motion_frames = None

In [None]:
volume = np.zeros([64, 64, 64])
params = dict(a=1, c=2, sigma=-1)
transform_params = [np.eye(3, 3), [16, 16, 0], 1.5]

recon_mode = 'basic'
fprojector = forward_projector(recon_mode)

# getting the model left ventricle volume and its' forward projected frames
lv_model_volume = lv_indicator(volume, params, transform_params)
lv_model_frames = fprojector(lv_model_volume)
lv_motion_frames = np.zeros(lv_model_frames.shape)

## Hamiltonian ODEs solvers

In [None]:
def RalstonIntegrator(): # correct but it should be optimized!
    def f(ODESystem, x0, nt, deltat=1.0):
        x = tuple(map(lambda x: x.clone(), x0)) # we need just the last element on which we do pull-back on
        dt = deltat / nt
        # l = [tuple([x[0], x[1][-1]])]
        l = [x[1][0]]
        for i in range(nt):
            xdot = ODESystem(*x)
            xi = tuple(map(lambda x, xdot: x + (2 * dt / 3) * xdot, x, xdot))
            xdoti = ODESystem(*xi)
            x = tuple(
                map(
                    lambda x, xdot, xdoti: x + (0.25 * dt) * (xdot + 3 * xdoti),
                    x,
                    xdot,
                    xdoti,
                )
            )
            # l.append(tuple([x[0], x[1][-1]]))
            l.append(x[1][i])
        return l

    return f

In [None]:
def Shooting(z0, im0, K, K_Q, lam, rho, nt=63, Integrator=RalstonIntegrator()):
    return Integrator(HamiltonianSystem(K, K_Q, z0, lam, rho), (z0, im0), nt)

In [None]:
# to convert torch.tensor lists -> tensor
def Gradient(fun):
    return torch.stack(torch.gradient(fun), dim=0) # (C, W, H) format
    # return k_grad()(fun[..., None, None])[..., 0, 0]

In [None]:
def Residual(K_Q):
    def R(z0):
        return (z0 ** 2).sum()
    return R

In [None]:
# fix this according to [2]
def DiffNorm(K):
    def V(z0, im):
        q = z0.double() * Gradient(im)
        return (q * K(q[None, ...])[0]).sum()

    return V

In [None]:
# try alternative norms too  
def Hamiltonian(K, K_Q): # update Hamiltonian to the shooting scheme as in [1], [2]
    def H(im):
        # print(im.sum().item(), im.min().item(), im.max().item())
        return (K_Q(im)).sum()

    return H

In [None]:
def MetaMorphosisLoss(K, K_Q, lam=0, rho=1):
    def loss(z0, im0):
        # p, imt = Shooting(z0, im0, K, K_Q, lam, rho)
        imt = torch.stack(Shooting(z0, im0, K, K_Q, lam, rho), dim=0)

        return Hamiltonian(K, K_Q)(imt) + lam * (DiffNorm(K)(z0, imt[-1]) + rho * Residual(K_Q)(z0))

    return loss

In [None]:
def HamiltonianSystem(K, K_Q, z0, lam, rho):
    H = Hamiltonian(K, K_Q)
    R = Residual(K_Q)
    D = DiffNorm(K)

    def HS(z0, im0):
        Gz0, Gim0 = grad(H(im0) + lam * (D(z0[-1], im0[-1]) + rho * R(z0[-1])), (z0, im0), allow_unused=True)
        return -Gim0, Gz0

    return HS

In [None]:
def Flow(z0, im0, K, K_Q, lam, rho, nt=63, Integrator=RalstonIntegrator()):
    return Integrator(HamiltonianSystem(K, K_Q, z0, lam, rho), (z0, im0), nt)

In [None]:
def Optimize(loss, z0, im0, lr=0.5, max_it=1):
    optimizer = torch.optim.LBFGS([z0], max_eval=1, max_iter=1, lr=lr, history_size=1)
    # optimizer = torch.optim.Adadelta([z0], lr=lr, rho=0.9, weight_decay=0)
    history = []
    print("performing optimization...")
    start = time.time()

    def closure():
        optimizer.zero_grad()
        L = loss(z0, im0)
        l = L.detach().cpu().numpy()
        print("Loss: ", l)
        history.append(l)
        L.backward()
        return L

    for i in range(max_it):
        print("it ", i, ": ", end="")
        optimizer.step(closure)

    print("Optimization (L-BFGS) time: ", round(time.time() - start, 2), " seconds")
    return history

In [None]:
def PlotDeformations(list_deformations):
    fig, axs = plt.subplots(8, 8)
    for i in range(8):
        for j in range(8):
            ind = i * 8 + j
            axs[i, j].imshow(list_deformations[ind].cpu().detach().numpy())

    plt.show()

## Main metamorphosis implementation

In [None]:
def metamorphosis(a_frames, a_lddmm_params, a_eps=0.5):
    # getting first and last frame
    first_frame = a_frames[0]
    last_frame = a_frames[-1]
    
    # for "time-series" growth model
    frames = torch.from_numpy(a_frames).float().cuda()
    frames.requires_grad = True
    
    # getting lddmm params
    T, lddmm_iteration, sigma, alpha, beta, epsilon = a_lddmm_params.values()

    # error bound for iterations
    eps = a_eps

    height, width = first_frame.shape

    time_res = 64

    torchdtype = torch.float32

    # initial values for GS
    im0 = torch.from_numpy(first_frame).cuda().requires_grad_(True)
    im1 = torch.from_numpy(last_frame).cuda().requires_grad_(True)

    # z0 = torch.from_numpy(first_frame - last_frame).requires_grad_(True)
    z0 = torch.rand([width, height], dtype=torchdtype).cuda().requires_grad_(True)
    rand_init = torch.rand([time_res, width, height], dtype=torchdtype).cuda().requires_grad_(True)
    # z0 = torch.zeros([width, height], dtype=torchdtype).requires_grad_(True)
    # z0 = torch.ones([width, height], dtype=torchdtype).requires_grad_(True)
 
    # lagrange multiplier and intensity change hyperparameters
    lam = 1e0
    rho = 1e0

    tic()    
    S = SamplesLoss(loss='sinkhorn', p=2, blur=eps, potentials=True, reach=1.0, scaling=0.95, diameter=10.0, debias=False)
        
    # K_Q = lambda q: ((im1 - q)**2)
    # for two images
    # K_Q = lambda q: S(q.flatten()[None, :], im1.flatten()[None, :])[0] - S(q.flatten()[None, :], q.flatten()[None, :])[0]
    K_Q = lambda q: S(torch.flatten(q, start_dim=1), torch.flatten(frames, start_dim=1))[0] - S(torch.flatten(q, start_dim=1), torch.flatten(q, start_dim=1))[0]
    # K_Q = lambda q: torch.nn.functional.kl_div(im1, q, reduction='none')
    K = GaussianBlur2d(kernel_size=1, sigma=(sigma, sigma))
    
    loss = MetaMorphosisLoss(K, K_Q, lam=lam, rho=rho)

    # gradient for GS
    learning_rate = 0.16

    # Geodesic shooting
    history = Optimize(loss, z0, rand_init, learning_rate, lddmm_iteration)
    
    # Computing and plotting deformations
    list_def = Flow(z0, frames, K, K_Q, lam=lam, rho=rho, nt=63)
    
    PlotDeformations(list_def)
    plt.show()
    
    toc()

## Parallel hole tests

In [None]:
dicom_loader = LoadVolumes()

# initialize data fetching from remote, configuration is in data/remote.yml
data_loaded = False
url, datasets = load_remote_data()

# fetch specific patient data
dicom_name = datasets['raw/']['turkey_par/'][10]
data_url = url + '/raw/' + 'turkey_par/' + dicom_name

# fetch the data from remote
data = fetch_data(data_url)

# load data with the dicom loader
frames, data_loaded = dicom_loader.LoadSinglePatient(data)

# normalizing the frame values
normalize_volume(frames)
frames = frames + 1  # just to get rid of NaNs in log computation
assert (data_loaded)

num_frames, width, height = frames.shape

lddmm_params = dict(T=num_frames * 32, K=10, sigma=5, alpha=1000, gamma=1, epsilon=1e-6)

indices = [0, -1]

In [None]:
# metamorphosis(frames[indices], lddmm_params, a_eps=1e-4)
metamorphosis(frames, lddmm_params, a_eps=1e-4)

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook
fig = plt.figure()
plt.imshow(frames[indices[1], :, :], aspect='equal')

## Experiments with loss function, divergences, distances and Brenier maps

In [None]:
eps = 1e-4

first_frame = frames[0]
last_frame = frames[-1]

torchdtype = torch.float32

im0 = torch.rand([64, 64], dtype=torchdtype).cuda()
im0.requires_grad = True
im1 = torch.from_numpy(last_frame).float().cuda()
im1.requires_grad = True
mass = 0.75 * torch.ones([1]).cuda()

S = SamplesLoss(loss='sinkhorn', p=2, blur=eps, potentials=True, reach=1.0, scaling=0.95, diameter=10.0, debias=False)

niter = 40
lr = 0.16

for i in range(niter):
    im0.requires_grad = True
    F_bias, _ = S(im0.flatten()[None, :], im0.flatten()[None, :])
    F, _ = S(im0.flatten()[None, :], im1.flatten()[None, :])
    
    g_x_bias = grad(F_bias.sum(), [im0], allow_unused=True)[0]
    g_x = grad(F.sum(), [im0], allow_unused=True)[0]
    
    with torch.no_grad():
        im0 = im0 - lr * (g_x - g_x_bias)
    im0.requires_grad = False
    print(((im1 - im0)**2).sum().item())

    
import matplotlib.pyplot as plt
%matplotlib notebook
fig = plt.figure()
plt.imshow((im0).cpu().detach().numpy(), aspect='equal')

## MPH tests

In [None]:
dicom_loader = LoadVolumes()

# initialize data fetching from remote, configuration is in data/remote.yml
data_loaded = False
url, datasets = load_remote_data()

# implementing pinhole bordermap loading here

# assemble here which data are we planning to download
raw_file_name = datasets['utility/']['apt_72_bordermap/'][2]
data_url = url + '/utility/' + 'apt_72_bordermap/' + raw_file_name

# fetch the data from remote
pinhole_bordermap_dat = fetch_data(data_url)
pinhole_bordermap = np.reshape(np.frombuffer(pinhole_bordermap_dat.getvalue(), dtype=np.float32),
                               [1024, 1024])  # antipattern to use exact burnt in numbers but I am lazy
pinhole_bordermap = pinhole_bordermap[::4, ::4]

dicom_file_name = datasets['simulated/']['motion_correction/'][0]['motion/'][
    0]  # massive data, Nystrom compression needed
data_url = url + '/simulated/' + 'motion_correction/motion/' + dicom_file_name

# fetch the data from remote
data = fetch_data(data_url)

# load data with the dicom loader
frames, data_loaded = dicom_loader.LoadSinglePatient(data)
frames[:] = frames[:] * np.where(pinhole_bordermap > 0, 1, 0)

# getting pixels that matter
sqrt_masked_pixels = np.ceil(np.sqrt(np.sum(np.where(pinhole_bordermap > 0, 1, 0)))).astype(int)
frames_compressed = np.zeros([frames.shape[0], sqrt_masked_pixels, sqrt_masked_pixels])
non_zero_ind = np.nonzero(pinhole_bordermap)

for i in range(frames.shape[0]):
    cur_frame = np.resize(frames[i, non_zero_ind[0], non_zero_ind[1]], sqrt_masked_pixels * sqrt_masked_pixels)
    frames_compressed[i] = np.reshape(cur_frame, [sqrt_masked_pixels, sqrt_masked_pixels])

first_frame = frames_compressed[0]

# normalizing the frame values
normalize_volume(frames)
frames = frames + 1  # just to get rid of NaNs in log computation

assert (data_loaded)

params = dict(T=16, K=100, sigma=64, alpha=1000, gamma=1,
              epsilon=1)  # might not be the most optimal parameters so far

num_frames, width, height = frames.shape

lddmm_params = dict(T=num_frames * params.get('T'), K=30, sigma=1, alpha=1000, gamma=1, epsilon=1e-6)

In [None]:
metamorphosis(frames[32:35], lddmm_params, a_eps=0.5)

In [None]:
plt.close("all")