In [79]:
import math
import numpy as np
import torch

In [80]:
### validate non-vectorized version ###
# in the non-vectorized version, input arguments align with old code. so we are back compatible

In [81]:
# Exemplary instantiations of functions: a_fun, b_fun, A_fun, B_fun.
# The functions need to satisfy conditions:
#
# a_fun(x) * b_fun(x) = 2.0 * math.pi * 1j * x
# A_fun(x, b) * B_fun(x, b) = np.exp(2.0 * math.pi * 1j * x * b)

a_fun = lambda x: 2.0 * math.pi * 1j * x
b_fun = lambda x: 1
A_fun = lambda x, b: np.exp(2.0 * math.pi * 1j * x * b)
B_fun = lambda x, b: 1


In [82]:
# Main mapping producing <num_rand_features>-dimensional random feature vector
# and used to approximate: E_{xi~P} [exp(2*pi*i*xi*(x^t*w+bias_term))].
# Can be used to approximate:
# f(x^t*w+b) = \int_{R} [F(xi)exp(2*pi*i*xi*(x^t*w+bias_term)) dxi], where
# F is a Fourier Transform of f.
#
# xw: input dim-dimensional vector (either x or w)
# bias_term: the mapping is used to approximate: f(x^tw + bias_term)
# xis: <num_rand_features>-dimensional vector of scalars
#      (depends on the form of the approximated f)
# large_fun: A_fun (if input vector is x) or B_fun (if input vector is w)
# small_fun: a_fun (if input vector is x) or B_fun (if input vector is w)
# num_rand_features: number of random features used
# M: hyperparameter of the algorithm (positive scalar)
# dim: dimensionality of the input vector
# seed: seed used contruct Gaussian projections


def phi_mapping(xw, bias_term, xis, large_fun, small_fun, num_rand_features, M, dim, seed=0):
  small_coeffs = np.apply_along_axis(small_fun, 0, xis)
  large_bias_fun = lambda x: large_fun(x, bias_term)
  large_coeffs = np.apply_along_axis(large_bias_fun, 0, xis)
  # np.random.seed(seed)
  # gs = np.random.normal(size=(num_rand_features, dim))
  # using torch.rand to match the random numbers in the vectorized version (np.random and torch.rand give different numbers for the same seed)
  torch.manual_seed(seed)
  gs = torch.normal(mean=0, std=1, size=(num_rand_features, dim))
  gs = gs.numpy()
  renorm_gs = np.transpose(np.sqrt(1+4.0 * M) * small_coeffs * np.transpose(gs))
  dot_products = np.einsum('ij,j->i', renorm_gs, xw)
  squared_xw = np.sum(xw * xw)
  correction_vector = (squared_xw / 2) * small_coeffs * small_coeffs
  correction_vector += M * np.linalg.norm(gs, axis=-1) * np.linalg.norm(gs, axis=-1)
  diff_vector = dot_products - correction_vector

  return (1.0 / np.sqrt(num_rand_features)) * large_coeffs * np.exp(diff_vector)


In [83]:
# def phi_mapping(xw, bias_term, xis, large_fun, small_fun, num_rand_features, M, dim, seed=0):
def input_to_rfs_torch(xw, AB_fun, ab_fun, xis, num_rfs, dim, device,
                                  seed=0, normalize=False, normalization_constant=None,
                                  orthogonal=False, proj_matrix=None, bias_term=0, M=1):
  small_coeffs = np.apply_along_axis(ab_fun, 0, xis)
  large_bias_fun = lambda x: AB_fun(x, bias_term)
  large_coeffs = np.apply_along_axis(large_bias_fun, 0, xis)
  # np.random.seed(seed)
  # gs = np.random.normal(size=(num_rfs, dim))
  # using torch.rand to match the random numbers in the vectorized version (np.random and torch.rand give different numbers for the same seed)
  torch.manual_seed(seed)
  gs = torch.normal(mean=0, std=1, size=(num_rand_features, dim))
  gs = gs.numpy()
  renorm_gs = np.transpose(np.sqrt(1+4.0 * M) * small_coeffs * np.transpose(gs))
  dot_products = np.einsum('ij,j->i', renorm_gs, xw)
  squared_xw = np.sum(xw * xw)
  correction_vector = (squared_xw / 2) * small_coeffs * small_coeffs
  correction_vector += M * np.linalg.norm(gs, axis=-1) * np.linalg.norm(gs, axis=-1)
  diff_vector = dot_products - correction_vector

  return (1.0 / np.sqrt(num_rfs)) * large_coeffs * np.exp(diff_vector)


