In [1]:
import torch
import numpy as np
torch.cuda.is_available()

True

In [678]:
def random_init(M,dev):
    a1 = torch.rand(1) + 0.5
    a2 = torch.rand(1) + 0.5
    a3 = torch.rand(1) + 0.5
    a4 = torch.rand(1) + 0.5
    
    w1, index = torch.sort(torch.rand(M)*torch.pi)
    w2, index = torch.sort(torch.rand(M)*torch.pi)
    
    a1 = a1.to(dev)
    a1.requires_grad_(True)
    
    a2 = a2.to(dev)
    a2.requires_grad_(True)
    
    a3 = a3.to(dev)
    a3.requires_grad_(True)
    
    a4 = a4.to(dev)
    a4.requires_grad_(True)
    
    w1 = w1.to(dev)
    w1.requires_grad_(True)
    
    w2 = w2.to(dev)
    w2.requires_grad_(True)
    
    
    return a1,a2,a3,a4,w1,w2

In [930]:
def orthogonal_init(M,dev):
    a1 = torch.tensor(1.0, requires_grad = True, device = dev)
    a2 = torch.tensor(np.sqrt(2), requires_grad = True, device = dev) # sqrt(2)
    a3 = torch.tensor(1.0, requires_grad = True, device = dev)
    a4 = torch.tensor(np.sqrt(2), requires_grad = True, device = dev)

    
    w1 = (torch.arange(M, device = dev)*2+1)/(2*M)*torch.pi
    w2 = (torch.arange(M, device = dev)*2+1)/(2*M)*torch.pi
    w1.requires_grad_(True)
    w2.requires_grad_(True)
    # w1.retain_grad()
    # w2.retain_grad()
    
    return a1,a2,a3,a4,w1,w2

In [679]:
def forward_pass_id(x,a1,a2,a3,a4,w1,dev):
    N = x.size()[0]
    M = w1.size()[0]
    
    # weight matrix W1
    W1 = a2 * torch.cos(torch.outer(w1,torch.arange(N, device = dev)))
    W1[:,0] = a1
    
    # frequency domain X
    X = torch.matmul(W1,x)/np.sqrt(N)
    
    # weight matrix W2_1 with same frequency components w1
    W2_1 = a4 * torch.cos(torch.outer(torch.arange(N, device = dev),w1))
    W2_1[0] = a3
            
    y = torch.matmul(W2_1,X)/np.sqrt(M)
    
    return X,y

In [680]:
def forward_pass_dif(x,a1,a2,a3,a4,w1,w2,dev):
    N = x.size()[0]
    M = w1.size()[0]
    
    # weight matrix W1
    W1 = a2 * torch.cos(torch.outer(w1,torch.arange(N, device = dev)))
    W1[:,0] = a1
    
    # frequency domain X
    X = torch.matmul(W1,x)/np.sqrt(N)
    
    # weight matrix W2_2 with different frequency components w2
    W2_2 = a4 * torch.cos(torch.outer(torch.arange(N, device = dev),w2))
    W2_2[0] = a3
            
    y = torch.matmul(W2_2,X)/np.sqrt(M)
    
    return X,y

In [681]:
def loss(x,y):
    return ((x-y)**2).sum()

In [1718]:
N = 1000
M = 1000
dev = torch.device("cpu")

In [1726]:
a1,a2,a3,a4,w1,w2 = random_init(M,dev)
print(f'a1 = {a1}, a2 = {a2}, a3 = {a3}, a4 = {a4}, w1 = {w1}, w2 = {w2}')

