In [121]:
!pip install torch



In [122]:
import math
import torch
from torch import nn
import numpy as np

def torch_apply_along_axis(function, x, axis: int = 0):
    """
    Torch equivalent of numpy apply along axis. This function is slow and should be avoided
    https://discuss.pytorch.org/t/apply-a-function-along-an-axis/130440
    """
    return torch.stack([
        function(x_i) for x_i in torch.unbind(x, dim=axis)
    ], dim=axis)

def input_to_rfs_torch(xw, AB_fun, ab_fun, xis, num_rfs, dim):
    ab_coeffs = torch_apply_along_axis(ab_fun, xis, 0)
    AB_coeffs = torch_apply_along_axis(AB_fun, xis, 0)
    torch.manual_seed(0)
    gs = torch.rand(size=(num_rfs, dim))
    renorm_gs = (ab_coeffs * gs.t()).t()
    dot_products = torch.einsum('ij,j->i', renorm_gs, xw)
    squared_xw = torch.sum(xw * xw)
    correction_vector = (squared_xw / 2) * ab_coeffs * ab_coeffs
    diff_vector = dot_products - correction_vector
    return (1.0 / math.sqrt(num_rfs)) * AB_coeffs * torch.exp(diff_vector)

def input_to_rfs_torch_vectorized(xw, AB_fun, ab_fun, xis, num_rfs, dim):
    ab_coeffs = torch_apply_along_axis(ab_fun, xis, 0)
    AB_coeffs = torch_apply_along_axis(AB_fun, xis, 0)
    torch.manual_seed(0)
    gs = torch.rand(size=(num_rfs, dim))
    renorm_gs = (ab_coeffs * gs.t()).t()
    dot_products = torch.einsum('ij,jk->ik', xw, renorm_gs.t())
    squared_xw = torch.sum(torch.mul(xw, xw), dim=1)
    correction_vector = torch.outer(squared_xw / 2, torch.mul(ab_coeffs, ab_coeffs))
    diff_vector = dot_products - correction_vector
    return (1.0 / math.sqrt(num_rfs)) * AB_coeffs * torch.exp(diff_vector)

class mynetwork(nn.Module):
    def __init__(self, w):
        super().__init__()
        self.w = w
        self.weights = input_to_rfs_torch(self.w, A_fun, a_fun, xis, num_rfs, dim)
        self.weights = nn.Parameter(self.weights)
        # self.bias = nn.Parameter(torch.zeros(10))

    def forward(self, x):
        xb = input_to_rfs_torch(x, A_fun, a_fun, xis, num_rfs, dim)
        return xb @ self.weights

In [123]:
###################### TEST
dim = 5
x = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0]).float()
w = torch.Tensor([5.0, 4.0, 3.0, 2.0, 1.0]).float()

bias = torch.Tensor([0.0])
groundtruth_value = torch.cos(torch.dot(x, w)+bias)
num_rfs = 10000
# a_fun= lambda xi: 2.0 * math.pi * 1j * xi
# b_fun= lambda x: 1
# A_fun= lambda x: np.exp(bias)
# B_fun= lambda x: 1
a_fun = lambda x: np.sin(x)
b_fun = lambda x: np.cos(x)
A_fun = lambda x: np.sin(x)
B_fun = lambda x: np.cos(x)

xis_creator = lambda x: 1.0 / (2.0 * math.pi) * (x > 0.5) - 1.0 / (2.0 * math.pi) * (x < 0.5)
random_tosses = torch.rand(num_rfs)
xis = xis_creator(random_tosses)

In [124]:
# test vectorized version
x_vec = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [5.0, 4.0, 3.0, 2.0, 1.0]]).float()
x_rfs_vec = input_to_rfs_torch_vectorized(x_vec, A_fun, a_fun, xis, num_rfs, dim)
print('vectorized output: ', x_rfs_vec)
x_rfs = input_to_rfs_torch(x, A_fun, a_fun, xis, num_rfs, dim)
print('non-vectorized output: ', x_rfs)
x_2 = torch.Tensor([5.0, 4.0, 3.0, 2.0, 1.0]).float()
x_rfs_2 = input_to_rfs_torch(x_2, A_fun, a_fun, xis, num_rfs, dim)
print('non-vectorized output (2nd vector): ', x_rfs_2)

vectorized output:  tensor([[ 0.0016,  0.0035,  0.0014,  ..., -0.0001,  0.0026,  0.0016],
        [ 0.0022,  0.0035,  0.0015,  ..., -0.0002,  0.0021,  0.0012]])
non-vectorized output:  tensor([ 0.0016,  0.0035,  0.0014,  ..., -0.0001,  0.0026,  0.0016])
non-vectorized output (2nd vector):  tensor([ 0.0022,  0.0035,  0.0015,  ..., -0.0002,  0.0021,  0.0012])


In [126]:
x_rfs = input_to_rfs_torch(x, A_fun, a_fun, xis, num_rfs, dim)
w_rfs = input_to_rfs_torch(w, B_fun, b_fun, xis, num_rfs, dim)

print(torch.dot(x_rfs, w_rfs))
print(groundtruth_value) # not great

net = mynetwork(w)
# real stupid test
for i in range(5):
    l = net(x)
    print(l)
    l.backward() #quite slow

tensor(5.2822e-09)
tensor([-0.9037])
tensor(0.0416, grad_fn=<DotBackward0>)
tensor(0.0416, grad_fn=<DotBackward0>)
tensor(0.0416, grad_fn=<DotBackward0>)
tensor(0.0416, grad_fn=<DotBackward0>)
tensor(0.0416, grad_fn=<DotBackward0>)