In [84]:
# Usage example for: cos(x^Tw + bias).
dim = 3
L = 4
bias = 1.0
X = np.random.uniform(size=(L, dim))
W = np.random.uniform(size=(L, dim))
G = np.einsum('ij,kj->ik', X, W)

groundtruth_values = np.cos(G + bias)
print("Groundtruth:")
print(groundtruth_values)

# Hyperparameters of the algorithm.
M = 1.0
num_rand_features = 100000
seed = 0

# xis scalars for the cosine mapping f.
xis_plus = np.ones(num_rand_features) * (1.0 / (2.0 * math.pi))
xis_minus = - xis_plus

# Computing random features.
results = []
for i in range(L):
  first_rfv_plus = phi_mapping(X[i], bias, xis_plus, A_fun, a_fun, num_rand_features, M, dim, seed)
  first_rfv_minus = phi_mapping(X[i], bias, xis_minus, A_fun, a_fun, num_rand_features, M, dim, seed)
  first_rfv = (1.0 / np.sqrt(2.0)) * np.power(1.0 + 4.0 * M, dim / 4.0) * np.concatenate((first_rfv_plus, first_rfv_minus), axis=-1)
  for j in range(L):
    second_rfv_plus = phi_mapping(W[j], bias, xis_plus, B_fun, b_fun, num_rand_features, M, dim, seed)
    second_rfv_minus = phi_mapping(W[j], bias, xis_minus, B_fun, b_fun, num_rand_features, M, dim, seed)
    # Final rfvs for the cosine mapping.
    second_rfv = (1.0 / np.sqrt(2.0)) * np.power(1.0 + 4.0 * M, dim / 4.0) * np.concatenate((second_rfv_plus, second_rfv_minus), axis=-1)

    approximate_value = 0.5 * np.power(1.0 + M, dim / 2.0) * (np.dot(first_rfv_plus, second_rfv_plus) + np.dot(first_rfv_minus, second_rfv_minus))
    approximate_value = np.dot(first_rfv, second_rfv)
    results.append(np.real(approximate_value))
print("Approximate")
print(np.reshape(np.array(results), (L, L)))

# validate back compatible code
results_2 = []
for i in range(L):
  first_rfv_plus = input_to_rfs_torch(X[i], A_fun, a_fun, xis_plus, num_rand_features, dim, device='cpu', seed=seed, bias_term=bias, M=M)
  first_rfv_minus = input_to_rfs_torch(X[i], A_fun, a_fun, xis_minus, num_rand_features, dim, device='cpu', seed=seed, bias_term=bias, M=M)
  first_rfv = (1.0 / np.sqrt(2.0)) * np.power(1.0 + 4.0 * M, dim / 4.0) * np.concatenate((first_rfv_plus, first_rfv_minus), axis=-1)
  for j in range(L):
    second_rfv_plus = input_to_rfs_torch(W[j], B_fun, b_fun, xis_plus, num_rand_features, dim, device='cpu', seed=seed, bias_term=bias, M=M)
    second_rfv_minus = input_to_rfs_torch(W[j], B_fun, b_fun, xis_minus, num_rand_features, dim, device='cpu', seed=seed, bias_term=bias, M=M)

    # Final rfvs for the cosine mapping.
    second_rfv = (1.0 / np.sqrt(2.0)) * np.power(1.0 + 4.0 * M, dim / 4.0) * np.concatenate((second_rfv_plus, second_rfv_minus), axis=-1)

    approximate_value = 0.5 * np.power(1.0 + M, dim / 2.0) * (np.dot(first_rfv_plus, second_rfv_plus) + np.dot(first_rfv_minus, second_rfv_minus))
    approximate_value = np.dot(first_rfv, second_rfv)
    results_2.append(np.real(approximate_value))
print("Approximate")
print(np.reshape(np.array(results_2), (L, L)))


Groundtruth:
[[-0.09871976 -0.46026593 -0.32263711  0.18555275]
 [ 0.4152287   0.13229234  0.1517042   0.00663355]
 [-0.216538   -0.8542541  -0.68533179 -0.47266385]
 [ 0.06901962 -0.68062977 -0.34661228 -0.36555236]]
