# make this fully parameterized

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from typing import Tuple
import pickle
import numpy as np
from tqdm import tqdm

In [2]:
def sample_gradients(n: int) -> torch.Tensor:
    '''
    Samples n gradient vectors from a uniform distribution on the unit sphere.

    Parameters:
    n (int): The number of gradient vectors to sample.

    Returns:
    g (torch.Tensor): The sampled gradient vectors. Shape (n, 3).
    '''

    # Save the current state of the random number generator
    state = torch.random.get_rng_state()

    # Set the seed
    torch.manual_seed(0)

    # Sample n angles from a uniform distribution on [0, 2pi].
    theta_angles = torch.rand(n) * 2 * torch.pi

    # Sample n angles from a uniform distribution on [0, pi].
    phi_angles = torch.rand(n) * torch.pi

    # Convert the angles to cartesian coordinates.
    x = torch.sin(phi_angles) * torch.cos(theta_angles)
    y = torch.sin(phi_angles) * torch.sin(theta_angles)
    z = torch.cos(phi_angles)

    # Stack the coordinates together
    g = torch.stack((x, y, z), dim=-1)

    # Restore the state of the random number generator
    torch.random.set_rng_state(state)

    return g

In [None]:
def create_random_diffusion_tensor() -> torch.Tensor:
    '''
    Creates a random diffusion tensor. 
    The tensor is symmetric and positive-definite.

    Returns:
    D (torch.Tensor): The diffusion tensor. Shape (3, 3).
    '''

    # Create a random lower triangular matrix with positive diagonal elements.
    L = torch.tril(torch.rand(3, 3)) + torch.eye(3)

    # Get diffusion tensor by Cholesky composition.
    D = L @ L.T

    return D

# polar decomposition

