In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from typing import Tuple
import pickle
import numpy as np
from tqdm import tqdm

In [None]:
def create_random_diffusion_tensor() -> torch.Tensor:
    '''
    Creates a random diffusion tensor. 
    The tensor is symmetric and positive-definite.

    Returns:
    D (torch.Tensor): The diffusion tensor. Shape (3, 3).
    '''

    # Create a random lower triangular matrix with positive diagonal elements.
    L = torch.tril(torch.rand(3, 3)) + torch.eye(3)

    # Get diffusion tensor by Cholesky composition.
    D = L @ L.T

    return D

In [None]:
def polar_decomposition(D: torch.Tensor) \
    -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    Calculates the polar decomposition parameters of a diffusion tensor.

    Parameters:
    D (torch.Tensor): The diffusion tensor. Shape (3, 3).

    Returns:
    angle_x (torch.Tensor): The angle of rotation around the x-axis.
    angle_y (torch.Tensor): The angle of rotation around the y-axis.
    angle_z (torch.Tensor): The angle of rotation around the z-axis.
    eig_val_1 (torch.Tensor): The eigenvalue of the first eigenvector.
    eig_val_2 (torch.Tensor): The eigenvalue of the second eigenvector.
    eig_val_3 (torch.Tensor): The eigenvalue of the third eigenvector.
    '''

    # Calculate the eigenvalues and eigenvectors.
    eig_vals, eig_vecs = torch.linalg.eigh(D)

    # Sort the eigenvalues and eigenvectors.
    eig_vals, indices = torch.sort(eig_vals, descending=True)
    eig_vecs = eig_vecs[:, indices]

    # Calculate the angles of rotation.
    angle_x = torch.atan2(eig_vecs[2, 1], eig_vecs[2, 2])
    angle_y = torch.atan2(-eig_vecs[2, 0], torch.sqrt(eig_vecs[2, 1] ** 2 + eig_vecs[2, 2] ** 2))
    angle_z = torch.atan2(eig_vecs[1, 0], eig_vecs[0, 0])

    # Get the eigenvalues.
    eig_val_1 = eig_vals[0]
    eig_val_2 = eig_vals[1]
    eig_val_3 = eig_vals[2]

    return angle_x, angle_y, angle_z, eig_val_1, eig_val_2, eig_val_3

In [None]:
def polar_composition(
        angle_x: torch.Tensor, angle_y: torch.Tensor, angle_z: torch.Tensor,
        eig_val_1: torch.Tensor, eig_val_2: torch.Tensor, eig_val_3: torch.Tensor
    ) -> torch.Tensor:
    '''
    Reconstructs the tensor diffusion tensor from the polar decomposition parameters.
    
    Parameters:
    angle_x (torch.Tensor): The angle of rotation around the x-axis.
    angle_y (torch.Tensor): The angle of rotation around the y-axis.
    angle_z (torch.Tensor): The angle of rotation around the z-axis.
    eig_val_1 (torch.Tensor): The eigenvalue of the first eigenvector.
    eig_val_2 (torch.Tensor): The eigenvalue of the second eigenvector.
    eig_val_3 (torch.Tensor): The eigenvalue of the third eigenvector.

    Returns:
    D (torch.Tensor): The reconstructed diffusion tensor. Shape (3, 3).
    '''

    # Calculate the rotation matrices.
    R_x = torch.tensor([
        [1, 0, 0],
        [0, torch.cos(angle_x), -torch.sin(angle_x)],
        [0, torch.sin(angle_x), torch.cos(angle_x)]
    ])
    R_y = torch.tensor([
        [torch.cos(angle_y), 0, torch.sin(angle_y)],
        [0, 1, 0],
        [-torch.sin(angle_y), 0, torch.cos(angle_y)]
    ])
    R_z = torch.tensor([
        [torch.cos(angle_z), -torch.sin(angle_z), 0],
        [torch.sin(angle_z), torch.cos(angle_z), 0],
        [0, 0, 1]
    ])

    # Calculate the rotation matrix.
    R = R_z @ R_y @ R_x

    # Calculate the diagonal matrix of eigenvalues.
    eig_vals = torch.tensor([
        [eig_val_1, 0, 0],
        [0, eig_val_2, 0],
        [0, 0, eig_val_3]
    ])

    # Reconstruct the diffusion tensor.
    D = R @ eig_vals @ R.T

    return D

In [None]:
# polar decomposition - test 1

total = 0
correct = 0

for i in range(10_000):

    # reconstruct diffusion tensor
    D = create_random_diffusion_tensor()

    # check if the tensor has Cholesky decomposition
    try:
        torch.linalg.cholesky(D)
    except:
        print(i, 'No Cholesky decomposition!')
        break

    D_ = polar_composition(*polar_decomposition(D))

    if torch.allclose(D, D_):
        correct += 1
    
    total += 1

print(f'{100*correct/total:.2f}% correct')

In [None]:
# polar decomposition - test 2

total = 0
correct = 0

for i in range(10_000):

    # create random angles in radians and range [0, 2pi]
    x_angles = torch.rand(1) * 2 * torch.pi
    y_angles = torch.rand(1) * 2 * torch.pi
    z_angles = torch.rand(1) * 2 * torch.pi

    # create random eigenvalues in range [0, 1] and descending order
    eig_val_1 = torch.rand(1)
    eig_val_2 = torch.rand(1) * eig_val_1
    eig_val_3 = torch.rand(1) * eig_val_2

    # reconstruct diffusion tensor
    D = polar_composition(x_angles, y_angles, z_angles, eig_val_1, eig_val_2, eig_val_3)

    # check if the tensor has Cholesky decomposition
    try:
        torch.linalg.cholesky(D)
    except:
        print(i, 'No Cholesky decomposition!')
        break

    # decompose the tensor
    angle_x_, angle_y_, angle_z_, eig_val_1_, eig_val_2_, eig_val_3_ = polar_decomposition(D)

    # reconstruct the tensor
    D_ = polar_composition(angle_x_, angle_y_, angle_z_, eig_val_1_, eig_val_2_, eig_val_3_)

    if torch.allclose(D, D_):
        correct += 1
    
    total += 1
    
print(f'{100*correct/total:.2f}% correct')

In [None]:
def quaternion_decomposition(D: torch.Tensor) \
    -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    Calculates the quaternion decomposition parameters of a diffusion tensor.

    Parameters:
    D (torch.Tensor): The diffusion tensor. Shape (3, 3).

    Returns:
    q (torch.Tensor): The quaternion parameters. Shape (4,).
    eig_val_1 (torch.Tensor): The eigenvalue of the first eigenvector.
    eig_val_2 (torch.Tensor): The eigenvalue of the second eigenvector.
    eig_val_3 (torch.Tensor): The eigenvalue of the third eigenvector.
    '''
    # Calculate the eigenvalues and eigenvectors.
    eig_vals, eig_vecs = torch.linalg.eigh(D)

    # Sort the eigenvalues and eigenvectors.
    eig_vals, indices = torch.sort(eig_vals, descending=True)
    eig_vecs = eig_vecs[:, indices]
    
    # Get the eigenvalues.
    eig_val_1 = eig_vals[0]
    eig_val_2 = eig_vals[1]
    eig_val_3 = eig_vals[2]

    # Calculate quaternion
    q0 = torch.sqrt(1 + eig_vecs[0, 0] + eig_vecs[1, 1] + eig_vecs[2, 2]) / 2
    q1 = (eig_vecs[2, 1] - eig_vecs[1, 2]) / (4 * q0)
    q2 = (eig_vecs[0, 2] - eig_vecs[2, 0]) / (4 * q0)
    q3 = (eig_vecs[1, 0] - eig_vecs[0, 1]) / (4 * q0)

    # Normalize the quaternion
    q = torch.tensor([q0, q1, q2, q3])
    q /= torch.linalg.norm(q)

    return q, eig_val_1, eig_val_2, eig_val_3  

