In [1]:
import torch
import numpy as np

# Change import path to access tfn_torch
import sys
import os
sys.path.append(os.path.abspath('..'))
from models.tfn_torch import *

In [2]:
# Test get_eijk function
eijk = get_eijk()
print("Levi-Civita Tensor (eijk):")
print(eijk)


Levi-Civita Tensor (eijk):
tensor([[[ 0.,  0.,  0.],
         [ 0.,  0.,  1.],
         [ 0., -1.,  0.]],

        [[ 0.,  0., -1.],
         [ 0.,  0.,  0.],
         [ 1.,  0.,  0.]],

        [[ 0.,  1.,  0.],
         [-1.,  0.,  0.],
         [ 0.,  0.,  0.]]])


In [3]:
# Test norm_with_epsilon function
tensor = torch.randn(5, 5, 3)  # Random tensor of shape [5, 5, 3]
norm = norm_with_epsilon(tensor, dim=-1)
print("Norm with epsilon (sample tensor):")
print(norm)


Norm with epsilon (sample tensor):
tensor([[2.6316, 0.9024, 2.0546, 1.6455, 1.5706],
        [1.2793, 0.8458, 1.2129, 2.1174, 1.3368],
        [1.2780, 1.8707, 2.6705, 3.0279, 1.2397],
        [1.6068, 1.7479, 1.6441, 2.6085, 1.0254],
        [1.2805, 2.7064, 1.6735, 2.1859, 1.8716]])


In [4]:
# Test shifted soft plus function (ssp)
x = torch.randn(10)  # A random vector
ssp_out = ssp(x)
print("Shifted Soft Plus output:")
print(ssp_out)


Shifted Soft Plus output:
tensor([-0.0072, -0.3984, -0.0549, -0.0681,  0.1712,  0.1557, -0.1043,  0.0500,
        -0.4511, -0.0986])


In [5]:
# Test rotation equivariant nonlinearity
x = torch.randn(10, 5, 1)  # Random tensor of shape [10, 5, 1] (channels, M)
output = rotation_equivariant_nonlinearity(x)
print("Rotation Equivariant Nonlinearity Output:")
print(output)


Rotation Equivariant Nonlinearity Output:
tensor([[[ 0.4601],
         [-0.5550],
         [ 1.1628],
         [ 0.3020],
         [-0.1060]],

        [[-0.4740],
         [ 0.7780],
         [-0.4257],
         [ 0.1247],
         [ 0.2622]],

        [[ 0.0193],
         [-0.2629],
         [ 0.1126],
         [-0.1536],
         [ 0.2525]],

        [[ 0.9076],
         [ 0.6009],
         [ 0.2431],
         [-0.3391],
         [ 0.3530]],

        [[-0.0517],
         [-0.3011],
         [ 0.3752],
         [ 1.4019],
         [-0.1941]],

        [[-0.3680],
         [-0.4822],
         [ 0.3700],
         [-0.2205],
         [-0.0439]],

        [[-0.5906],
         [-0.3711],
         [ 0.6829],
         [ 1.0274],
         [ 0.0671]],

        [[ 0.2444],
         [ 0.0447],
         [ 0.6268],
         [ 0.3311],
         [-0.3509]],

        [[-0.1242],
         [-0.3000],
         [-0.3512],
         [ 0.1351],
         [ 0.7347]],

        [[ 0.4957],
         [-0.1904],


In [6]:
# Test difference matrix
geometry = torch.randn(4, 3)  # Random set of 4 points in 3D
diff_matrix = difference_matrix(geometry)
print("Difference Matrix (relative vectors):")
print(diff_matrix)

# Test distance matrix
dist_matrix = distance_matrix(geometry)
print("Distance Matrix:")
print(dist_matrix)


Difference Matrix (relative vectors):
tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.2452,  0.9747,  1.6944],
         [-0.4484,  2.8429, -0.6752],
         [ 0.0074, -0.0919, -1.0078]],

        [[-0.2452, -0.9747, -1.6944],
         [ 0.0000,  0.0000,  0.0000],
         [-0.6936,  1.8682, -2.3696],
         [-0.2379, -1.0666, -2.7022]],

        [[ 0.4484, -2.8429,  0.6752],
         [ 0.6936, -1.8682,  2.3696],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.4558, -2.9348, -0.3326]],

        [[-0.0074,  0.0919,  1.0078],
         [ 0.2379,  1.0666,  2.7022],
         [-0.4558,  2.9348,  0.3326],
         [ 0.0000,  0.0000,  0.0000]]])
Distance Matrix:
tensor([[1.0000e-04, 1.9701e+00, 2.9562e+00, 1.0120e+00],
        [1.9701e+00, 1.0000e-04, 3.0962e+00, 2.9148e+00],
        [2.9562e+00, 3.0962e+00, 1.0000e-04, 2.9885e+00],
        [1.0120e+00, 2.9148e+00, 2.9885e+00, 1.0000e-04]])


In [7]:
# Test random rotation matrix
rng = np.random.RandomState(42)  # Use a fixed random state for reproducibility
rand_rotation_matrix = random_rotation_matrix(rng)
print("Random Rotation Matrix:")
print(rand_rotation_matrix)


Random Rotation Matrix:
[[ 0.71633922 -0.6942997   0.0693257 ]
 [ 0.60546963  0.56914161 -0.55631318]
 [ 0.34679193  0.44048356  0.82807584]]


In [8]:
# Test rotation matrix
axis = np.array([0, 0, 1])  # Rotation around the Z-axis
theta = np.pi / 4  # 45 degrees
rotation_matrix_result = rotation_matrix(axis, theta)
print("Rotation Matrix (45 degrees around Z-axis):")
print(rotation_matrix_result)


Rotation Matrix (45 degrees around Z-axis):
[[ 0.70710678 -0.70710678  0.        ]
 [ 0.70710678  0.70710678  0.        ]
 [ 0.          0.          1.        ]]
