# make this fully parameterized

S0 in range [100, 2_000] with 100 std around mean S0 of the voxel  
b either 0 or 1_000  
100 gradient vectors uniformly distributed on sphere  
10 measurements for b = 0 per voxel  
100 measurements for b = 1_000 per voxel  
1_000_000 voxels in total  

In [2]:
import torch
from typing import Tuple
import scipy

In [None]:
def exp_model(S0: float, b: float, D: np.ndarray, g: np.ndarray) -> float:
    '''
    Implements the Stejskal-Tanner equation.

    Parameters:
    S0 (float): The baseline signal intensity.
    b (float): The b-value.
    D (numpy.ndarray): The diffusion tensor. Shape (3, 3).
    g (numpy.ndarray): The gradient vector. Shape (3,).

    Returns:
    S (float): The expected signal intensity.
    '''

    return S0 * np.exp(-b * g @ D @ g)

In [None]:
def sample_single_S0(lower_bound: float, upper_bound: float) -> float:
    '''
    Chooses a random S0 value from a uniform distribution.

    Parameters:
    lower_bound (float): The lower bound of the uniform distribution.
    upper_bound (float): The upper bound of the uniform distribution.

    Returns:
    S0 (float): The chosen S0 value.
    '''

    return np.random.uniform(lower_bound, upper_bound)

In [None]:
def sample_multiple_S0(mean_S0: float, std_S0: float, n: int) -> np.ndarray:
    '''
    Samples n S0 values from a normal distribution.

    Parameters:
    mean_S0 (float): The mean of the normal distribution.
    std_S0 (float): The standard deviation of the normal distribution.
    n (int): The number of samples to take.

    Returns:
    S0 (numpy.ndarray): The sampled S0 values. Shape (n,).
    '''

    return np.random.normal(mean_S0, std_S0, n)

In [None]:
def sample_gradients(n: int) -> np.ndarray:
    '''
    Samples n gradient vectors from a uniform distribution on the unit sphere.

    Parameters:
    n (int): The number of gradient vectors to sample.

    Returns:
    g (numpy.ndarray): The sampled gradient vectors. Shape (n, 3).
    '''

    # Sample n angles from a uniform distribution on [0, 2pi].
    theta_angles = np.random.uniform(0, 2 * np.pi, n)

    # Sample n angles from a uniform distribution on [0, pi].
    phi_angles = np.random.uniform(0, np.pi, n)

    # Convert the angles to cartesian coordinates.
    x = np.sin(phi_angles) * np.cos(theta_angles)
    y = np.sin(phi_angles) * np.sin(theta_angles)
    z = np.cos(phi_angles)

    # Stack the coordinates together
    g = np.stack((x, y, z), axis=-1)

    return g

In [None]:
single_S0 = sample_single_S0(0, 1000)
print(single_S0)
multiple_S0 = sample_multiple_S0(single_S0, 100, 10)
print(multiple_S0)

In [None]:
gradients = sample_gradients(10)
print(gradients)

In [81]:
def create_random_diffusion_tensor() -> torch.Tensor:
    '''
    Creates a random diffusion tensor. 
    The tensor is symmetric and positive-definite.

    Returns:
    D (torch.Tensor): The diffusion tensor. Shape (3, 3).
    '''

    # Create a random lower triangular matrix with positive diagonal elements.
    L = torch.tril(torch.rand(3, 3)) + torch.eye(3)

    # Get diffusion tensor by Cholesky composition.
    D = L @ L.T

    return D

# polar decomposition