In [None]:
def polar_decomposition(D: torch.Tensor) \
    -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    Calculates the polar decomposition parameters of a diffusion tensor.

    Parameters:
    D (torch.Tensor): The diffusion tensor. Shape (3, 3).

    Returns:
    angle_x (torch.Tensor): The angle of rotation around the x-axis.
    angle_y (torch.Tensor): The angle of rotation around the y-axis.
    angle_z (torch.Tensor): The angle of rotation around the z-axis.
    eig_val_1 (torch.Tensor): The eigenvalue of the first eigenvector.
    eig_val_2 (torch.Tensor): The eigenvalue of the second eigenvector.
    eig_val_3 (torch.Tensor): The eigenvalue of the third eigenvector.
    '''

    # Calculate the eigenvalues and eigenvectors.
    eig_vals, eig_vecs = torch.linalg.eigh(D)

    # Sort the eigenvalues and eigenvectors.
    eig_vals, indices = torch.sort(eig_vals, descending=True)
    eig_vecs = eig_vecs[:, indices]

    # Calculate the angles of rotation.
    angle_x = torch.atan2(eig_vecs[2, 1], eig_vecs[2, 2])
    angle_y = torch.atan2(-eig_vecs[2, 0], torch.sqrt(eig_vecs[2, 1] ** 2 + eig_vecs[2, 2] ** 2))
    angle_z = torch.atan2(eig_vecs[1, 0], eig_vecs[0, 0])

    # Get the eigenvalues.
    eig_val_1 = eig_vals[0]
    eig_val_2 = eig_vals[1]
    eig_val_3 = eig_vals[2]

    return angle_x, angle_y, angle_z, eig_val_1, eig_val_2, eig_val_3

In [None]:
def polar_composition(
        angle_x: torch.Tensor, angle_y: torch.Tensor, angle_z: torch.Tensor,
        eig_val_1: torch.Tensor, eig_val_2: torch.Tensor, eig_val_3: torch.Tensor
    ) -> torch.Tensor:
    '''
    Reconstructs the tensor diffusion tensor from the polar decomposition parameters.
    
    Parameters:
    angle_x (torch.Tensor): The angle of rotation around the x-axis.
    angle_y (torch.Tensor): The angle of rotation around the y-axis.
    angle_z (torch.Tensor): The angle of rotation around the z-axis.
    eig_val_1 (torch.Tensor): The eigenvalue of the first eigenvector.
    eig_val_2 (torch.Tensor): The eigenvalue of the second eigenvector.
    eig_val_3 (torch.Tensor): The eigenvalue of the third eigenvector.

    Returns:
    D (torch.Tensor): The reconstructed diffusion tensor. Shape (3, 3).
    '''

    # Calculate the rotation matrices.
    R_x = torch.tensor([
        [1, 0, 0],
        [0, torch.cos(angle_x), -torch.sin(angle_x)],
        [0, torch.sin(angle_x), torch.cos(angle_x)]
    ])
    R_y = torch.tensor([
        [torch.cos(angle_y), 0, torch.sin(angle_y)],
        [0, 1, 0],
        [-torch.sin(angle_y), 0, torch.cos(angle_y)]
    ])
    R_z = torch.tensor([
        [torch.cos(angle_z), -torch.sin(angle_z), 0],
        [torch.sin(angle_z), torch.cos(angle_z), 0],
        [0, 0, 1]
    ])

    # Calculate the rotation matrix.
    R = R_z @ R_y @ R_x

    # Calculate the diagonal matrix of eigenvalues.
    eig_vals = torch.tensor([
        [eig_val_1, 0, 0],
        [0, eig_val_2, 0],
        [0, 0, eig_val_3]
    ])

    # Reconstruct the diffusion tensor.
    D = R @ eig_vals @ R.T

    return D

In [None]:
total = 0
correct = 0

for i in range(10_000):

    # reconstruct diffusion tensor
    D = create_random_diffusion_tensor()

    # check if the tensor has Cholesky decomposition
    try:
        torch.linalg.cholesky(D)
    except:
        print(i, 'No Cholesky decomposition!')
        break

    D_ = polar_composition(*polar_decomposition(D))

    if torch.allclose(D, D_):
        correct += 1
    
    total += 1

print(f'{100*correct/total:.2f}% correct')

In [None]:
total = 0
correct = 0

for i in range(10_000):

    # create random angles in radians and range [0, 2pi]
    angle_x = torch.rand(1) * 2 * torch.pi
    angle_y = torch.rand(1) * 2 * torch.pi
    angle_z = torch.rand(1) * 2 * torch.pi

    # create random eigenvalues in range [0, 1] and descending order
    eig_val_1 = torch.rand(1)
    eig_val_2 = torch.rand(1) * eig_val_1
    eig_val_3 = torch.rand(1) * eig_val_2

    # reconstruct diffusion tensor
    D = polar_composition(angle_x, angle_y, angle_z, eig_val_1, eig_val_2, eig_val_3)

    # check if the tensor has Cholesky decomposition
    try:
        torch.linalg.cholesky(D)
    except:
        print(i, 'No Cholesky decomposition!')
        break

    # decompose the tensor
    angle_x_, angle_y_, angle_z_, eig_val_1_, eig_val_2_, eig_val_3_ = polar_decomposition(D)

    # reconstruct the tensor
    D_ = polar_composition(angle_x_, angle_y_, angle_z_, eig_val_1_, eig_val_2_, eig_val_3_)

    if torch.allclose(D, D_):
        correct += 1
    
    total += 1
    
print(f'{100*correct/total:.2f}% correct')

# quaternion decomposition

In [None]:
def quaternion_decomposition(D: torch.Tensor) \
    -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    Calculates the quaternion decomposition parameters of a diffusion tensor.

    Parameters:
    D (torch.Tensor): The diffusion tensor. Shape (3, 3).

    Returns:
    q (torch.Tensor): The quaternion parameters. Shape (4,).
    eig_val_1 (torch.Tensor): The eigenvalue of the first eigenvector.
    eig_val_2 (torch.Tensor): The eigenvalue of the second eigenvector.
    eig_val_3 (torch.Tensor): The eigenvalue of the third eigenvector.
    '''
    # Calculate the eigenvalues and eigenvectors.
    eig_vals, eig_vecs = torch.linalg.eigh(D)

    # Sort the eigenvalues and eigenvectors.
    eig_vals, indices = torch.sort(eig_vals, descending=True)
    eig_vecs = eig_vecs[:, indices]
    
    # Get the eigenvalues.
    eig_val_1 = eig_vals[0]
    eig_val_2 = eig_vals[1]
    eig_val_3 = eig_vals[2]

    # Calculate quaternion
    q0 = torch.sqrt(1 + eig_vecs[0, 0] + eig_vecs[1, 1] + eig_vecs[2, 2]) / 2
    q1 = (eig_vecs[2, 1] - eig_vecs[1, 2]) / (4 * q0)
    q2 = (eig_vecs[0, 2] - eig_vecs[2, 0]) / (4 * q0)
    q3 = (eig_vecs[1, 0] - eig_vecs[0, 1]) / (4 * q0)

    # Normalize the quaternion
    q = torch.tensor([q0, q1, q2, q3])
    q /= torch.linalg.norm(q)

    return q, eig_val_1, eig_val_2, eig_val_3  