In [None]:
def quaternion_composition(q: torch.Tensor, eig_val_1: torch.Tensor, 
                           eig_val_2: torch.Tensor, eig_val_3: torch.Tensor) -> torch.Tensor:
    '''
    Reconstructs the tensor diffusion tensor from the quaternion decomposition parameters.

    Parameters:
    q (torch.Tensor): The quaternion parameters. Shape (4,).
    eig_val_1 (torch.Tensor): The eigenvalue of the first eigenvector.
    eig_val_2 (torch.Tensor): The eigenvalue of the second eigenvector.
    eig_val_3 (torch.Tensor): The eigenvalue of the third eigenvector.

    Returns:
    D (torch.Tensor): The reconstructed diffusion tensor. Shape (3, 3).
    '''

    # Calculate rotation matrix
    R = torch.tensor([
        [1 - 2*q[2]**2 - 2*q[3]**2, 2*q[1]*q[2] - 2*q[0]*q[3], 2*q[1]*q[3] + 2*q[0]*q[2]],
        [2*q[1]*q[2] + 2*q[0]*q[3], 1 - 2*q[1]**2 - 2*q[3]**2, 2*q[2]*q[3] - 2*q[0]*q[1]],
        [2*q[1]*q[3] - 2*q[0]*q[2], 2*q[2]*q[3] + 2*q[0]*q[1], 1 - 2*q[1]**2 - 2*q[2]**2]
    ])

    # Calculate the diagonal matrix of eigenvalues.
    eig_vals = torch.tensor([
        [eig_val_1, 0, 0],
        [0, eig_val_2, 0],
        [0, 0, eig_val_3]
    ])

    # Reconstruct the diffusion tensor.
    D = R @ eig_vals @ R.T

    return D

