In [1]:
# Import required libraries
import numpy as np
import torch
import torch.nn.functional as F
from numpy.linalg import svd, matrix_rank

# Initialize random seed for reproducibility
np.random.seed(0)
torch.manual_seed(0)

# Step 1: Simulating group and transformations
class Group:
    """
    Abstract group class with input/output transformations
    """
    def __init__(self, input_matrices, output_matrices):
        self.repr_size_in = input_matrices[0].shape[1]
        self.repr_size_out = output_matrices[0].shape[1]
        self.input_matrices = input_matrices
        self.output_matrices = output_matrices
        self.parameters = range(len(input_matrices))

    def _input_transformation(self, weights, params):
        print(f"Applying input transformation for parameter {params}:")
        transformed_weights = np.matmul(weights, self.input_matrices[params])
        print(transformed_weights)
        return transformed_weights

    def _output_transformation(self, weights, params):
        print(f"Applying output transformation for parameter {params}:")
        transformed_weights = np.matmul(self.output_matrices[params], weights)
        print(transformed_weights)
        return transformed_weights

# Example input and output matrices (for simplicity using 2x2 matrices)
input_matrices =  [
        torch.FloatTensor(np.eye(4)), 
        torch.FloatTensor(-1 * np.eye(4))
    ]
output_matrices =  [
        torch.FloatTensor(np.eye(1)), 
        torch.FloatTensor(-1 * np.eye(1))
    ]

group = Group(input_matrices, output_matrices)

# Step 2: Initialize a random weight matrix
size = (group.repr_size_out, group.repr_size_in) # Example shape
w = np.random.randn(*size)
print("Initial random weight matrix:")
print(w)

# Step 3: Symmetrize the weight matrix
Wsym = np.zeros_like(w)
for param in group.parameters:
    input_trans = group._input_transformation(w, param)
    Wsym = np.concatenate((Wsym, group._output_transformation(input_trans, param)), axis=0)
print("\nSymmetrized weight matrix:")
print(Wsym)

# Step 4: Vectorize the symmetrized matrix
wvec = np.reshape(Wsym, [Wsym.shape[0], -1])
print("\nVectorized weight matrix:")
print(wvec)

# Step 5: Perform SVD (Singular Value Decomposition)
u, s, vh = svd(wvec)
print("\nSingular values from SVD:")
print(s)

# Step 6: Calculate the rank of the matrix
rank = matrix_rank(wvec)
print(f"\nRank of the matrix: {rank}")

# # Step 7: Construct the basis using the right-singular vectors (vh)
# indices = [-1, 1, Wsym.shape[1], Wsym.shape[2]]  # Example new size
# basis = np.reshape(vh[:rank, ...], indices)
# print("\nConstructed basis from SVD:")
# print(basis)

# # Step 8: Create a PyTorch tensor for the basis
# basis_tensor = torch.tensor(basis, dtype=torch.float64).clone().detach().requires_grad_(False)
# print("\nConverted basis to PyTorch tensor:")
# print(basis_tensor)


Initial random weight matrix:
[[1.76405235 0.40015721 0.97873798 2.2408932 ]]
Applying input transformation for parameter 0:
tensor([[1.7641, 0.4002, 0.9787, 2.2409]], dtype=torch.float64)
Applying output transformation for parameter 0:
tensor([[1.7641, 0.4002, 0.9787, 2.2409]], dtype=torch.float64)
Applying input transformation for parameter 1:
tensor([[-1.7641, -0.4002, -0.9787, -2.2409]], dtype=torch.float64)
Applying output transformation for parameter 1:
tensor([[1.7641, 0.4002, 0.9787, 2.2409]], dtype=torch.float64)

Symmetrized weight matrix:
[[0.         0.         0.         0.        ]
 [1.76405235 0.40015721 0.97873798 2.2408932 ]
 [1.76405235 0.40015721 0.97873798 2.2408932 ]]

Vectorized weight matrix:
[[0.         0.         0.         0.        ]
 [1.76405235 0.40015721 0.97873798 2.2408932 ]
 [1.76405235 0.40015721 0.97873798 2.2408932 ]]

Singular values from SVD:
[4.30151993e+00 2.74900183e-16 0.00000000e+00]

Rank of the matrix: 1