In [None]:
def quaternion_composition(q: torch.Tensor, eig_val_1: torch.Tensor, 
                           eig_val_2: torch.Tensor, eig_val_3: torch.Tensor) -> torch.Tensor:
    '''
    Reconstructs the tensor diffusion tensor from the quaternion decomposition parameters.

    Parameters:
    q (torch.Tensor): The quaternion parameters. Shape (4,).
    eig_val_1 (torch.Tensor): The eigenvalue of the first eigenvector.
    eig_val_2 (torch.Tensor): The eigenvalue of the second eigenvector.
    eig_val_3 (torch.Tensor): The eigenvalue of the third eigenvector.

    Returns:
    D (torch.Tensor): The reconstructed diffusion tensor. Shape (3, 3).
    '''

    # Calculate rotation matrix
    R = torch.tensor([
        [1 - 2*q[2]**2 - 2*q[3]**2, 2*q[1]*q[2] - 2*q[0]*q[3], 2*q[1]*q[3] + 2*q[0]*q[2]],
        [2*q[1]*q[2] + 2*q[0]*q[3], 1 - 2*q[1]**2 - 2*q[3]**2, 2*q[2]*q[3] - 2*q[0]*q[1]],
        [2*q[1]*q[3] - 2*q[0]*q[2], 2*q[2]*q[3] + 2*q[0]*q[1], 1 - 2*q[1]**2 - 2*q[2]**2]
    ])

    # Calculate the diagonal matrix of eigenvalues.
    eig_vals = torch.tensor([
        [eig_val_1, 0, 0],
        [0, eig_val_2, 0],
        [0, 0, eig_val_3]
    ])

    # Reconstruct the diffusion tensor.
    D = R @ eig_vals @ R.T

    return D

In [None]:
total = 0
correct = 0

for i in range(10_000):

    # reconstruct diffusion tensor
    D = create_random_diffusion_tensor()

    # check if the tensor has Cholesky decomposition
    try:
        torch.linalg.cholesky(D)
    except:
        print(i, 'No Cholesky decomposition!')
        break

    D_ = quaternion_composition(*quaternion_decomposition(D))

    if torch.allclose(D, D_):
        correct += 1
    
    total += 1

print(f'{100*correct/total:.2f}% correct')

In [None]:
total = 0
correct = 0

for i in range(10_000):

    # create random quaternion with values in range [-1, 1]
    q = torch.rand(4) * 2 - 1

    # normalize the quaternion
    q /= torch.linalg.norm(q)

    # create random eigenvalues in range [0, 1] and descending order
    eig_val_1 = torch.rand(1)
    eig_val_2 = torch.rand(1) * eig_val_1
    eig_val_3 = torch.rand(1) * eig_val_2

    # reconstruct diffusion tensor
    D = quaternion_composition(q, eig_val_1, eig_val_2, eig_val_3)

    # check if the tensor has Cholesky decomposition
    try:
        torch.linalg.cholesky(D)
    except:
        print(i, 'No Cholesky decomposition!')
        break

    # decompose the tensor
    q_, eig_val_1_, eig_val_2_, eig_val_3_ = quaternion_decomposition(D)

    # reconstruct the tensor
    D_ = quaternion_composition(q_, eig_val_1_, eig_val_2_, eig_val_3_)

    if torch.allclose(D, D_):
        correct += 1
    
    total += 1
    
print(f'{100*correct/total:.2f}% correct')

---
---

In [3]:
lstsq_results: dict[Tuple[float, float, float], np.ndarray]

with open('lstsq_results.pkl', 'rb') as f:
    lstsq_results = pickle.load(f)

In [4]:
num_b_0 = 10
num_b_1k = 128
SNR = 20

b_values = torch.tensor([0.0] * num_b_0 + [1_000.0] * num_b_1k) # shape (num_b_0 + num_b_1k,)
gradients = sample_gradients(num_b_0 + num_b_1k) # shape (num_b_0 + num_b_1k, 3)

d_tensors = []
normalized_noisy_signals = []

for d_array in tqdm(lstsq_results.values()):

    d_tensor = torch.tensor(d_array).float() # shape (3, 3)

    signal = torch.exp(- b_values * torch.einsum('bi, bij, bj -> b', 
                                                  gradients, 
                                                  d_tensor.unsqueeze(0).repeat(num_b_0 + num_b_1k, 1, 1),
                                                  gradients
                                                  )) # shape (num_b_0 + num_b_1k,)
    
    noise = torch.normal(mean=torch.zeros_like(signal), std=torch.ones_like(signal) / SNR) # shape (num_b_0 + num_b_1k,)

    noisy_signal = signal + noise # shape (num_b_0 + num_b_1k,)

    normalized_noisy_signal = noisy_signal / torch.mean(noisy_signal[:num_b_0]) # shape (num_b_0 + num_b_1k,)

    d_tensors.append(d_tensor)
    normalized_noisy_signals.append(normalized_noisy_signal)

d_tensors = torch.stack(d_tensors) # shape (num_d_tensors, 3, 3)
normalized_noisy_signals = torch.stack(normalized_noisy_signals) # shape (num_d_tensors, num_b_0 + num_b_1k)

100%|██████████| 300446/300446 [00:51<00:00, 5815.48it/s]
