In [23]:
import numpy as np
from dipy.sims.voxel import multi_tensor, multi_tensor_odf
from dipy.core.sphere import disperse_charges, HemiSphere
from dipy.core.gradients import gradient_table
import torch
from DELIMIT.SphericalHarmonicTransformation import Signal2SH, SH2Signal
from DELIMIT.SphericalConvolution import LocalSphericalConvolution, SphericalConvolution
from DELIMIT.loss import MSESignal

# Example for each class

## Parameters that need to be set

In [24]:
num_gradients = 30
sh_order = 4

## Signal Generation

In [25]:
theta = np.pi * np.random.rand(num_gradients)
phi = 2 * np.pi * np.random.rand(num_gradients)
hsph_initial = HemiSphere(theta=theta, phi=phi)
hsph_updated, potential = disperse_charges(hsph_initial, 5000)
gradients = hsph_updated.vertices

gtab = gradient_table(np.concatenate((np.zeros(1), np.ones(30)*1000)),
                      np.concatenate((np.zeros((1, 3)), gradients)))

mevals = np.array([[0.0015, 0.0003, 0.0003],
                   [0.0015, 0.0003, 0.0003]])
angles = [(0, 0), (60, 0)]
fractions = [50, 50]
signal, sticks = multi_tensor(gtab, mevals, S0=1, angles=angles,
                              fractions=fractions, snr=None)

## Signal Domain to Spherical Harmonic Domain transformation

In [26]:
s2sh = Signal2SH(gradients=gradients, sh_order=sh_order, lb_lambda=0.006)

input_tensor = torch.from_numpy(signal[1:]).reshape(1, num_gradients, 1, 1, 1).float()
input_tensor_sh = s2sh(input_tensor)
print(input_tensor_sh.shape)

torch.Size([1, 15, 1, 1, 1])


## Local Spherical Convolution

In [27]:
lsc = LocalSphericalConvolution(shells_in=1, shells_out=3,
                                sh_order_in=sh_order, sh_order_out=sh_order, lb_lambda=0.006,
                                sampled_gradients=gradients, kernel_sizes=[5, 5],
                                angular_distance=(np.pi / 10))

lsc_tensor_sh = lsc(input_tensor_sh)

num_coefficients = int((sh_order + 1) * (sh_order / 2 + 1)) # just for visualization
print(lsc_tensor_sh.reshape(1, -1, num_coefficients, 1, 1, 1).shape)

torch.Size([1, 3, 15, 1, 1, 1])


## Spherical Convolution

In [28]:
sc = SphericalConvolution(shells_in=3, shells_out=1, sh_order=sh_order)
sc_tensor_sh = sc(lsc_tensor_sh)
print(sc_tensor_sh.shape)

torch.Size([1, 15, 1, 1, 1])


## Loss calculation

In [29]:
loss = MSESignal(sh_order=sh_order, gradients=gradients)
print(loss(sc_tensor_sh, input_tensor_sh, torch.from_numpy(np.ones(1)).reshape(1, 1, 1, 1)))

tensor(1.1650, grad_fn=<MeanBackward1>)


## Spherical Harmonic Domain to Signal domain transformation

In [30]:
sh2s = SH2Signal(sh_order=sh_order, gradients=gradients)
output_signal = sh2s(sc_tensor_sh)