Approximate
[[-0.10751362 -0.46808341 -0.3317336   0.17797384]
 [ 0.41773888  0.13403115  0.15296017  0.00363442]
 [-0.23466393 -0.86620414 -0.70998969 -0.48148429]
 [ 0.0671927  -0.6823076  -0.3509625  -0.36831847]]
Approximate
[[-0.10751362 -0.46808341 -0.3317336   0.17797384]
 [ 0.41773888  0.13403115  0.15296017  0.00363442]
 [-0.23466393 -0.86620414 -0.70998969 -0.48148429]
 [ 0.0671927  -0.6823076  -0.3509625  -0.36831847]]


In [85]:
# Usage example for: sin(x^Tw + bias).
dim = 3
L = 4
bias = 0.5
X = np.random.uniform(size=(L, dim))
W = np.random.uniform(size=(L, dim))
G = np.einsum('ij,kj->ik', X, W)

groundtruth_values = np.sin(G + bias)
print("Groundtruth:")
print(groundtruth_values)

# Hyperparameters of the algorithm.
M = 1.0
num_rand_features = 100000
seed = 0

# xis scalars for the cosine mapping f.
xis_plus = np.ones(num_rand_features) * (1.0 / (2.0 * math.pi))
xis_minus = - xis_plus

# Computing random features.
results = []
for i in range(L):
  for j in range(L):
    first_rfv_plus = phi_mapping(X[i], bias, xis_plus, A_fun, a_fun, num_rand_features, M, dim, seed)
    second_rfv_plus = phi_mapping(W[j], bias, xis_plus, B_fun, b_fun, num_rand_features, M, dim, seed)
    first_rfv_minus = phi_mapping(X[i], bias, xis_minus, A_fun, a_fun, num_rand_features, M, dim, seed)
    second_rfv_minus = phi_mapping(W[j], bias, xis_minus, B_fun, b_fun, num_rand_features, M, dim, seed)

    # Final rfvs for the cosine mapping.
    first_rfv = (1.0 / 2.0) * np.power(1.0 + 4.0 * M, dim / 4.0) * np.concatenate((first_rfv_plus, first_rfv_minus), axis=-1)
    second_rfv = (1.0 / 1.0j) * np.power(1.0 + 4.0 * M, dim / 4.0) * np.concatenate((second_rfv_plus, -second_rfv_minus), axis=-1)

    approximate_value = 0.5 * np.power(1.0 + M, dim / 2.0) * (np.dot(first_rfv_plus, second_rfv_plus) + np.dot(first_rfv_minus, second_rfv_minus))
    approximate_value = np.dot(first_rfv, second_rfv)
    results.append(np.real(approximate_value))
print("Approximate:")
print(np.reshape(np.array(results), (L, L)))

# validate back compatible code
results_2 = []
for i in range(L):
  first_rfv_plus = input_to_rfs_torch(X[i], A_fun, a_fun, xis_plus, num_rand_features, dim, device='cpu', seed=seed, bias_term=bias, M=M)
  first_rfv_minus = input_to_rfs_torch(X[i], A_fun, a_fun, xis_minus, num_rand_features, dim, device='cpu', seed=seed, bias_term=bias, M=M)
  first_rfv = (1.0 / 2.0) * np.power(1.0 + 4.0 * M, dim / 4.0) * np.concatenate((first_rfv_plus, first_rfv_minus), axis=-1)
  for j in range(L):
    second_rfv_plus = input_to_rfs_torch(W[j], B_fun, b_fun, xis_plus, num_rand_features, dim, device='cpu', seed=seed, bias_term=bias, M=M)
    second_rfv_minus = input_to_rfs_torch(W[j], B_fun, b_fun, xis_minus, num_rand_features, dim, device='cpu', seed=seed, bias_term=bias, M=M)

    # Final rfvs for the cosine mapping.
    second_rfv = (1.0 / 1.0j) * np.power(1.0 + 4.0 * M, dim / 4.0) * np.concatenate((second_rfv_plus, -second_rfv_minus), axis=-1)

    approximate_value = 0.5 * np.power(1.0 + M, dim / 2.0) * (np.dot(first_rfv_plus, second_rfv_plus) + np.dot(first_rfv_minus, second_rfv_minus))
    approximate_value = np.dot(first_rfv, second_rfv)
    results_2.append(np.real(approximate_value))