a1 = tensor([1.3394], requires_grad=True), a2 = tensor([0.7492], requires_grad=True), a3 = tensor([0.9715], requires_grad=True), a4 = tensor([1.3431], requires_grad=True), w1 = tensor([0.0040, 0.0050, 0.0060, 0.0154, 0.0179, 0.0193, 0.0238, 0.0268, 0.0285,
        0.0304, 0.0325, 0.0375, 0.0448, 0.0487, 0.0586, 0.0635, 0.0638, 0.0702,
        0.0783, 0.0815, 0.0847, 0.0852, 0.0999, 0.1062, 0.1084, 0.1103, 0.1111,
        0.1138, 0.1139, 0.1226, 0.1242, 0.1336, 0.1366, 0.1391, 0.1406, 0.1473,
        0.1482, 0.1490, 0.1511, 0.1521, 0.1543, 0.1555, 0.1620, 0.1644, 0.1647,
        0.1682, 0.1691, 0.1724, 0.1726, 0.1753, 0.1776, 0.1781, 0.1785, 0.1903,
        0.1911, 0.1978, 0.1984, 0.2004, 0.2035, 0.2064, 0.2074, 0.2093, 0.2097,
        0.2100, 0.2125, 0.2142, 0.2234, 0.2246, 0.2248, 0.2253, 0.2375, 0.2411,
        0.2423, 0.2429, 0.2468, 0.2490, 0.2527, 0.2539, 0.2611, 0.2720, 0.2729,
        0.2878, 0.2913, 0.2937, 0.2946, 0.2975, 0.3024, 0.3078, 0.3091, 0.3120,
        0.3131, 0.3139,

In [1727]:
x = torch.rand(N)*2 - 1
x

tensor([-0.8321,  0.1716,  0.2912,  0.0097,  0.5804, -0.0315,  0.4208, -0.9788,
        -0.6965,  0.1288, -0.1852,  0.0997,  0.6989,  0.6855,  0.8883,  0.3532,
        -0.2019, -0.7404,  0.0165, -0.6417, -0.5419,  0.2563,  0.7534, -0.7029,
        -0.4713,  0.0243, -0.8625, -0.0589,  0.6355,  0.0081,  0.0406,  0.8726,
        -0.4646, -0.3660,  0.0552,  0.4737, -0.0621, -0.7686, -0.2630, -0.8168,
         0.8152, -0.3126, -0.4358, -0.2170, -0.2621,  0.6809,  0.2015,  0.7723,
        -0.2445, -0.2938, -0.2204, -0.3042, -0.2486, -0.0970, -0.5632, -0.9486,
         0.6107, -0.1353, -0.0328,  0.7308,  0.4608, -0.7156,  0.8832,  0.7938,
         0.4621, -0.0589,  0.0595,  0.4611, -0.7248, -0.9268, -0.8729,  0.8436,
        -0.1111, -0.1486,  0.3076,  0.6013,  0.8056, -0.8143,  0.2173,  0.7238,
         0.1814, -0.0570, -0.8819,  0.3411,  0.2098, -0.9269,  0.4553, -0.8590,
         0.5324,  0.0160, -0.1348, -0.5659,  0.6818, -0.0664,  0.6412, -0.9408,
         0.3775, -0.9061,  0.0540, -0.76

In [1728]:
X,y = forward_pass_dif(x,a1,a2,a3,a4,w1,w2,dev)

In [1729]:
ls = loss(x,y)
ls

tensor(445.8278, grad_fn=<SumBackward0>)

In [1730]:
ls.backward()

In [1731]:
print(f'a1 = {a1.grad}, a2 = {a2.grad}, a3 = {a3.grad}, a4 = {a4.grad}, w1 = {w1.grad.abs().mean()}, w2 = {w2.grad.abs().mean()}')

a1 = tensor([1.8689]), a2 = tensor([299.4886]), a3 = tensor([-0.2547]), a4 = tensor([169.1209]), w1 = 167.9651336669922, w2 = 157.51023864746094


In [1732]:
print(f'a1 = {a1.grad/((M+N)/np.sqrt(M*N))}, a2 = {a2.grad/np.sqrt(M*N)}, a3 = {a3.grad}, a4 = {a4.grad*1.4/(N-1)}, w1 = {w1.grad/N*np.sqrt(M)/10}, w2 = {w2.grad/N*np.sqrt(M)/5}')

a1 = tensor([0.9345]), a2 = tensor([0.2995]), a3 = tensor([-0.2547]), a4 = tensor([0.2370]), w1 = tensor([ 6.1596e-02,  1.1576e-01,  3.6485e-01, -1.5096e-01,  1.7334e+00,
        -7.3301e-02, -7.2116e-01, -1.4447e+00,  2.2265e+00,  1.6595e+00,
        -2.5194e+00, -8.5395e-02, -2.1497e+00, -4.2662e+00, -7.1463e-01,
        -3.6205e-01, -6.2151e-01, -4.4275e-01,  7.1217e-01,  7.4931e-01,
        -1.0352e+00, -8.6107e-01,  4.0265e-02, -5.6640e-01,  4.9736e-01,
        -1.4373e+00, -2.6841e+00, -1.1649e-01, -2.4010e-02,  1.9442e+00,
        -1.0273e+00,  1.3778e+00, -1.0739e-01,  6.1562e-01,  1.9376e+00,
         2.8567e+00,  2.3924e-01,  2.0818e+00,  1.7962e-01, -5.5108e-01,
         4.3817e-01,  8.5331e-02, -3.3411e-02, -3.6631e-01, -9.7426e-01,
         3.9662e-01, -6.3477e-01,  7.8097e-02,  2.8550e-01,  1.7111e+00,
         2.5353e-01, -3.0644e-01, -9.4516e-01,  4.2973e-01, -3.2058e-01,
        -1.8733e+00, -1.3475e+00, -2.1873e-01,  1.0338e+00, -8.3066e-02,
        -2.7514e-01, -3.84

In [1609]:
a1.grad/(2*np.sqrt(M/N))

tensor([-0.2444])

In [1610]:
a2.grad/np.sqrt(M*N)

tensor([0.4592])

In [1611]:
a4.grad*np.sqrt(2)/(N-1)

tensor([0.4465])

In [1633]:
w1.grad.abs().mean()/N*np.sqrt(M)

tensor(1.5484)

In [1632]:
w2.grad.abs().median()/N*np.sqrt(M)

tensor(0.8511)

In [1146]:
0.5*2/np.pi

0.3183098861837907

In [876]:
w2.grad.max()

tensor(1705.6671)

In [1215]:
X

tensor([ 0.1115,  0.2275, -0.4555,  0.3689,  0.6550,  0.5515,  0.3438, -0.6631,
         0.8158, -0.1205,  0.2374, -0.3639, -0.8425, -0.6010, -0.5199,  0.0020,
         0.5588, -0.0182, -0.2570,  0.3402, -0.1640,  0.4977, -0.3149,  0.0290,
        -0.2168, -0.0064,  0.1437,  0.1634, -0.2613, -0.3131, -0.2773,  0.1989,
         0.3823,  0.3129, -0.6933, -0.0030, -0.3549,  0.6032,  0.6331, -0.1449,
         0.1356,  0.1581, -0.4372, -0.0487,  0.0219,  0.1967, -0.3185, -0.0303,
         0.8918, -0.0673], grad_fn=<DivBackward0>)

#### Find numerical relations

In [837]:
nn = 100
gradient = torch.zeros(nn,6, device = dev)

for i in range(nn):
    a1,a2,a3,a4,w1,w2 = random_init(M,dev)
    X,y = forward_pass_dif(x,a1,a2,a3,a4,w1,w2,dev)
    ls = loss(x,y)
    ls.backward()
    gradient[i,0] = a1.grad.abs().item()
    gradient[i,1] = a2.grad.abs().item()
    gradient[i,2] = a3.grad.abs().item()
    gradient[i,3] = a4.grad.abs().item()
    gradient[i,4] = w1.grad.abs().median()
    gradient[i,5] = w2.grad.abs().median()
    

In [756]:
ratio = torch.div(gradient,gradient[:,1:2])

In [205]:
# ratio.median(0)

torch.return_types.median(
values=tensor([0.0052, 1.0000, 0.0017, 1.0102, 0.4305, 0.3658]),
indices=tensor([ 9, 49, 61, 64, 72, 44]))

In [609]:
ratio.median(0)

torch.return_types.median(
values=tensor([8.8082e-04, 1.0000e+00, 9.8251e-04, 9.0350e-01, 5.0594e-01, 4.0735e-01]),
indices=tensor([44, 49, 82, 52, 83, 21]))

In [588]:
# a4.grad/a3.grad = (N-1)*(2/pi)
(N-1)*2/np.pi

635.9831525952138

In [589]:
a4_a3 = torch.div(gradient[:,3],gradient[:,2])
a4_a3.max()

tensor(380492.9688)

In [590]:
a4_a3.min()

tensor(47.7600)

In [591]:
a4_a3.mean()

tensor(11362.4600)

In [592]:
a4_a3.median()

tensor(932.8977)

In [593]:
a4_a3 = torch.div(gradient[:,3].mean(),gradient[:,2].mean())
a4_a3

tensor(571.3546)

We need to balance the gradients for all parameters.

#### Gradient Calculation

#### Forward pass:

1. transformation phase

$$
X_k = \frac{1}{\sqrt{N}} \left( a_1 x_0 + a_2 \sum_{n=1}^{N-1} \cos(w_k\cdot n)x_n \right)
$$

$$
\begin{bmatrix}
X_0 \\ X_1 \\ \vdots \\ X_{M-1}
\end{bmatrix}_{M \times 1}
= 
\frac{1}{\sqrt{N}}
\begin{bmatrix}
a_{1} & a_{2}\cos(w_0\cdot 1) & \cdots & a_{2}\cos(w_0\cdot(N-1)) \\ 
a_{1} & a_{2}\cos(w_1\cdot 1) & \cdots & a_{2}\cos(w_1\cdot(N-1)) \\ 
\vdots & \vdots &  & \vdots \\ 
a_{1} & a_{2}\cos(w_{M-1}\cdot 1) & \cdots & a_{2}\cos(w_{M-1}\cdot(N-1))
\end{bmatrix}_{M \times N}
\begin{bmatrix}
x_{0} \\ x_1 \\ \vdots \\ x_{N-1}
\end{bmatrix}_{N \times 1}
$$

2. reconstruction phase

$$
\begin{cases}
y_0 = \frac{a_3}{\sqrt{M}} \sum_{k=0}^{M-1} X_k&\\
y_n = \frac{a_4}{\sqrt{M}} \sum_{k=0}^{M-1}  \cos(\tilde{w}_k\cdot n)X_k & n = 1,\dots,N-1
\end{cases}
$$

$$
\begin{bmatrix}
y_0 \\ y_1 \\ \vdots \\ y_{N-1}
\end{bmatrix}_{N \times 1}
= 
\frac{1}{\sqrt{M}}
\begin{bmatrix}
a_{3} & a_{3} & \cdots & a_{3} \\ 
a_{4}\cos(\tilde{w}_0\cdot 1) & a_{4}\cos(\tilde{w}_1\cdot 1) & \cdots & a_{4}\cos(\tilde{w}_{M-1}\cdot 1) \\ 
\vdots & \vdots &  & \vdots \\ 
a_{4}\cos(\tilde{w}_0\cdot(N-1)) & a_{4}\cos(\tilde{w}_1\cdot(N-1)) & \cdots & a_{4}\cos(\tilde{w}_{M-1}\cdot(N-1))
\end{bmatrix}_{N \times M}
\begin{bmatrix}
X_{0} \\ X_1 \\ \vdots \\ X_{M-1}
\end{bmatrix}_{M \times 1}
$$

#### Loss

$$
l = ||\mathrm{y-x}||_{L_2}^2
$$

#### Backward pass
$$ 
\frac{\partial l}{\partial \mathrm{y}} = 2(\mathrm{y-x})
$$

$a_3$ is shared via $M$ connections,

$$
\frac{\partial y_0}{\partial a_3} = \frac{1}{\sqrt{M}} \sum_{k=0}^{M-1} X_k
$$

In [838]:
y = y.detach()
x = x.detach()
X = X.detach()

In [839]:
l = ((y-x)**2).sum()
l

tensor(351.6760)

In [840]:
ls

tensor(351.6760, grad_fn=<SumBackward0>)

In [841]:
dl_dy = 2 * (y-x)
dl_dy[0]

tensor(0.1805)

In [858]:
X.sum()

tensor(6.6418)

In [842]:
dy0_da3 = X.sum()/np.sqrt(M)
dy0_da3

tensor(0.2100)

In [843]:
dl_da3 = dl_dy[0] * dy0_da3
dl_da3

tensor(0.0379)

In [844]:
a3.grad

tensor([0.0379])

In [845]:
a3.grad.abs().item()

0.03791802003979683

$$
\frac{\partial y_n}{\partial a_4} = \frac{1}{\sqrt{M}} \sum_{k=0}^{M-1}  \cos(\tilde{w}_k\cdot n)X_k , \ n = 1,\dots,N-1
$$

$$
\begin{bmatrix}
\frac{\partial y_1}{\partial a_4} \\ \frac{\partial y_2}{\partial a_4} \\ \vdots \\ \frac{\partial y_{N-1}}{\partial a_4}
\end{bmatrix}_{(N-1) \times 1}
= 
\frac{1}{\sqrt{M}}
\begin{bmatrix}
\cos(\tilde{w}_0\cdot 1) & \cos(\tilde{w}_1\cdot 1) & \cdots & \cos(\tilde{w}_{M-1}\cdot 1) \\ 
\cos(\tilde{w}_0\cdot 2) & \cos(\tilde{w}_1\cdot 2) & \cdots & \cos(\tilde{w}_{M-1}\cdot 2) \\ 
\vdots & \vdots &  & \vdots \\ 
\cos(\tilde{w}_0\cdot(N-1)) & \cos(\tilde{w}_1\cdot(N-1)) & \cdots & \cos(\tilde{w}_{M-1}\cdot(N-1))
\end{bmatrix}_{(N-1) \times M}
\begin{bmatrix}
X_{0} \\ X_1 \\ \vdots \\ X_{M-1}
\end{bmatrix}_{M \times 1}
$$

$$
\frac{\partial l}{\partial a_4} = \sum_{n=1}^{N-1}\frac{\partial l}{\partial y_n}\cdot\frac{\partial y_n}{\partial a_4} = 
2[y_1 - x_1, y_2 - x_2, \dots, y_{N-1} - x_{N-1}]\begin{bmatrix}
\frac{\partial y_1}{\partial a_4} \\ \frac{\partial y_2}{\partial a_4} \\ \vdots \\ \frac{\partial y_{N-1}}{\partial a_4}
\end{bmatrix}
$$

In [846]:
cos_matrix = torch.cos(torch.outer(torch.arange(1,N),w2))
dy_da4 = torch.matmul(cos_matrix,X)/np.sqrt(M)

In [847]:
dy_da4.abs().mean()*np.pi/2

tensor(0.2745, grad_fn=<DivBackward0>)

In [848]:
dy0_da3

tensor(0.2100)

In [849]:
dl_da4 = (dl_dy[1:]*dy_da4).sum()
dl_da4

tensor(75.3291, grad_fn=<SumBackward0>)

In [850]:
a4.grad

tensor([75.3291])

By summing $N-1$ terms up, the gradient of $a_4$ scales up comparing to the gradient of $a_3$, since $a_3$ is shared via $M$ connections, while $a_4$ is shared via $M \times (N-1)$ connections.

However, the sum over $k$ for $a_4$ is scaled by cosine terms, which brings $X$ to \{$\frac{\partial y_n}{\partial a_4}\}_{n = 1,\dots,N-1}$, which forms a distribution over the range of $(-|\frac{\partial y_0}{\partial a_3}|,|\frac{\partial y_0}{\partial a_3}|)$. This can be seen by comparing

$$
\frac{\partial y_0}{\partial a_3} = \frac{1}{\sqrt{M}} \sum_{k=0}^{M-1} X_k
$$

and

$$
\frac{\partial y_n}{\partial a_4} = \frac{1}{\sqrt{M}} \sum_{k=0}^{M-1}  \cos(\tilde{w}_k\cdot n)X_k , \ n = 1,\dots,N-1
$$

Therefore, we need to consider the effect of cosine terms and scale $\frac{\partial l}{\partial a_4}/(N-1)$ up by certain amount. A reasonable factor would be the cosine absolute average,
$$
\frac{1}{\pi} \int_{-\frac{\pi}{2}}^\frac{\pi}{2} \cos(x) dx = \frac{2}{\pi}
$$

Therefore, we multiply $\frac{\partial l}{\partial a_4}$ by $\frac{\pi}{2(N-1)}$ to balance its gradient with $a_3$.

In [770]:
np.pi/2/(N-1)

0.0015723686954903868

In [851]:
dl_da4*np.pi/2/(N-1)

tensor(0.1184, grad_fn=<DivBackward0>)

In [1122]:
dl_da4*np.sqrt(2)/(N-1)

tensor(0.1066, grad_fn=<DivBackward0>)

In [1123]:
dl_da3

tensor(0.0379)

$$
\frac{\partial y_n}{\partial \tilde{w}_k} = -\frac{a_4}{\sqrt{M}} \cdot n\sin(\tilde{w}_k\cdot n)X_k , \ n = 1,\dots,N-1
$$



In [1080]:
np.pi/2

1.5707963267948966

In [853]:
X.abs().mean()

tensor(0.2203)

In [862]:
X.abs().median()

tensor(0.1851)

In [774]:
1/(np.sqrt(2)/np.pi*0.38)

5.845898602839955

In [636]:
6*M/N/(N-1)

0.011741682974559686

In [1362]:
w1.grad.abs().mean()

AttributeError: 'NoneType' object has no attribute 'abs'

In [855]:
w1.grad.abs().max() * M/N/(N-1)

tensor(0.6757)

In [856]:
w1.grad.abs().min() * M/N/(N-1)

tensor(2.1144e-05)

In [857]:
w1.grad.abs().median() * M/N/(N-1)

tensor(0.0416)

In [863]:
a1.grad /np.sqrt(N)

tensor([0.0052])

In [864]:
a3.grad

tensor([0.0379])

#### Approxiamation

In [1420]:
N = 1000
M = 800
dev = torch.device("cpu")

In [1421]:
a1,a2,a3,a4,w1,w2 = random_init(M,dev)

In [1413]:
x = torch.rand(N,800)*2 - 1

In [1414]:
X,y = forward_pass_dif(x,a1,a2,a3,a4,w1,w2,dev)

In [1415]:
X

tensor([[ 0.4598, -0.3606,  0.6415,  ...,  0.4045, -0.3895, -0.6628],
        [-0.4721,  0.1397, -0.8591,  ...,  0.1167,  0.2939, -0.5597],
        [-0.3645,  0.4974, -0.5694,  ..., -0.3329,  0.5015, -0.5941],
        ...,
        [ 0.7418, -0.4864,  0.0994,  ..., -0.1749,  0.7142,  0.3974],
        [ 0.8384, -0.5902,  0.1772,  ..., -0.3368,  0.7546,  0.6148],
        [-0.0180, -0.1045,  0.5986,  ..., -0.2249,  0.0555,  0.3196]],
       grad_fn=<DivBackward0>)

In [1266]:
y

tensor([[ 0.1465, -0.0086,  0.6379,  ..., -0.6438, -0.2735, -0.5575],
        [ 0.9485, -0.2307, -0.6410,  ...,  0.4648,  0.8622,  0.7022],
        [ 0.4927,  0.9810, -0.4208,  ..., -0.7656, -0.2328, -0.9784],
        ...,
        [ 0.1194,  0.5219,  0.5036,  ...,  0.5356, -0.6212, -0.0105],
        [-0.3770,  0.8338, -0.4038,  ..., -0.4052,  0.4597,  0.0815],
        [-0.4295, -0.7324,  0.9020,  ..., -0.8081,  0.6354, -0.5520]],
       grad_fn=<DivBackward0>)

In [1070]:
X.abs().max()

tensor(1.1546, grad_fn=<MaxBackward1>)

In [1071]:
X.abs().mean()

tensor(0.1992, grad_fn=<MeanBackward0>)

In [1072]:
X.abs().median()

tensor(0.1686, grad_fn=<MedianBackward0>)

In [1073]:
y[1:].abs().max()

tensor(1.4739, grad_fn=<MaxBackward1>)

In [1074]:
y[1:].abs().mean()

tensor(0.2155, grad_fn=<MeanBackward0>)

In [1075]:
y[1:].abs().median()

tensor(0.1807, grad_fn=<MedianBackward0>)

In [1076]:
y[:,0].size()

torch.Size([1000])

In [1077]:
y[:,0].abs().max()

tensor(0.9931, grad_fn=<MaxBackward1>)

In [1078]:
y[:,0].abs().mean()

tensor(0.2051, grad_fn=<MeanBackward0>)

In [1079]:
y[:,0].abs().median()

tensor(0.1721, grad_fn=<MedianBackward0>)

In [1081]:
0.4/np.sqrt(2)

0.282842712474619

In [1121]:
(torch.matmul(torch.arange(0,N).view(1,-1).type(torch.float32),x).sum(0).abs()/62500).mean()

tensor(0.1328)

In [1116]:
(N*(N-2)/4+N/2)/2/2

62500.0

In [1618]:
w = 3/N*2*np.pi
w = 0.01
torch.cos(w*torch.arange(1,N)).sum()

tensor(-54.4821)

In [1245]:
w = 2

In [1246]:
torch.inner(torch.sin(w*torch.arange(1,N)),1.0*torch.arange(1,N))

tensor(-346.7193)

In [1360]:
N=500

In [1361]:
torch.matmul(torch.sin(torch.outer(w2,torch.arange(0,N))),1.0*torch.arange(0,N)).abs()/N

tensor([2.0214e+02, 2.2016e+01, 7.6059e+00, 3.6358e+00, 2.0019e+00, 1.1749e+00,
        6.9928e-01, 4.0081e-01, 2.0140e-01, 6.1494e-02, 4.0325e-02, 1.1676e-01,
        1.7559e-01, 2.2184e-01, 2.5890e-01, 2.8898e-01, 3.1375e-01, 3.3441e-01,
        3.5180e-01, 3.6659e-01, 3.7924e-01, 3.9022e-01, 3.9972e-01, 4.0808e-01,
        4.1542e-01, 4.2190e-01, 4.2768e-01, 4.3281e-01, 4.3744e-01, 4.4158e-01,
        4.4536e-01, 4.4874e-01, 4.5189e-01, 4.5464e-01, 4.5727e-01, 4.5959e-01,
        4.6181e-01, 4.6378e-01, 4.6565e-01, 4.6731e-01, 4.6894e-01, 4.7038e-01,
        4.7177e-01, 4.7305e-01, 4.7433e-01, 4.7534e-01, 4.7633e-01, 4.7739e-01,
        4.7827e-01, 4.7914e-01, 4.7990e-01, 4.8072e-01, 4.8148e-01, 4.8208e-01,
        4.8276e-01, 4.8339e-01, 4.8399e-01, 4.8468e-01, 4.8501e-01, 4.8550e-01,
        4.8601e-01, 4.8638e-01, 4.8688e-01, 4.8723e-01, 4.8767e-01, 4.8801e-01,
        4.8829e-01, 4.8867e-01, 4.8901e-01, 4.8934e-01, 4.8961e-01, 4.8989e-01,
        4.9018e-01, 4.9044e-01, 4.9068e-

In [1359]:
torch.matmul(torch.sin(torch.outer(w2,torch.arange(0,N))),1.0*torch.arange(0,N)).size()

torch.Size([500])

In [1595]:
N = 1500
M = 512
dev = torch.device("cpu")

In [1596]:
a1,a2,a3,a4,w1,w2 = random_init(M,dev)

In [1597]:
W2_2 = a4 * torch.cos(torch.outer(torch.arange(N, device = dev),w2))
W2_2[0] = a3
W2_2

tensor([[ 1.4552,  1.4552,  1.4552,  ...,  1.4552,  1.4552,  1.4552],
        [ 1.3083,  1.3083,  1.3082,  ..., -1.3080, -1.3081, -1.3083],
        [ 1.3083,  1.3083,  1.3078,  ...,  1.3072,  1.3075,  1.3082],
        ...,
        [ 1.1120,  1.0567, -0.2917,  ..., -0.9719, -1.0894,  0.2195],
        [ 1.1147,  1.0531, -0.3092,  ...,  0.9896,  1.0769, -0.2107],
        [ 1.1173,  1.0495, -0.3267,  ..., -1.0070, -1.0641,  0.2012]],
       grad_fn=<CopySlices>)

In [1598]:
torch.matmul(torch.transpose(W2_2,0,1),W2_2)

tensor([[ 1.1854e+03,  1.0134e+03,  1.0742e+02,  ...,  6.9383e-01,
          6.7053e-01,  1.3718e+00],
        [ 1.0134e+03,  1.3735e+03,  1.1666e+02,  ...,  7.2977e-01,
          7.0789e-01,  1.3650e+00],
        [ 1.0742e+02,  1.1666e+02,  1.2692e+03,  ...,  1.4327e+00,
          1.4395e+00,  1.2280e+00],
        ...,
        [ 6.9383e-01,  7.2980e-01,  1.4326e+00,  ...,  1.2641e+03,
         -2.4946e+02,  3.6114e+01],
        [ 6.7054e-01,  7.0789e-01,  1.4395e+00,  ..., -2.4946e+02,
          1.3083e+03,  3.1293e+01],
        [ 1.3718e+00,  1.3650e+00,  1.2280e+00,  ...,  3.6114e+01,
          3.1293e+01,  1.3022e+03]], grad_fn=<MmBackward0>)

In [1599]:
W2_2.sum(1)

tensor([745.0869,  -6.8673, -10.8802,  ...,  29.4976,  10.6277, -25.1379],
       grad_fn=<SumBackward1>)

In [1600]:
W2_2[1:,:].sum()

tensor(-268.8798, grad_fn=<SumBackward0>)