In [None]:
# quaternion decomposition - test 1

total = 0
correct = 0

for i in range(10_000):

    # reconstruct diffusion tensor
    D = create_random_diffusion_tensor()

    # check if the tensor has Cholesky decomposition
    try:
        torch.linalg.cholesky(D)
    except:
        print(i, 'No Cholesky decomposition!')
        break

    D_ = quaternion_composition(*quaternion_decomposition(D))

    if torch.allclose(D, D_):
        correct += 1
    
    total += 1

print(f'{100*correct/total:.2f}% correct')

In [None]:
# quaternion decomposition - test 2

total = 0
correct = 0

for i in range(10_000):

    # create random quaternion with values in range [-1, 1]
    q = torch.rand(4) * 2 - 1

    # normalize the quaternion
    q /= torch.linalg.norm(q)

    # create random eigenvalues in range [0, 1] and descending order
    eig_val_1 = torch.rand(1)
    eig_val_2 = torch.rand(1) * eig_val_1
    eig_val_3 = torch.rand(1) * eig_val_2

    # reconstruct diffusion tensor
    D = quaternion_composition(q, eig_val_1, eig_val_2, eig_val_3)

    # check if the tensor has Cholesky decomposition
    try:
        torch.linalg.cholesky(D)
    except:
        print(i, 'No Cholesky decomposition!')
        break

    # decompose the tensor
    q_, eig_val_1_, eig_val_2_, eig_val_3_ = quaternion_decomposition(D)

    # reconstruct the tensor
    D_ = quaternion_composition(q_, eig_val_1_, eig_val_2_, eig_val_3_)

    if torch.allclose(D, D_):
        correct += 1
    
    total += 1
    
print(f'{100*correct/total:.2f}% correct')

#
---
---