print("Approximate")
print(np.reshape(np.array(results_2), (L, L)))

Groundtruth:
[[0.99336994 0.99537018 0.97433318 0.92207096]
 [0.88893583 0.93166626 0.83373491 0.81176232]
 [0.90103314 0.80373569 0.887045   0.76288439]
 [0.75886593 0.84960487 0.6876303  0.71699613]]
Approximate:
[[0.99824607 1.01340828 0.97738155 0.929027  ]
 [0.89181886 0.94139565 0.83576861 0.81486762]
 [0.90064885 0.80887199 0.88611532 0.76211318]
 [0.7575847  0.85446197 0.68558617 0.71637163]]
Approximate
[[0.99824607 1.01340828 0.97738155 0.929027  ]
 [0.89181886 0.94139565 0.83576861 0.81486762]
 [0.90064885 0.80887199 0.88611532 0.76211318]
 [0.7575847  0.85446197 0.68558617 0.71637163]]


In [86]:
# # RELU variant:

# def phi_relu_mapping(xw, num_rand_features, dim, seed=0):
#   np.random.seed(seed)
#   gs = np.random.normal(size=(num_rand_features, dim))
#   dot_products = np.einsum('ij,j->i', gs, xw)
#   return (1.0 / np.sqrt(num_rand_features)) * np.maximum(dot_products, np.zeros(num_rand_features))

In [87]:
### validate vectorized version ###

In [88]:
from google.colab import drive
drive.mount('/content/drive')
import os
os.chdir('/content/drive/MyDrive/NNK/vectorized_validation/new_nnk')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [89]:
import math
from torch import nn
from performer_attention import gaussian_orthogonal_random_matrix

def torch_apply_along_axis(function, x, axis: int = 0, b=None):
    """
    Torch equivalent of numpy apply along axis. This function is slow and should be avoided
    https://discuss.pytorch.org/t/apply-a-function-along-an-axis/130440
    """
    if b is None:
      return torch.stack([
          function(x_i) for x_i in torch.unbind(x, dim=axis)
      ], dim=axis)
    else:
      return torch.stack([
          function(x_i, b) for x_i in torch.unbind(x, dim=axis)
      ], dim=axis)

# A_fun = lambda x, b: np.exp(2.0 * math.pi * 1j * x * b)

def input_to_rfs_torch_vectorized(xw, AB_fun, ab_fun, xis, num_rfs, dim, device,
                                  seed=0, normalize=False, normalization_constant=None,
                                  orthogonal=False, proj_matrix=None, bias_term=0, M=1):
    if normalize :
      if normalization_constant is None :
        xw = torch.nn.functional.normalize(xw)
      else :
        xw = normalization_constant*torch.nn.functional.normalize(xw)

    ab_coeffs = torch_apply_along_axis(ab_fun, xis, 0)
    AB_coeffs = torch_apply_along_axis(AB_fun, xis, 0, bias_term)
    torch.manual_seed(seed)

    if proj_matrix is None:
      if device == 'cpu':
        if orthogonal is False :
          gs = torch.normal(mean=0, std=1, size=(num_rand_features, dim))
        else :
          gs = gaussian_orthogonal_random_matrix(num_rfs, dim, scaling = 0, device = 'cpu')
      else :
        if orthogonal is False :
          gs = torch.normal(mean=0, std=1, size=(num_rand_features, dim))
        else :
          gs = gaussian_orthogonal_random_matrix(num_rfs, dim, scaling = 0, device = 'cuda')
    else :
      if device == 'cpu':
        gs = proj_matrix
      else :
         gs = proj_matrix.cuda()
    renorm_gs = (torch.sqrt(torch.Tensor([1+4.0 * M])) * ab_coeffs * gs.t()).t()
    if len(xw.shape) == 2 :
      # renorm_gs has complex numbers, and torch.eimsum needs both inputs to be complex doubles
      dot_products = torch.einsum('ij,jk->ik', xw.type(torch.complex128), renorm_gs.t())
    elif len(xw.shape) == 3:
      dot_products = torch.einsum('bij,jk->bik', xw.type(torch.complex128), renorm_gs.t())
    else :
      raise ValueError("Unsuported Tensor shape")
    squared_xw = torch.sum(torch.mul(xw, xw), dim=-1) #do not keepdims here
    if len(squared_xw.shape) == 1 :
      correction_vector = torch.outer(squared_xw / 2, torch.mul(ab_coeffs, ab_coeffs))
    elif len(squared_xw.shape) == 2 :
      correction_vector = torch.einsum('pq, r->pqr', squared_xw, torch.mul(ab_coeffs, ab_coeffs))
    else :
      raise ValueError("Unsupported tensor shape of xw")
    correction_vector += M * torch.linalg.norm(gs, axis=-1) * torch.linalg.norm(gs, axis=-1)
    # check the shape of dot_products and correctino_vector
    diff_vector = dot_products - correction_vector
    return (1.0 / math.sqrt(num_rfs)) * AB_coeffs * torch.exp(diff_vector)