In [82]:
def polar_decomposition(D: torch.Tensor) \
    -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    Calculates the polar decomposition parameters of a diffusion tensor.

    Parameters:
    D (torch.Tensor): The diffusion tensor. Shape (3, 3).

    Returns:
    angle_x (torch.Tensor): The angle of rotation around the x-axis.
    angle_y (torch.Tensor): The angle of rotation around the y-axis.
    angle_z (torch.Tensor): The angle of rotation around the z-axis.
    eig_val_1 (torch.Tensor): The eigenvalue of the first eigenvector.
    eig_val_2 (torch.Tensor): The eigenvalue of the second eigenvector.
    eig_val_3 (torch.Tensor): The eigenvalue of the third eigenvector.
    '''

    # Calculate the eigenvalues and eigenvectors.
    eig_vals, eig_vecs = torch.linalg.eigh(D)

    # Sort the eigenvalues and eigenvectors.
    eig_vals, indices = torch.sort(eig_vals, descending=True)
    eig_vecs = eig_vecs[:, indices]

    # Calculate the angles of rotation.
    angle_x = torch.atan2(eig_vecs[2, 1], eig_vecs[2, 2])
    angle_y = torch.atan2(-eig_vecs[2, 0], torch.sqrt(eig_vecs[2, 1] ** 2 + eig_vecs[2, 2] ** 2))
    angle_z = torch.atan2(eig_vecs[1, 0], eig_vecs[0, 0])

    # Get the eigenvalues.
    eig_val_1 = eig_vals[0]
    eig_val_2 = eig_vals[1]
    eig_val_3 = eig_vals[2]

    return angle_x, angle_y, angle_z, eig_val_1, eig_val_2, eig_val_3

In [83]:
def polar_composition(
        angle_x: torch.Tensor, angle_y: torch.Tensor, angle_z: torch.Tensor,
        eig_val_1: torch.Tensor, eig_val_2: torch.Tensor, eig_val_3: torch.Tensor
    ) -> torch.Tensor:
    '''
    Reconstructs the tensor diffusion tensor from the polar decomposition parameters.
    
    Parameters:
    angle_x (torch.Tensor): The angle of rotation around the x-axis.
    angle_y (torch.Tensor): The angle of rotation around the y-axis.
    angle_z (torch.Tensor): The angle of rotation around the z-axis.
    eig_val_1 (torch.Tensor): The eigenvalue of the first eigenvector.
    eig_val_2 (torch.Tensor): The eigenvalue of the second eigenvector.
    eig_val_3 (torch.Tensor): The eigenvalue of the third eigenvector.

    Returns:
    D (torch.Tensor): The reconstructed diffusion tensor. Shape (3, 3).
    '''

    # Calculate the rotation matrices.
    R_x = torch.tensor([
        [1, 0, 0],
        [0, torch.cos(angle_x), -torch.sin(angle_x)],
        [0, torch.sin(angle_x), torch.cos(angle_x)]
    ])
    R_y = torch.tensor([
        [torch.cos(angle_y), 0, torch.sin(angle_y)],
        [0, 1, 0],
        [-torch.sin(angle_y), 0, torch.cos(angle_y)]
    ])
    R_z = torch.tensor([
        [torch.cos(angle_z), -torch.sin(angle_z), 0],
        [torch.sin(angle_z), torch.cos(angle_z), 0],
        [0, 0, 1]
    ])

    # Calculate the rotation matrix.
    R = R_z @ R_y @ R_x

    # Calculate the diagonal matrix of eigenvalues.
    eig_vals = torch.tensor([
        [eig_val_1, 0, 0],
        [0, eig_val_2, 0],
        [0, 0, eig_val_3]
    ])

    # Reconstruct the diffusion tensor.
    D = R @ eig_vals @ R.T

    return D

In [85]:
total = 0
correct = 0

for i in range(10_000):

    # reconstruct diffusion tensor
    D = create_random_diffusion_tensor()

    # check if the tensor has Cholesky decomposition
    try:
        torch.linalg.cholesky(D)
    except:
        print(i, 'No Cholesky decomposition!')
        break

    D_ = polar_composition(*polar_decomposition(D))

    if torch.allclose(D, D_):
        correct += 1
    
    total += 1

print(f'{100*correct/total:.2f}% correct')

20.66% correct


In [86]:
total = 0
correct = 0

for i in range(10_000):

    # create random angles in radians and range [0, 2pi]
    angle_x = torch.rand(1) * 2 * torch.pi
    angle_y = torch.rand(1) * 2 * torch.pi
    angle_z = torch.rand(1) * 2 * torch.pi

    # create random eigenvalues in range [0, 1] and descending order
    eig_val_1 = torch.rand(1)
    eig_val_2 = torch.rand(1) * eig_val_1
    eig_val_3 = torch.rand(1) * eig_val_2

    # reconstruct diffusion tensor
    D = polar_composition(angle_x, angle_y, angle_z, eig_val_1, eig_val_2, eig_val_3)

    # check if the tensor has Cholesky decomposition
    try:
        torch.linalg.cholesky(D)
    except:
        print(i, 'No Cholesky decomposition!')
        break

    # decompose the tensor
    angle_x_, angle_y_, angle_z_, eig_val_1_, eig_val_2_, eig_val_3_ = polar_decomposition(D)

    # reconstruct the tensor
    D_ = polar_composition(angle_x_, angle_y_, angle_z_, eig_val_1_, eig_val_2_, eig_val_3_)

    if torch.allclose(D, D_):
        correct += 1
    
    total += 1
    
print(f'{100*correct/total:.2f}% correct')

34.91% correct


# quaternion decomposition

In [87]:
def quaternion_decomposition(D: torch.Tensor) \
    -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    Calculates the quaternion decomposition parameters of a diffusion tensor.

    Parameters:
    D (torch.Tensor): The diffusion tensor. Shape (3, 3).

    Returns:
    q (torch.Tensor): The quaternion parameters. Shape (4,).
    eig_val_1 (torch.Tensor): The eigenvalue of the first eigenvector.
    eig_val_2 (torch.Tensor): The eigenvalue of the second eigenvector.
    eig_val_3 (torch.Tensor): The eigenvalue of the third eigenvector.
    '''
    # Calculate the eigenvalues and eigenvectors.
    eig_vals, eig_vecs = torch.linalg.eigh(D)

    # Sort the eigenvalues and eigenvectors.
    eig_vals, indices = torch.sort(eig_vals, descending=True)
    eig_vecs = eig_vecs[:, indices]
    
    # Get the eigenvalues.
    eig_val_1 = eig_vals[0]
    eig_val_2 = eig_vals[1]
    eig_val_3 = eig_vals[2]

    # Calculate quaternion
    q0 = torch.sqrt(1 + eig_vecs[0, 0] + eig_vecs[1, 1] + eig_vecs[2, 2]) / 2
    q1 = (eig_vecs[2, 1] - eig_vecs[1, 2]) / (4 * q0)
    q2 = (eig_vecs[0, 2] - eig_vecs[2, 0]) / (4 * q0)
    q3 = (eig_vecs[1, 0] - eig_vecs[0, 1]) / (4 * q0)

    # Normalize the quaternion
    q = torch.tensor([q0, q1, q2, q3])
    q /= torch.linalg.norm(q)

    return q, eig_val_1, eig_val_2, eig_val_3  