In [2]:
def sample_gradients(n: int, seed: int) -> torch.Tensor:
    '''
    Samples n gradient vectors from a uniform distribution on the unit sphere.

    Parameters:
    n (int): The number of gradient vectors to sample.

    Returns:
    g (torch.Tensor): The sampled gradient vectors. Shape (n, 3).
    '''

    # Save the current state of the random number generator
    state = torch.random.get_rng_state()

    # Set the seed
    torch.manual_seed(seed)

    # Sample n angles from a uniform distribution on [0, 2pi].
    theta_angles = torch.rand(n) * 2 * torch.pi

    # Sample n angles from a uniform distribution on [0, pi].
    phi_angles = torch.rand(n) * torch.pi

    # Convert the angles to cartesian coordinates.
    x = torch.sin(phi_angles) * torch.cos(theta_angles)
    y = torch.sin(phi_angles) * torch.sin(theta_angles)
    z = torch.cos(phi_angles)

    # Stack the coordinates together
    g = torch.stack((x, y, z), dim=-1)

    # Restore the state of the random number generator
    torch.random.set_rng_state(state)

    return g

In [None]:
# load lstsq_results

lstsq_results: dict[Tuple[float, float, float], np.ndarray]

with open('lstsq_results.pkl', 'rb') as f:
    lstsq_results = pickle.load(f)

In [None]:
# calculate d_tensors and noisy_signals

num_b_0 = 10
num_b_1k = 128
SNR = 20
seed = 0

b_values = torch.tensor([0.0] * num_b_0 + [1_000.0] * num_b_1k)
gradients = sample_gradients(num_b_0 + num_b_1k, seed)

d_tensors = []
noisy_signals = []

for d_array in tqdm(lstsq_results.values()):

    d_tensor = torch.tensor(d_array).float()

    signal = torch.exp(- b_values * torch.einsum('bi, bij, bj -> b', 
                                                  gradients,
                                                  d_tensor.unsqueeze(0).repeat(num_b_0 + num_b_1k, 1, 1),
                                                  gradients
                                                  ))
    
    noise = torch.normal(mean=torch.zeros_like(signal), std=torch.ones_like(signal) / SNR)

    noisy_signal = signal + noise

    noisy_signal /= torch.mean(noisy_signal[:num_b_0])

    d_tensors.append(d_tensor)
    noisy_signals.append(noisy_signal)

d_tensors = torch.stack(d_tensors) # shape (num_d_tensors, 3, 3)
noisy_signals = torch.stack(noisy_signals) # shape (num_d_tensors, num_b_0 + num_b_1k)

In [None]:
# save d_tensors and noisy_signals

torch.save(d_tensors, 'd_tensors.pt')
torch.save(noisy_signals, 'noisy_signals.pt')

In [None]:
# load d_tensors and noisy_signals

d_tensors = torch.load('d_tensors.pt')
noisy_signals = torch.load('noisy_signals.pt')

In [None]:
class DiffusionDataset(Dataset):
    
    def __init__(self, d_tensors: torch.Tensor, noisy_signals: torch.Tensor):
        super().__init__()
        self.d_tensors = d_tensors
        self.noisy_signals = noisy_signals

    def __len__(self):
        return self.d_tensors.shape[0]

    def __getitem__(self, idx):
        return self.d_tensors[idx], self.noisy_signals[idx]

In [None]:
# shuffle the data and split 80% train and 20% validation

dataset = DiffusionDataset(d_tensors, noisy_signals)
train_size = int(0.8 * len(dataset))

generator = torch.Generator().manual_seed(seed)
train_dataset, val_dataset = random_split(dataset, lengths=[train_size, len(dataset) - train_size], generator=generator)

In [None]:
# create dataloaders

batch_size = 32

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, drop_last=True)