In [90]:
# test vectorized version
np.random.seed(0)
x_vec = np.random.uniform(size=(L, dim))
x_rfs_concatenated = []
for i in range(len(x_vec)):
  x_rfs = input_to_rfs_torch(x_vec[i], A_fun, a_fun, xis_plus, num_rand_features, dim, device='cpu', bias_term=bias)
  x_rfs_concatenated.append(x_rfs)
x_rfs_concatenated = np.array(x_rfs_concatenated)
print(x_rfs_concatenated)

[[-3.97686176e-04-3.16334357e-05j -9.01379439e-04+1.09944888e-03j
  -5.04900796e-05-1.59215933e-05j ... -3.49882371e-05-2.67185725e-04j
  -3.06621338e-05+3.13011013e-05j -2.88912889e-04+8.70202255e-04j]
 [-2.37251213e-04-2.52395717e-04j -2.50101786e-04+1.20886212e-03j
   6.75237992e-06-4.54695463e-05j ... -8.14224056e-05-2.19351461e-04j
  -2.75699046e-05+2.62181249e-05j -5.24427476e-04+5.99017519e-04j]
 [-5.51959760e-04+1.69438731e-04j -2.04345478e-03-2.41017231e-04j
  -7.52261685e-05+1.45501804e-05j ...  4.11820670e-06-3.89972237e-04j
   2.07383364e-05+5.99285747e-05j -1.15621185e-03+6.51281772e-04j]
 [-3.53788441e-04-1.25095706e-04j -1.02909432e-03+8.54003487e-04j
  -4.97961232e-05-3.53509038e-07j ...  1.45428747e-04-2.07594760e-04j
  -4.11722616e-05+1.87965277e-06j  7.32151897e-05+8.59350840e-04j]]


In [91]:
# test vectorized version
x_vec_tensor = torch.from_numpy(x_vec)
xis_plus_tensor = torch.from_numpy(xis_plus)
x_rfs_vec = input_to_rfs_torch_vectorized(x_vec_tensor, A_fun, a_fun, xis_plus_tensor, num_rand_features, dim, device='cpu', bias_term=bias)
print(x_rfs_vec)

tensor([[-3.9769e-04-3.1633e-05j, -9.0138e-04+1.0994e-03j,
         -5.0490e-05-1.5922e-05j,  ...,
         -3.4988e-05-2.6719e-04j, -3.0662e-05+3.1301e-05j,
         -2.8891e-04+8.7020e-04j],
        [-2.3725e-04-2.5240e-04j, -2.5010e-04+1.2089e-03j,
          6.7524e-06-4.5470e-05j,  ...,
         -8.1422e-05-2.1935e-04j, -2.7570e-05+2.6218e-05j,
         -5.2443e-04+5.9902e-04j],
        [-5.5196e-04+1.6944e-04j, -2.0435e-03-2.4102e-04j,
         -7.5226e-05+1.4550e-05j,  ...,
          4.1182e-06-3.8997e-04j,  2.0738e-05+5.9929e-05j,
         -1.1562e-03+6.5128e-04j],
        [-3.5379e-04-1.2510e-04j, -1.0291e-03+8.5400e-04j,
         -4.9796e-05-3.5351e-07j,  ...,
          1.4543e-04-2.0759e-04j, -4.1172e-05+1.8797e-06j,
          7.3215e-05+8.5935e-04j]], dtype=torch.complex128)


In [92]:
np.isclose(x_rfs_vec, x_rfs_concatenated)

array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True]])

In [93]:
np.sum(np.isclose(x_rfs_vec, x_rfs_concatenated)) # should be L * num_rand_features = 4 * 100000 = 400,000