In [88]:
def quaternion_composition(q: torch.Tensor, eig_val_1: torch.Tensor, 
                           eig_val_2: torch.Tensor, eig_val_3: torch.Tensor) -> torch.Tensor:
    '''
    Reconstructs the tensor diffusion tensor from the quaternion decomposition parameters.

    Parameters:
    q (torch.Tensor): The quaternion parameters. Shape (4,).
    eig_val_1 (torch.Tensor): The eigenvalue of the first eigenvector.
    eig_val_2 (torch.Tensor): The eigenvalue of the second eigenvector.
    eig_val_3 (torch.Tensor): The eigenvalue of the third eigenvector.

    Returns:
    D (torch.Tensor): The reconstructed diffusion tensor. Shape (3, 3).
    '''

    # Calculate rotation matrix
    R = torch.tensor([
        [1 - 2*q[2]**2 - 2*q[3]**2, 2*q[1]*q[2] - 2*q[0]*q[3], 2*q[1]*q[3] + 2*q[0]*q[2]],
        [2*q[1]*q[2] + 2*q[0]*q[3], 1 - 2*q[1]**2 - 2*q[3]**2, 2*q[2]*q[3] - 2*q[0]*q[1]],
        [2*q[1]*q[3] - 2*q[0]*q[2], 2*q[2]*q[3] + 2*q[0]*q[1], 1 - 2*q[1]**2 - 2*q[2]**2]
    ])

    # Calculate the diagonal matrix of eigenvalues.
    eig_vals = torch.tensor([
        [eig_val_1, 0, 0],
        [0, eig_val_2, 0],
        [0, 0, eig_val_3]
    ])

    # Reconstruct the diffusion tensor.
    D = R @ eig_vals @ R.T

    return D

In [100]:
total = 0
correct = 0

for i in range(10_000):

    # reconstruct diffusion tensor
    D = create_random_diffusion_tensor()

    # check if the tensor has Cholesky decomposition
    try:
        torch.linalg.cholesky(D)
    except:
        print(i, 'No Cholesky decomposition!')
        break

    D_ = quaternion_composition(*quaternion_decomposition(D))

    if torch.allclose(D, D_):
        correct += 1
    
    total += 1

print(f'{100*correct/total:.2f}% correct')

19.41% correct


In [101]:
total = 0
correct = 0

for i in range(10_000):

    # create random quaternion with values in range [-1, 1]
    q = torch.rand(4) * 2 - 1

    # normalize the quaternion
    q /= torch.linalg.norm(q)

    # create random eigenvalues in range [0, 1] and descending order
    eig_val_1 = torch.rand(1)
    eig_val_2 = torch.rand(1) * eig_val_1
    eig_val_3 = torch.rand(1) * eig_val_2

    # reconstruct diffusion tensor
    D = quaternion_composition(q, eig_val_1, eig_val_2, eig_val_3)

    # check if the tensor has Cholesky decomposition
    try:
        torch.linalg.cholesky(D)
    except:
        print(i, 'No Cholesky decomposition!')
        break

    # decompose the tensor
    q_, eig_val_1_, eig_val_2_, eig_val_3_ = quaternion_decomposition(D)

    # reconstruct the tensor
    D_ = quaternion_composition(q_, eig_val_1_, eig_val_2_, eig_val_3_)

    if torch.allclose(D, D_):
        correct += 1
    
    total += 1
    
print(f'{100*correct/total:.2f}% correct')

30.56% correct