In [None]:
class PolarDiffusionNet(torch.nn.Module):
    
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super().__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.net = torch.nn.Sequential(
            torch.nn.Linear(self.input_size, self.hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(self.hidden_size, self.hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(self.hidden_size, self.output_size)
        )

    def forward(self, net_input: torch.Tensor) -> torch.Tensor:
        net_output = self.net(net_input)
        activated_net_output = self.activate(net_output)
        predictions = self.reconstruct(activated_net_output)
        return predictions
    
    def activate(self, net_output: torch.Tensor) -> torch.Tensor:
        
        # Split the tensor into individual elements
        x_angles, y_angles, z_angles, eig_val_1, eig_val_2_over_1, eig_val_3_over_2 = torch.split(net_output, split_size_or_sections=1, dim=1)
        
        # Apply different activation functions to each element
        x_angles = torch.sigmoid(x_angles)
        y_angles = torch.sigmoid(y_angles)
        z_angles = torch.sigmoid(z_angles)
        eig_val_1 = torch.relu(eig_val_1)
        eig_val_2_over_1 = torch.sigmoid(eig_val_2_over_1)
        eig_val_3_over_2 = torch.sigmoid(eig_val_3_over_2)

        # Return the activated tensor
        activated_net_output = torch.cat((x_angles, y_angles, z_angles, eig_val_1, eig_val_2_over_1, eig_val_3_over_2), dim=1)
        
        return activated_net_output
    
    def reconstruct(self, activated_net_output: torch.Tensor) -> torch.Tensor:
        
        # Split the tensor into individual elements
        x_angles, y_angles, z_angles, eig_val_1, eig_val_2_over_1, eig_val_3_over_2 = torch.split(activated_net_output, split_size_or_sections=1, dim=1)      
        
        # Compute angles
        x_angles = x_angles.squeeze() * 2 * torch.pi
        y_angles = y_angles.squeeze() * 2 * torch.pi
        z_angles = z_angles.squeeze() * 2 * torch.pi
        
        # Compute eigenvalues
        eig_val_1 = eig_val_1.squeeze()
        eig_val_2 = eig_val_1 * eig_val_2_over_1.squeeze()
        eig_val_3 = eig_val_2 * eig_val_3_over_2.squeeze()

        # Create the roation matrices around the x axis.
        R_x = torch.zeros((x_angles.shape[0], 3, 3))
        R_x[:, 0, 0] = 1
        R_x[:, 1, 1] = torch.cos(x_angles)
        R_x[:, 1, 2] = -torch.sin(x_angles)
        R_x[:, 2, 1] = torch.sin(x_angles)
        R_x[:, 2, 2] = torch.cos(x_angles)

        # Create the roation matrices around the y axis.
        R_y = torch.zeros((y_angles.shape[0], 3, 3))
        R_y[:, 0, 0] = torch.cos(y_angles)
        R_y[:, 0, 2] = torch.sin(y_angles)
        R_y[:, 1, 1] = 1
        R_y[:, 2, 0] = -torch.sin(y_angles)
        R_y[:, 2, 2] = torch.cos(y_angles)

        # Create the roation matrices around the z axis.
        R_z = torch.zeros((z_angles.shape[0], 3, 3))
        R_z[:, 0, 0] = torch.cos(z_angles)
        R_z[:, 0, 1] = -torch.sin(z_angles)
        R_z[:, 1, 0] = torch.sin(z_angles)
        R_z[:, 1, 1] = torch.cos(z_angles)
        R_z[:, 2, 2] = 1

        # Calculate the rotation matrices.
        R = torch.bmm(R_z, torch.bmm(R_y, R_x))

        # Calculate the diagonal matrix of eigenvalues.
        eig_vals = torch.zeros((eig_val_1.shape[0], 3, 3))
        eig_vals[:, 0, 0] = eig_val_1
        eig_vals[:, 1, 1] = eig_val_2
        eig_vals[:, 2, 2] = eig_val_3

        # Reconstruct the diffusion tensors.
        D = torch.bmm(R, torch.bmm(eig_vals, R.transpose(1, 2)))

        return D

In [None]:
batch_d_tensors, batch_noisy_signals = next(iter(train_dataloader))

In [None]:
polar_dnet = PolarDiffusionNet(input_size=num_b_0 + num_b_1k, hidden_size=64, output_size=6)

In [None]:
pred_d_tensors = polar_dnet(batch_noisy_signals)

In [None]:
loss = torch.nn.MSELoss()

In [None]:
loss(pred_d_tensors, batch_d_tensors)