400000

In [94]:
# # old code, for reference
# def input_to_rfs_torch_vectorized(xw, AB_fun, ab_fun, xis, num_rfs, dim, device,
#                                   seed=0, normalize=False, normalization_constant=None,
#                                   orthogonal=False, proj_matrix=None):
#     if normalize :
#       if normalization_constant is None :
#         xw = torch.nn.functional.normalize(xw)
#       else :
#         xw = normalization_constant*torch.nn.functional.normalize(xw)

#     ab_coeffs = torch_apply_along_axis(ab_fun, xis, 0)
#     AB_coeffs = torch_apply_along_axis(AB_fun, xis, 0)
#     torch.manual_seed(seed)

#     if proj_matrix is None:
#       if device == 'cpu':
#         if orthogonal is False :
#           gs = torch.rand(size=(num_rfs, dim))
#         else :
#           gs = gaussian_orthogonal_random_matrix(num_rfs, dim, scaling = 0, device = 'cpu')
#       else :
#         if orthogonal is False :
#           gs = torch.rand(size=(num_rfs, dim)).cuda()
#         else :
#           gs = gaussian_orthogonal_random_matrix(num_rfs, dim, scaling = 0, device = 'cuda')
#     else :
#       if device == 'cpu':
#         gs = proj_matrix
#       else :
#          gs = proj_matrix.cuda()
#     renorm_gs = (ab_coeffs * gs.t()).t()
#     if len(xw.shape) == 2 :
#       dot_products = torch.einsum('ij,jk->ik', xw, renorm_gs.t())
#     elif len(xw.shape) == 3:
#       dot_products = torch.einsum('bij,jk->bik', xw, renorm_gs.t())
#     else :
#       raise ValueError("Unsuported Tensor shape")
#     squared_xw = torch.sum(torch.mul(xw, xw), dim=-1) #do not keepdims here
#     if len(squared_xw.shape) == 1 :
#       correction_vector = torch.outer(squared_xw / 2, torch.mul(ab_coeffs, ab_coeffs))
#     elif len(squared_xw.shape) == 2 :
#       correction_vector = torch.einsum('pq, r->pqr', squared_xw, torch.mul(ab_coeffs, ab_coeffs))
#     else :
#       raise ValueError("Unsupported tensor shape of xw")
#     diff_vector = dot_products - correction_vector
#     return (1.0 / math.sqrt(num_rfs)) * AB_coeffs * torch.exp(diff_vector)


# # class NNK(nn.Module) :
# #   def __init__(self, input_weights, A_fun, a_fun, xis, num_rfs, dim, model_device, seed=0, \
# #                normalize=False, normalization_constant=None, orthogonal=False, proj_matrix=None):
# #         super().__init__()
# #         self.input_weights = input_weights
# #         self.A_fun = A_fun
# #         self.a_fun = a_fun
# #         self.xis = xis
# #         self.num_rfs = num_rfs
# #         self.dim = dim
# #         self.model_device = model_device
# #         self.seed = seed
# #         self.normalize = normalize
# #         self.normalization_constant = normalization_constant
# #         self.orthogonal = orthogonal
# #         self.proj_matrix = proj_matrix

# #         self.weights = input_to_rfs_torch_vectorized(xw=self.input_weights, AB_fun=self.A_fun, ab_fun=self.a_fun, xis=self.xis, \
# #                                                      num_rfs=self.num_rfs, dim=self.dim, device=self.model_device, seed=self.seed,
# #                                                      normalize=self.normalize, normalize_constant=self.normalize_constant, \
# #                                                      orthogonal=self.orthogonal, proj_matrix=self.proj_matrix)
# #         self.weights = nn.Parameter(self.weights)
# #         # TODO: ADD BIAS

# #   def forward(self, x):
# #         output_x = input_to_rfs_torch_vectorized(xw=x, AB_fun=self.A_fun, ab_fun=self.a_fun, xis=self.xis, \
# #                                                      num_rfs=self.num_rfs, dim=self.dim, device=self.model_device,
# #                                                      seed=self.seed, normalize=self.normalize, \
# #                                                      normalize_constant=self.normalize_constant, \
# #                                                      orthogonal=self.orthogonal, proj_matrix=self.proj_matrix
# #                                                     )
# #         return output_x @ self.weights.t()