In [1]:
import torch
import numpy as np
torch.cuda.is_available()

True

### Define functions

In [2]:
def random_init(M,dev):
    a1 = torch.rand(1) + 0.5
    a2 = torch.rand(1) + 0.5
    a3 = torch.rand(1) + 0.5
    a4 = torch.rand(1) + 0.5
    
    w1, index = torch.sort(torch.rand(M)*torch.pi)
    w2, index = torch.sort(torch.rand(M)*torch.pi)
    
    a1 = a1.to(dev)
    a1.requires_grad_(True)
    
    a2 = a2.to(dev)
    a2.requires_grad_(True)
    
    a3 = a3.to(dev)
    a3.requires_grad_(True)
    
    a4 = a4.to(dev)
    a4.requires_grad_(True)
    
    w1 = w1.to(dev)
    w1.requires_grad_(True)
    
    w2 = w2.to(dev)
    w2.requires_grad_(True)
    
    
    return a1,a2,a3,a4,w1,w2

In [3]:
def orthogonal_init(M,dev):
    a1 = torch.tensor(1.0, requires_grad = True, device = dev)
    a2 = torch.tensor(np.sqrt(2), requires_grad = True, device = dev) # sqrt(2)
    a3 = torch.tensor(1.0, requires_grad = True, device = dev)
    a4 = torch.tensor(np.sqrt(2), requires_grad = True, device = dev)

    
    w1 = (torch.arange(M, device = dev)*2+1)/(2*M)*torch.pi
    w2 = (torch.arange(M, device = dev)*2+1)/(2*M)*torch.pi
    w1.requires_grad_(True)
    w2.requires_grad_(True)
    # w1.retain_grad()
    # w2.retain_grad()
    
    return a1,a2,a3,a4,w1,w2

In [4]:
def forward_pass_id(x,a1,a2,a3,a4,w1,dev):
    N = x.size()[0]
    M = w1.size()[0]
    
    # weight matrix W1
    W1 = a2 * torch.cos(torch.outer(w1,torch.arange(N, device = dev)))
    W1[:,0] = a1
    
    # frequency domain X
    X = torch.matmul(W1,x)/np.sqrt(N)
    
    # weight matrix W2_1 with same frequency components w1
    W2_1 = a4 * torch.cos(torch.outer(torch.arange(N, device = dev),w1))
    W2_1[0] = a3
            
    y = torch.matmul(W2_1,X)/np.sqrt(N)
    
    return X,y

In [5]:
def forward_pass_dif(x,a1,a2,a3,a4,w1,w2,dev):
    N = x.size()[0]
    M = w1.size()[0]
    
    # weight matrix W1
    W1 = a2 * torch.cos(torch.outer(w1,torch.arange(N, device = dev)))
    W1[:,0] = a1
    
    # frequency domain X
    X = torch.matmul(W1,x)/np.sqrt(N)
    
    # weight matrix W2_2 with different frequency components w2
    W2_2 = a4 * torch.cos(torch.outer(torch.arange(N, device = dev),w2))
    W2_2[0] = a3
            
    y = torch.matmul(W2_2,X)/np.sqrt(N)
    
    return X,y

In [6]:
def loss(x,y):
    return ((x-y)**2).mean()

In [6]:
def parameter_update(loss,lr,a1,a2,a3,a4,w1,w2,forward_type):
    loss.backward()
    with torch.no_grad():
        a1 -= lr * a1.grad
        a1.grad.zero_()
        a2 -= lr * a2.grad
        a2.grad.zero_()
        a3 -= lr * a3.grad
        a3.grad.zero_()
        a4 -= lr * a4.grad
        a4.grad.zero_()
        
        w1 -= lr * w1.grad
        w1.grad.zero_()
        
        if forward_type == "dif":
            w2 -= lr * w2.grad
            w2.grad.zero_()
            
    return a1,a2,a3,a4,w1,w2

### Build the model

#### hyperparameters

In [8]:
N = 512  # sequnece length
M = 512  # frequency components
dev = torch.device("cuda") # operate on GPU
batch_size = 100;

#### input data

In [66]:
# random generated white noise

n_sample = 2000;
x = torch.rand(N,n_sample, device = dev)*2 - 1

In [67]:
x.size()

torch.Size([512, 2000])

In [68]:
n_batches = np.floor(n_sample/batch_size)
n_batches = n_batches.astype(int)
n_batches

20

In [197]:
n_iter = 1000
lr = 0.001

In [38]:
# a1,a2,a3,a4,w1,w2 = random_init(M,dev)
print(f'a1 = {a1}, a2 = {a2}, a3 = {a3}, a4 = {a4}, w1 = {w1}')

a1 = tensor([1.3706], device='cuda:0', requires_grad=True), a2 = tensor([0.7380], device='cuda:0', requires_grad=True), a3 = tensor([0.8174], device='cuda:0', requires_grad=True), a4 = tensor([0.6694], device='cuda:0', requires_grad=True), w1 = tensor([1.1334e-03, 7.8128e-03, 2.5556e-02, 5.4142e-02, 7.5123e-02, 7.6172e-02,
        7.6510e-02, 8.5882e-02, 8.6886e-02, 9.1110e-02, 9.1691e-02, 9.5817e-02,
        1.0957e-01, 1.1061e-01, 1.2605e-01, 1.3971e-01, 1.4089e-01, 1.5058e-01,
        1.7243e-01, 1.7544e-01, 1.8666e-01, 1.9130e-01, 2.0587e-01, 2.1194e-01,
        2.1738e-01, 2.3759e-01, 2.3906e-01, 2.4217e-01, 2.4721e-01, 2.5127e-01,
        2.5664e-01, 2.6353e-01, 2.6488e-01, 2.7214e-01, 2.7320e-01, 2.8084e-01,
        2.8270e-01, 2.9685e-01, 2.9696e-01, 3.0753e-01, 3.1418e-01, 3.1720e-01,
        3.2904e-01, 3.3152e-01, 3.3628e-01, 3.3845e-01, 3.5030e-01, 3.5149e-01,
        3.5447e-01, 3.6072e-01, 3.6275e-01, 3.6316e-01, 3.6380e-01, 3.7369e-01,
        3.7369e-01, 3.7530e-01, 3.7

In [157]:
a1,a2,a3,a4,w1 = torch.load('parameters3.pt')
print(f'a1 = {a1}, a2 = {a2}, a3 = {a3}, a4 = {a4}, w1 = {w1}')

a1 = tensor([0.9927], device='cuda:0', requires_grad=True), a2 = tensor([1.3094], device='cuda:0', requires_grad=True), a3 = tensor([0.9328], device='cuda:0', requires_grad=True), a4 = tensor([1.2709], device='cuda:0', requires_grad=True), w1 = tensor([-2.0554e-03,  8.6142e-03,  2.1764e-02,  5.1523e-02,  5.7763e-02,
         6.3935e-02,  8.8423e-02,  9.4723e-02,  7.5860e-02,  9.9779e-02,
         8.2261e-02,  1.0755e-01,  1.1253e-01,  1.1938e-01,  1.2437e-01,
         1.3102e-01,  1.4318e-01,  1.5533e-01,  1.6186e-01,  1.6607e-01,
         1.7363e-01,  1.7899e-01,  1.8700e-01,  1.9174e-01,  1.9962e-01,
         2.0411e-01,  2.1181e-01,  2.3675e-01,  2.1612e-01,  2.2824e-01,
         2.2358e-01,  2.4882e-01,  2.4050e-01,  2.5316e-01,  2.6032e-01,
         2.7278e-01,  2.6574e-01,  2.7796e-01,  2.8405e-01,  2.9051e-01,
         2.9548e-01,  3.0775e-01,  3.2674e-01,  3.0327e-01,  3.1592e-01,
         3.3867e-01,  3.3369e-01,  3.4612e-01,  3.2124e-01,  3.7556e-01,
         3.5082e-01,  4.0

In [50]:
cache_n = 10
l = torch.ones(cache_n)*100
l = l.to(dev)
l

tensor([100., 100., 100., 100., 100., 100., 100., 100., 100., 100.],
       device='cuda:0')

In [218]:
# forward_pass_id

list_para = []
L = 0
decimation = 50

for epoch in range(n_iter):
    
    x = x[:,torch.randperm(n_sample)]
    for b in range(n_batches):
        x_batch = x[:,b*batch_size:(b+1)*batch_size]
        X,y = forward_pass_id(x_batch,a1,a2,a3,a4,w1,dev)
        lss = loss(x_batch,y)
        #if lss > l.max():
        #    print(f'a1 = {a1}, a2 = {a2}, a3 = {a3}, a4 = {a4}, w1 = {w1}, loss = {lss}, prev_loss = {l}')
        #    torch.save([a1,a2,a3,a4,w1],'parameters.pt')
        #    raise Exception("Checkpoint failed") 
        #else:
        #    transit = torch.clone(l)
        #    l[0:cache_n-1] = transit[1:cache_n]
        #    l[cache_n-1] = lss.item()
        #    
        L += lss.item()
        lss.backward()
        with torch.no_grad():
            a1 -= lr * a1.grad * (N-1)
            a1.grad.zero_()
            a2 -= lr * a2.grad
            a2.grad.zero_()
            a3 -= lr * a3.grad * (N-1)
            a3.grad.zero_()
            a4 -= lr * a4.grad
            a4.grad.zero_()

            w1 -= lr * w1.grad
            w1.grad.zero_()
            
    if (epoch+1) % decimation == 0:
        L = L/(n_batches*decimation)
        print(f'a1 = {a1.item()}, a2 = {a2.item()}, a3 = {a3.item()}, a4 = {a4.item()}, loss = {L}')
        # print(f'a1 = {a1}, a2 = {a2}, a3 = {a3}, a4 = {a4}, w1 = {w1}, loss = {L}')
        # list_para.append([a1.item(),a2.item(),a3.item(),a4.item(),w1.detach().to('cpu').numpy()])
        torch.save([a1,a2,a3,a4,w1],'parameters.pt')
        L = 0

a1 = 1.0413668155670166, a2 = 1.420009970664978, a3 = 0.9421101808547974, a4 = 1.3846323490142822, loss = 0.0057315724836662415
a1 = 1.0413720607757568, a2 = 1.4200108051300049, a3 = 0.9421101808547974, a4 = 1.3846327066421509, loss = 0.005731783212628215
a1 = 1.041373372077942, a2 = 1.4200096130371094, a3 = 0.9421074390411377, a4 = 1.3846322298049927, loss = 0.0057317745489999655
a1 = 1.0413721799850464, a2 = 1.4200096130371094, a3 = 0.9421025514602661, a4 = 1.3846324682235718, loss = 0.005731547077186406
a1 = 1.0413762331008911, a2 = 1.420009970664978, a3 = 0.9421055316925049, a4 = 1.3846324682235718, loss = 0.0057316397158429025
a1 = 1.041375994682312, a2 = 1.4200105667114258, a3 = 0.9421002864837646, a4 = 1.3846324682235718, loss = 0.005731502728071064
a1 = 1.041375994682312, a2 = 1.42001211643219, a3 = 0.9421005845069885, a4 = 1.3846343755722046, loss = 0.005731547180563211
a1 = 1.0413750410079956, a2 = 1.4200105667114258, a3 = 0.9420995712280273, a4 = 1.3846355676651, loss = 0.00

In [200]:
W1 = (torch.arange(M, device = dev)*2+1)/(2*M)*torch.pi

In [219]:
torch.norm(W1-w1)

tensor(0.7823, device='cuda:0', grad_fn=<NormBackward1>)

In [217]:
torch.norm(W1-w1)

tensor(0.7846, device='cuda:0', grad_fn=<NormBackward1>)

In [164]:
w2 = torch.clone(w1).detach()
w2.requires_grad_(True)
w2

tensor([-3.0793e-03,  9.2257e-03,  2.1620e-02,  5.2028e-02,  5.8188e-02,
         6.4248e-02,  8.8801e-02,  9.4937e-02,  7.6636e-02,  1.0106e-01,
         8.2700e-02,  1.0717e-01,  1.1326e-01,  1.1941e-01,  1.2559e-01,
         1.3163e-01,  1.4418e-01,  1.5610e-01,  1.6225e-01,  1.6836e-01,
         1.7449e-01,  1.8061e-01,  1.8674e-01,  1.9285e-01,  1.9897e-01,
         2.0509e-01,  2.1122e-01,  2.3569e-01,  2.1733e-01,  2.2958e-01,
         2.2345e-01,  2.4794e-01,  2.4181e-01,  2.5405e-01,  2.6017e-01,
         2.7241e-01,  2.6629e-01,  2.7853e-01,  2.8465e-01,  2.9077e-01,
         2.9688e-01,  3.0912e-01,  3.2748e-01,  3.0301e-01,  3.1524e-01,
         3.3971e-01,  3.3359e-01,  3.4583e-01,  3.2136e-01,  3.7641e-01,
         3.5194e-01,  4.0087e-01,  3.5806e-01,  3.7029e-01,  3.9476e-01,
         4.1310e-01,  3.6418e-01,  3.8864e-01,  4.2533e-01,  4.0699e-01,
         4.3144e-01,  3.8253e-01,  4.3756e-01,  4.1922e-01,  4.4978e-01,
         4.5590e-01,  4.4367e-01,  4.6812e-01,  4.6

In [184]:
# forward_pass_dif

list_para = []
L = 0
decimation = 50

for epoch in range(n_iter):
    
    x = x[:,torch.randperm(n_sample)]
    for b in range(n_batches):
        x_batch = x[:,b*batch_size:(b+1)*batch_size]
        X,y = forward_pass_dif(x_batch,a1,a2,a3,a4,w1,w2,dev)
        lss = loss(x_batch,y)
        #if lss > l.max():
        #    print(f'a1 = {a1}, a2 = {a2}, a3 = {a3}, a4 = {a4}, w1 = {w1}, loss = {lss}, prev_loss = {l}')
        #    torch.save([a1,a2,a3,a4,w1],'parameters.pt')
        #    raise Exception("Checkpoint failed") 
        #else:
        #    transit = torch.clone(l)
        #    l[0:cache_n-1] = transit[1:cache_n]
        #    l[cache_n-1] = lss.item()
        #    
        L += lss.item()
        lss.backward()
        with torch.no_grad():
            a1 -= lr * a1.grad * (N-1)
            a1.grad.zero_()
            a2 -= lr * a2.grad
            a2.grad.zero_()
            a3 -= lr * a3.grad * (N-1)
            a3.grad.zero_()
            a4 -= lr * a4.grad
            a4.grad.zero_()

            w1 -= lr * w1.grad
            w1.grad.zero_()
            w2 -= lr * w2.grad
            w2.grad.zero_()
            
    if (epoch+1) % decimation == 0:
        L = L/(n_batches*decimation)
        print(f'a1 = {a1.item()}, a2 = {a2.item()}, a3 = {a3.item()}, a4 = {a4.item()}, loss = {L}')
        # print(f'a1 = {a1}, a2 = {a2}, a3 = {a3}, a4 = {a4}, w1 = {w1}, loss = {L}')
        # list_para.append([a1.item(),a2.item(),a3.item(),a4.item(),w1.detach().to('cpu').numpy()])
        torch.save([a1,a2,a3,a4,w1,w2],'parameters_dif.pt')
        L = 0

a1 = 1.0411267280578613, a2 = 1.4199849367141724, a3 = 0.9422417879104614, a4 = 1.3846133947372437, loss = 0.005735061415936798
a1 = 1.041124701499939, a2 = 1.4199851751327515, a3 = 0.9422412514686584, a4 = 1.3846129179000854, loss = 0.00573510434338823
a1 = 1.041129469871521, a2 = 1.4199862480163574, a3 = 0.9422422051429749, a4 = 1.384615421295166, loss = 0.005735167752485722
a1 = 1.0411309003829956, a2 = 1.4199872016906738, a3 = 0.9422382712364197, a4 = 1.384615421295166, loss = 0.00573522446770221
a1 = 1.0411386489868164, a2 = 1.4199868440628052, a3 = 0.9422338604927063, a4 = 1.38461434841156, loss = 0.005735276406630874
a1 = 1.0411405563354492, a2 = 1.419985294342041, a3 = 0.9422324299812317, a4 = 1.384613037109375, loss = 0.005734992334619164
a1 = 1.0411410331726074, a2 = 1.4199860095977783, a3 = 0.9422239065170288, a4 = 1.3846124410629272, loss = 0.005735049194190651
a1 = 1.0411441326141357, a2 = 1.419987678527832, a3 = 0.9422239065170288, a4 = 1.3846142292022705, loss = 0.005735

In [124]:
list_para

[[0.9948885440826416,
  1.3040611743927002,
  0.9398193359375,
  1.2653857469558716,
  array([-2.13834178e-03,  9.67403408e-03,  2.16691997e-02,  5.17155975e-02,
          5.74714504e-02,  6.44636229e-02,  8.82105157e-02,  9.41595137e-02,
          7.61836171e-02,  1.01200074e-01,  8.19512010e-02,  1.06008701e-01,
          1.13346256e-01,  1.19104818e-01,  1.24790311e-01,  1.31481901e-01,
          1.43861666e-01,  1.55668750e-01,  1.61281526e-01,  1.68309078e-01,
          1.73118129e-01,  1.80571750e-01,  1.84459955e-01,  1.92790329e-01,
          1.98435828e-01,  2.04908744e-01,  2.10013166e-01,  2.34978840e-01,
          2.17449367e-01,  2.29720026e-01,  2.22764730e-01,  2.47190714e-01,
          2.42231727e-01,  2.54457712e-01,  2.59261936e-01,  2.71682858e-01,
          2.66449600e-01,  2.79208153e-01,  2.84264565e-01,  2.91570902e-01,
          2.96564579e-01,  3.08894008e-01,  3.27057779e-01,  3.03096741e-01,
          3.15136909e-01,  3.38460326e-01,  3.33957583e-01,  3.46602

In [113]:
a4

tensor([1.2609], device='cuda:0', requires_grad=True)

In [142]:
X,y = forward_pass_id(x_batch,a1,a2,a3,a4,w1,dev)
y

tensor([[-0.0926, -0.2471,  0.5919,  ..., -0.9111,  0.6393,  0.1052],
        [ 0.4042, -0.1179, -0.6990,  ..., -0.0931,  0.3232, -0.1757],
        [-0.0742, -0.2354, -0.2431,  ..., -0.2107, -0.4751, -0.2798],
        ...,
        [-0.0791,  0.2420,  0.3527,  ..., -0.2526,  0.7959, -0.4959],
        [ 0.0603, -0.2363,  0.6479,  ..., -0.2369,  0.1252, -0.4896],
        [ 0.3531,  0.5500,  0.2493,  ..., -0.2125, -0.4343,  0.2217]],
       device='cuda:0', grad_fn=<DivBackward0>)

In [143]:
lss= loss(x_batch,y)
lss

tensor(0.1881, device='cuda:0', grad_fn=<MeanBackward0>)

In [144]:
lss.backward()

In [137]:
a1.grad

tensor([-1.2656e-05], device='cuda:0')

loss = 0.005742670204024761

In [40]:
l

tensor([100., 100., 100., 100., 100.], device='cuda:0')

In [20]:
loss(x_batch,y)

tensor(0.1084, device='cuda:0', grad_fn=<MeanBackward0>)

In [182]:
a11,a21,a31,a41,w11 = torch.load('parameters4.pt')
print(f'a1 = {a11}, a2 = {a21}, a3 = {a31}, a4 = {a41}, w1 = {w11}')

a1 = tensor([1.0408], device='cuda:0', requires_grad=True), a2 = tensor([1.4200], device='cuda:0', requires_grad=True), a3 = tensor([0.9426], device='cuda:0', requires_grad=True), a4 = tensor([1.3846], device='cuda:0', requires_grad=True), w1 = tensor([-3.0793e-03,  9.2257e-03,  2.1620e-02,  5.2028e-02,  5.8188e-02,
         6.4248e-02,  8.8801e-02,  9.4937e-02,  7.6636e-02,  1.0106e-01,
         8.2700e-02,  1.0717e-01,  1.1326e-01,  1.1941e-01,  1.2559e-01,
         1.3163e-01,  1.4418e-01,  1.5610e-01,  1.6225e-01,  1.6836e-01,
         1.7449e-01,  1.8061e-01,  1.8674e-01,  1.9285e-01,  1.9897e-01,
         2.0509e-01,  2.1122e-01,  2.3569e-01,  2.1733e-01,  2.2958e-01,
         2.2345e-01,  2.4794e-01,  2.4181e-01,  2.5405e-01,  2.6017e-01,
         2.7241e-01,  2.6629e-01,  2.7853e-01,  2.8465e-01,  2.9077e-01,
         2.9688e-01,  3.0912e-01,  3.2748e-01,  3.0301e-01,  3.1524e-01,
         3.3971e-01,  3.3359e-01,  3.4583e-01,  3.2136e-01,  3.7641e-01,
         3.5194e-01,  4.0

In [193]:
a11,a21,a31,a41,w11,w21 = torch.load('parameters_dif.pt')
print(f'a1 = {a11}, a2 = {a21}, a3 = {a31}, a4 = {a41}, w1 = {w11}, w2 = {w21}')

a1 = tensor([1.0412], device='cuda:0', requires_grad=True), a2 = tensor([1.4200], device='cuda:0', requires_grad=True), a3 = tensor([0.9422], device='cuda:0', requires_grad=True), a4 = tensor([1.3846], device='cuda:0', requires_grad=True), w1 = tensor([-3.0600e-03,  9.2529e-03,  2.1638e-02,  5.2068e-02,  5.8182e-02,
         6.4295e-02,  8.8802e-02,  9.4914e-02,  7.6577e-02,  1.0104e-01,
         8.2695e-02,  1.0714e-01,  1.1327e-01,  1.1940e-01,  1.2553e-01,
         1.3167e-01,  1.4401e-01,  1.5610e-01,  1.6221e-01,  1.6834e-01,
         1.7446e-01,  1.8057e-01,  1.8669e-01,  1.9281e-01,  1.9893e-01,
         2.0505e-01,  2.1117e-01,  2.3564e-01,  2.1729e-01,  2.2952e-01,
         2.2341e-01,  2.4787e-01,  2.4176e-01,  2.5399e-01,  2.6011e-01,
         2.7234e-01,  2.6623e-01,  2.7846e-01,  2.8457e-01,  2.9069e-01,
         2.9681e-01,  3.0904e-01,  3.2739e-01,  3.0292e-01,  3.1515e-01,
         3.3962e-01,  3.3350e-01,  3.4573e-01,  3.2127e-01,  3.7630e-01,
         3.5184e-01,  4.0

In [179]:
W1 = (torch.arange(M, device = dev)*2+1)/(2*M)*torch.pi
W1

tensor([3.0680e-03, 9.2039e-03, 1.5340e-02, 2.1476e-02, 2.7612e-02, 3.3748e-02,
        3.9884e-02, 4.6019e-02, 5.2155e-02, 5.8291e-02, 6.4427e-02, 7.0563e-02,
        7.6699e-02, 8.2835e-02, 8.8971e-02, 9.5107e-02, 1.0124e-01, 1.0738e-01,
        1.1351e-01, 1.1965e-01, 1.2579e-01, 1.3192e-01, 1.3806e-01, 1.4419e-01,
        1.5033e-01, 1.5647e-01, 1.6260e-01, 1.6874e-01, 1.7487e-01, 1.8101e-01,
        1.8715e-01, 1.9328e-01, 1.9942e-01, 2.0555e-01, 2.1169e-01, 2.1783e-01,
        2.2396e-01, 2.3010e-01, 2.3623e-01, 2.4237e-01, 2.4850e-01, 2.5464e-01,
        2.6078e-01, 2.6691e-01, 2.7305e-01, 2.7918e-01, 2.8532e-01, 2.9146e-01,
        2.9759e-01, 3.0373e-01, 3.0986e-01, 3.1600e-01, 3.2214e-01, 3.2827e-01,
        3.3441e-01, 3.4054e-01, 3.4668e-01, 3.5282e-01, 3.5895e-01, 3.6509e-01,
        3.7122e-01, 3.7736e-01, 3.8350e-01, 3.8963e-01, 3.9577e-01, 4.0190e-01,
        4.0804e-01, 4.1417e-01, 4.2031e-01, 4.2645e-01, 4.3258e-01, 4.3872e-01,
        4.4485e-01, 4.5099e-01, 4.5713e-

In [199]:
torch.norm(W1-w1)

tensor(0.8043, device='cuda:0', grad_fn=<NormBackward1>)

In [251]:
# initialization

a1 = torch.tensor(1.3706, requires_grad = True, device = dev)
a2 = torch.tensor(0.7380, requires_grad = True, device = dev)
a3 = torch.tensor(0.8174, requires_grad = True, device = dev)
a4 = torch.tensor(0.6694, requires_grad = True, device = dev)
w1 = torch.tensor([1.1334e-03, 7.8128e-03, 2.5556e-02, 5.4142e-02, 7.5123e-02, 7.6172e-02,
        7.6510e-02, 8.5882e-02, 8.6886e-02, 9.1110e-02, 9.1691e-02, 9.5817e-02,
        1.0957e-01, 1.1061e-01, 1.2605e-01, 1.3971e-01, 1.4089e-01, 1.5058e-01,
        1.7243e-01, 1.7544e-01, 1.8666e-01, 1.9130e-01, 2.0587e-01, 2.1194e-01,
        2.1738e-01, 2.3759e-01, 2.3906e-01, 2.4217e-01, 2.4721e-01, 2.5127e-01,
        2.5664e-01, 2.6353e-01, 2.6488e-01, 2.7214e-01, 2.7320e-01, 2.8084e-01,
        2.8270e-01, 2.9685e-01, 2.9696e-01, 3.0753e-01, 3.1418e-01, 3.1720e-01,
        3.2904e-01, 3.3152e-01, 3.3628e-01, 3.3845e-01, 3.5030e-01, 3.5149e-01,
        3.5447e-01, 3.6072e-01, 3.6275e-01, 3.6316e-01, 3.6380e-01, 3.7369e-01,
        3.7369e-01, 3.7530e-01, 3.7816e-01, 3.8047e-01, 3.8091e-01, 3.8705e-01,
        3.8853e-01, 3.8901e-01, 3.9189e-01, 3.9193e-01, 4.0319e-01, 4.1430e-01,
        4.1970e-01, 4.2083e-01, 4.2808e-01, 4.3936e-01, 4.4141e-01, 4.5836e-01,
        4.6695e-01, 4.8586e-01, 4.9130e-01, 5.0329e-01, 5.1413e-01, 5.1521e-01,
        5.1918e-01, 5.2263e-01, 5.2306e-01, 5.2800e-01, 5.5170e-01, 5.6189e-01,
        5.6740e-01, 5.9523e-01, 5.9937e-01, 6.0351e-01, 6.0504e-01, 6.1747e-01,
        6.2096e-01, 6.2463e-01, 6.3205e-01, 6.3589e-01, 6.3936e-01, 6.4256e-01,
        6.4791e-01, 6.4819e-01, 6.7773e-01, 6.9186e-01, 7.0218e-01, 7.1077e-01,
        7.1499e-01, 7.1659e-01, 7.1672e-01, 7.2524e-01, 7.2783e-01, 7.2968e-01,
        7.3806e-01, 7.4392e-01, 7.4523e-01, 7.5162e-01, 7.6501e-01, 7.6655e-01,
        7.6843e-01, 7.7050e-01, 7.7300e-01, 7.8455e-01, 7.9845e-01, 8.0296e-01,
        8.0581e-01, 8.1032e-01, 8.1180e-01, 8.1503e-01, 8.1829e-01, 8.2715e-01,
        8.3038e-01, 8.3414e-01, 8.3668e-01, 8.4204e-01, 8.4537e-01, 8.4974e-01,
        8.5228e-01, 8.5367e-01, 8.5638e-01, 8.5896e-01, 8.6945e-01, 8.7149e-01,
        8.9765e-01, 9.0214e-01, 9.0497e-01, 9.1642e-01, 9.2448e-01, 9.3773e-01,
        9.6030e-01, 9.6883e-01, 9.6915e-01, 9.7510e-01, 9.8488e-01, 1.0024e+00,
        1.0086e+00, 1.0155e+00, 1.0240e+00, 1.0565e+00, 1.0593e+00, 1.0690e+00,
        1.0761e+00, 1.0943e+00, 1.1065e+00, 1.1131e+00, 1.1203e+00, 1.1279e+00,
        1.1316e+00, 1.1358e+00, 1.1423e+00, 1.1555e+00, 1.1570e+00, 1.1601e+00,
        1.1646e+00, 1.1693e+00, 1.1758e+00, 1.1760e+00, 1.1971e+00, 1.1972e+00,
        1.2000e+00, 1.2039e+00, 1.2086e+00, 1.2120e+00, 1.2152e+00, 1.2206e+00,
        1.2207e+00, 1.2309e+00, 1.2388e+00, 1.2435e+00, 1.2544e+00, 1.2570e+00,
        1.2588e+00, 1.2618e+00, 1.2760e+00, 1.2768e+00, 1.2798e+00, 1.2800e+00,
        1.2830e+00, 1.2858e+00, 1.2873e+00, 1.3044e+00, 1.3108e+00, 1.3164e+00,
        1.3221e+00, 1.3244e+00, 1.3256e+00, 1.3337e+00, 1.3354e+00, 1.3422e+00,
        1.3628e+00, 1.3772e+00, 1.3779e+00, 1.3838e+00, 1.3903e+00, 1.3980e+00,
        1.4083e+00, 1.4121e+00, 1.4147e+00, 1.4166e+00, 1.4179e+00, 1.4221e+00,
        1.4373e+00, 1.4416e+00, 1.4440e+00, 1.4515e+00, 1.4524e+00, 1.4594e+00,
        1.4804e+00, 1.4813e+00, 1.4855e+00, 1.4882e+00, 1.4885e+00, 1.4969e+00,
        1.5047e+00, 1.5056e+00, 1.5073e+00, 1.5155e+00, 1.5156e+00, 1.5192e+00,
        1.5193e+00, 1.5247e+00, 1.5247e+00, 1.5254e+00, 1.5258e+00, 1.5278e+00,
        1.5367e+00, 1.5393e+00, 1.5476e+00, 1.5498e+00, 1.5554e+00, 1.5564e+00,
        1.5601e+00, 1.5637e+00, 1.5642e+00, 1.5682e+00, 1.5698e+00, 1.5708e+00,
        1.5710e+00, 1.5730e+00, 1.5939e+00, 1.6111e+00, 1.6197e+00, 1.6208e+00,
        1.6218e+00, 1.6231e+00, 1.6339e+00, 1.6381e+00, 1.6401e+00, 1.6416e+00,
        1.6419e+00, 1.6431e+00, 1.6466e+00, 1.6495e+00, 1.6498e+00, 1.6540e+00,
        1.6572e+00, 1.6645e+00, 1.6728e+00, 1.7095e+00, 1.7128e+00, 1.7172e+00,
        1.7213e+00, 1.7277e+00, 1.7516e+00, 1.7532e+00, 1.7559e+00, 1.7738e+00,
        1.7794e+00, 1.7811e+00, 1.7828e+00, 1.7834e+00, 1.7838e+00, 1.7850e+00,
        1.7939e+00, 1.8071e+00, 1.8291e+00, 1.8326e+00, 1.8401e+00, 1.8533e+00,
        1.8542e+00, 1.8807e+00, 1.8910e+00, 1.8969e+00, 1.8996e+00, 1.9038e+00,
        1.9038e+00, 1.9089e+00, 1.9161e+00, 1.9204e+00, 1.9251e+00, 1.9262e+00,
        1.9311e+00, 1.9327e+00, 1.9338e+00, 1.9370e+00, 1.9437e+00, 1.9520e+00,
        1.9527e+00, 1.9579e+00, 1.9581e+00, 1.9678e+00, 1.9685e+00, 1.9687e+00,
        1.9787e+00, 1.9924e+00, 1.9963e+00, 2.0023e+00, 2.0074e+00, 2.0243e+00,
        2.0266e+00, 2.0338e+00, 2.0420e+00, 2.0439e+00, 2.0581e+00, 2.0606e+00,
        2.0664e+00, 2.0665e+00, 2.0673e+00, 2.0756e+00, 2.0843e+00, 2.0854e+00,
        2.0855e+00, 2.0875e+00, 2.1030e+00, 2.1118e+00, 2.1208e+00, 2.1304e+00,
        2.1331e+00, 2.1352e+00, 2.1391e+00, 2.1433e+00, 2.1439e+00, 2.1475e+00,
        2.1572e+00, 2.1610e+00, 2.1828e+00, 2.1875e+00, 2.1925e+00, 2.1936e+00,
        2.2100e+00, 2.2178e+00, 2.2235e+00, 2.2315e+00, 2.2415e+00, 2.2445e+00,
        2.2559e+00, 2.2642e+00, 2.2845e+00, 2.2958e+00, 2.2976e+00, 2.2980e+00,
        2.3036e+00, 2.3060e+00, 2.3137e+00, 2.3236e+00, 2.3300e+00, 2.3367e+00,
        2.3373e+00, 2.3422e+00, 2.3453e+00, 2.3580e+00, 2.3601e+00, 2.3618e+00,
        2.3641e+00, 2.3649e+00, 2.3736e+00, 2.3823e+00, 2.3831e+00, 2.3838e+00,
        2.3875e+00, 2.3957e+00, 2.4004e+00, 2.4108e+00, 2.4130e+00, 2.4133e+00,
        2.4218e+00, 2.4246e+00, 2.4339e+00, 2.4367e+00, 2.4369e+00, 2.4455e+00,
        2.4492e+00, 2.4556e+00, 2.4572e+00, 2.4678e+00, 2.4747e+00, 2.4846e+00,
        2.4911e+00, 2.5004e+00, 2.5051e+00, 2.5161e+00, 2.5188e+00, 2.5192e+00,
        2.5202e+00, 2.5217e+00, 2.5408e+00, 2.5486e+00, 2.5562e+00, 2.5629e+00,
        2.5631e+00, 2.5638e+00, 2.5779e+00, 2.5786e+00, 2.5829e+00, 2.5831e+00,
        2.5897e+00, 2.5983e+00, 2.6045e+00, 2.6088e+00, 2.6129e+00, 2.6157e+00,
        2.6174e+00, 2.6227e+00, 2.6278e+00, 2.6280e+00, 2.6296e+00, 2.6302e+00,
        2.6324e+00, 2.6425e+00, 2.6516e+00, 2.6519e+00, 2.6772e+00, 2.6832e+00,
        2.6869e+00, 2.6993e+00, 2.7076e+00, 2.7130e+00, 2.7157e+00, 2.7200e+00,
        2.7325e+00, 2.7374e+00, 2.7467e+00, 2.7639e+00, 2.7667e+00, 2.7669e+00,
        2.7749e+00, 2.7779e+00, 2.7789e+00, 2.7827e+00, 2.7851e+00, 2.7914e+00,
        2.7988e+00, 2.8020e+00, 2.8086e+00, 2.8088e+00, 2.8113e+00, 2.8446e+00,
        2.8457e+00, 2.8504e+00, 2.8542e+00, 2.8549e+00, 2.8601e+00, 2.8765e+00,
        2.8863e+00, 2.8929e+00, 2.8941e+00, 2.8961e+00, 2.9023e+00, 2.9026e+00,
        2.9028e+00, 2.9129e+00, 2.9146e+00, 2.9240e+00, 2.9302e+00, 2.9340e+00,
        2.9347e+00, 2.9368e+00, 2.9411e+00, 2.9435e+00, 2.9498e+00, 2.9522e+00,
        2.9561e+00, 2.9562e+00, 2.9594e+00, 2.9631e+00, 2.9790e+00, 2.9992e+00,
        3.0000e+00, 3.0010e+00, 3.0092e+00, 3.0105e+00, 3.0114e+00, 3.0354e+00,
        3.0385e+00, 3.0432e+00, 3.0513e+00, 3.0514e+00, 3.0784e+00, 3.0988e+00,
        3.1001e+00, 3.1105e+00, 3.1114e+00, 3.1179e+00, 3.1268e+00, 3.1284e+00,
        3.1319e+00, 3.1384e+00], requires_grad = True, device = dev)
print(f'a1 = {a1}, a2 = {a2}, a3 = {a3}, a4 = {a4}, w1 = {w1}')

a1 = 1.3705999851226807, a2 = 0.7379999756813049, a3 = 0.8173999786376953, a4 = 0.6693999767303467, w1 = tensor([1.1334e-03, 7.8128e-03, 2.5556e-02, 5.4142e-02, 7.5123e-02, 7.6172e-02,
        7.6510e-02, 8.5882e-02, 8.6886e-02, 9.1110e-02, 9.1691e-02, 9.5817e-02,
        1.0957e-01, 1.1061e-01, 1.2605e-01, 1.3971e-01, 1.4089e-01, 1.5058e-01,
        1.7243e-01, 1.7544e-01, 1.8666e-01, 1.9130e-01, 2.0587e-01, 2.1194e-01,
        2.1738e-01, 2.3759e-01, 2.3906e-01, 2.4217e-01, 2.4721e-01, 2.5127e-01,
        2.5664e-01, 2.6353e-01, 2.6488e-01, 2.7214e-01, 2.7320e-01, 2.8084e-01,
        2.8270e-01, 2.9685e-01, 2.9696e-01, 3.0753e-01, 3.1418e-01, 3.1720e-01,
        3.2904e-01, 3.3152e-01, 3.3628e-01, 3.3845e-01, 3.5030e-01, 3.5149e-01,
        3.5447e-01, 3.6072e-01, 3.6275e-01, 3.6316e-01, 3.6380e-01, 3.7369e-01,
        3.7369e-01, 3.7530e-01, 3.7816e-01, 3.8047e-01, 3.8091e-01, 3.8705e-01,
        3.8853e-01, 3.8901e-01, 3.9189e-01, 3.9193e-01, 4.0319e-01, 4.1430e-01,
        4.1970e

#### Calculate gradients

In [85]:
x_batch = torch.rand(N,100, device = dev)*2 - 1
x_batch

tensor([[ 0.7179,  0.8650, -0.5943,  ...,  0.7901, -0.2121,  0.7680],
        [ 0.2279, -0.9521, -0.8201,  ..., -0.8383, -0.2029,  0.4879],
        [-0.9810,  0.6741,  0.3377,  ...,  0.7077, -0.9687,  0.9071],
        ...,
        [ 0.3681, -0.5291, -0.7496,  ...,  0.6096,  0.3589,  0.5597],
        [-0.9341, -0.5160, -0.1211,  ...,  0.3075,  0.0845,  0.8373],
        [ 0.9541, -0.9633, -0.8746,  ..., -0.8030,  0.7034, -0.3630]],
       device='cuda:0')

In [86]:
x_batch.mean()

tensor(-0.0008, device='cuda:0')

In [43]:
a1,a2,a3,a4,w1,w2 = random_init(M,dev)

In [32]:
X,y = forward_pass_id(x_batch,a1,a2,a3,a4,w1,dev)
y

tensor([[ 0.2199, -0.1001,  0.0973,  ...,  0.1587, -0.6989,  0.3203],
        [-0.2368,  0.0965,  0.3227,  ...,  0.0224,  0.1350,  0.1543],
        [ 0.2133,  0.8393,  0.1149,  ..., -0.0242,  0.3296, -0.1854],
        ...,
        [ 0.1300,  0.0925, -0.3117,  ...,  0.1103,  0.2252, -0.2563],
        [-0.1816,  0.2469, -1.3127,  ..., -0.5568,  0.6143,  0.1308],
        [ 0.2991,  0.1904, -0.5052,  ..., -0.4926, -0.2186,  0.1235]],
       device='cuda:0', grad_fn=<DivBackward0>)

In [33]:
X

tensor([[-0.4877,  0.8658,  0.3636,  ..., -0.5485, -0.2062, -0.7994],
        [-0.4585, -0.2088,  0.2274,  ..., -0.4528,  0.0210,  0.0974],
        [-0.4310, -0.7905,  0.0646,  ..., -0.0349,  0.1922,  0.5211],
        ...,
        [ 0.4162,  0.1246, -0.6604,  ...,  0.3254,  0.1437, -0.5955],
        [ 0.5448, -0.0054, -0.4135,  ...,  0.0921,  0.3512, -0.6635],
        [ 0.2768,  0.6091, -0.7016,  ...,  0.5071, -1.2274, -0.4421]],
       device='cuda:0', grad_fn=<DivBackward0>)

In [34]:
ls = loss(x_batch,y)
ls

tensor(0.1632, device='cuda:0', grad_fn=<MeanBackward0>)

In [35]:
ls.backward()

In [42]:
a2.grad

tensor([-0.0050], device='cuda:0')

In [56]:
nn = 10
a1_d = torch.zeros(nn, device = dev)
a2_d = torch.zeros(nn, device = dev)
a3_d = torch.zeros(nn, device = dev)
a4_d = torch.zeros(nn, device = dev)
w1_d = torch.zeros(M, nn, device = dev)
for i in range(nn):
    a1,a2,a3,a4,w1,w2 = random_init(M,dev)
    X,y = forward_pass_id(x_batch,a1,a2,a3,a4,w1,dev)
    ls = loss(x_batch,y)
    ls.backward()
    a1_d[i] = a1.grad.item()
    a2_d[i] = a2.grad.item()
    a3_d[i] = a3.grad.item()
    a4_d[i] = a4.grad.item()
    w1_d[:,i] = w1.grad.detach()
    
print(a1_d)
print(a2_d)
print(a3_d)
print(a4_d)
print(w1_d)

tensor([-0.0003, -0.0001,  0.0008,  0.0018, -0.0001,  0.0023,  0.0014,  0.0011,
         0.0020,  0.0014], device='cuda:0')
tensor([-0.0524, -0.0821,  0.1735,  0.3194, -0.0140, -0.0108,  0.0924,  0.0893,
         0.0911,  0.0823], device='cuda:0')
tensor([ 3.7608e-05,  4.3939e-04,  6.3131e-04,  1.6191e-03, -1.2122e-04,
         1.6408e-03,  7.6690e-04,  4.8274e-04,  1.2257e-03,  5.9257e-04],
       device='cuda:0')
tensor([-0.1277, -0.1583,  0.1963,  0.3436, -0.0200, -0.0067,  0.0643,  0.0487,
         0.0660,  0.0484], device='cuda:0')
tensor([[ 0.0662, -0.0272,  0.0769,  ...,  0.0131,  0.0453,  0.3159],
        [-0.0076,  0.0015,  0.1747,  ...,  0.2179,  0.1939, -0.0079],
        [ 0.0334,  0.0301, -0.1834,  ...,  0.0317, -0.1531, -0.2122],
        ...,
        [-0.0279,  0.0187,  0.7321,  ..., -0.0244,  0.0009,  0.0583],
        [-0.0083, -0.0179,  0.7170,  ...,  0.0655, -0.1442, -0.0433],
        [-0.0143, -0.0729,  0.7112,  ...,  0.0119, -0.5258, -0.0568]],
       device='cuda:0')

In [60]:
w1_d.max(0)

torch.return_types.max(
values=tensor([0.1726, 0.0959, 0.7345, 1.3880, 0.2829, 0.2404, 0.3498, 0.4108, 0.4903,
        0.4482], device='cuda:0'),
indices=tensor([295, 298, 508, 509, 230, 394, 318, 246, 401, 480], device='cuda:0'))

In [80]:
w1_d[:,1].to('cpu').numpy().round(4)

array([-0.0272,  0.0015,  0.0301, -0.0142,  0.0206,  0.0123, -0.0124,
       -0.0075, -0.0409,  0.0118, -0.0125,  0.0099, -0.0288, -0.0416,
        0.0079, -0.0147, -0.0076,  0.0051,  0.027 , -0.0112,  0.0288,
        0.0111, -0.0374,  0.0458,  0.0246, -0.0155,  0.0057,  0.0942,
        0.0291,  0.0157,  0.0026, -0.0346, -0.1043, -0.0012, -0.0068,
       -0.0216,  0.011 , -0.0325, -0.0062, -0.0177,  0.0349, -0.0253,
        0.0189, -0.0223, -0.04  ,  0.0057,  0.0043, -0.0131,  0.038 ,
       -0.0135, -0.0027,  0.0596,  0.0342, -0.0227, -0.0345, -0.0436,
       -0.0046,  0.0143,  0.0051, -0.0055,  0.0122,  0.0432,  0.0154,
       -0.0364,  0.0224, -0.0194, -0.0013,  0.0115,  0.0236, -0.0029,
       -0.0271,  0.0061,  0.0013,  0.0012,  0.0022, -0.0054,  0.0214,
       -0.0025,  0.0013, -0.0031, -0.0324,  0.0736,  0.0129, -0.03  ,
       -0.0435,  0.015 , -0.0205, -0.0247, -0.0216,  0.0158, -0.0161,
        0.0013,  0.0107, -0.0022, -0.0391,  0.0828,  0.0341,  0.0041,
       -0.0301, -0.0

In [81]:
dl_dy = 2*(y - x_batch)
dl_dy

tensor([[-1.3589,  0.2524,  0.8368,  ..., -0.7726, -0.1416, -0.1690],
        [ 0.0718, -0.2572,  0.5511,  ..., -0.4763, -0.4765,  0.6708],
        [-0.2237, -1.4052,  0.3405,  ..., -0.5510,  1.0055,  0.2452],
        ...,
        [ 1.3694,  0.7889,  1.3972,  ..., -0.0640, -1.5495,  0.0327],
        [ 0.7400, -0.2287,  1.2224,  ..., -1.6703, -1.3714,  0.0965],
        [-0.4092,  1.0259,  0.1044,  ...,  0.7106, -0.2987, -0.0573]],
       device='cuda:0', grad_fn=<MulBackward0>)

In [163]:
dl_dy.size()

torch.Size([512, 100])

In [91]:
a1

tensor([0.8346], device='cuda:0', requires_grad=True)

In [88]:
W1 = a2 * torch.cos(torch.outer(w1,torch.arange(N, device = dev)))
W1[:,0] = a1
W1

tensor([[ 0.8346,  0.8258,  0.8258,  ..., -0.5609, -0.5636, -0.5664],
        [ 0.8346,  0.8258,  0.8258,  ..., -0.7133, -0.7103, -0.7073],
        [ 0.8346,  0.8258,  0.8257,  ..., -0.2774, -0.2708, -0.2640],
        ...,
        [ 0.8346, -0.8258,  0.8256,  ..., -0.8220,  0.8209, -0.8197],
        [ 0.8346, -0.8258,  0.8257,  ...,  0.0117, -0.0040, -0.0035],
        [ 0.8346, -0.8258,  0.8257,  ...,  0.4252, -0.4193,  0.4135]],
       device='cuda:0', grad_fn=<CopySlices>)

In [89]:
W1.mean()

tensor(0.0004, device='cuda:0', grad_fn=<MeanBackward0>)

In [95]:
torch.matmul(W1,x_batch)

tensor([[ 10.4778,   7.1326,   0.1995,  ...,   1.6283,  -2.7474,  18.2252],
        [ 12.2094,  11.3534,  -3.4186,  ...,  -1.6751,  -2.7673,  17.6689],
        [ 10.2783,   8.8102,  -5.5412,  ...,  -2.4758,  -1.3650,  13.6412],
        ...,
        [ 13.2678,   2.2307,  -0.4822,  ...,   2.9631,   5.4477, -21.5998],
        [  6.7982,  -6.8268,  -4.4299,  ...,   1.8589,   9.6202, -12.4924],
        [  4.2076,  -9.1806,  -5.0634,  ...,   1.0289,   8.9828,  -7.0521]],
       device='cuda:0', grad_fn=<MmBackward0>)

In [96]:
# frequency domain X
X = torch.matmul(W1,x_batch)/np.sqrt(N)

In [97]:
X

tensor([[ 0.4631,  0.3152,  0.0088,  ...,  0.0720, -0.1214,  0.8054],
        [ 0.5396,  0.5018, -0.1511,  ..., -0.0740, -0.1223,  0.7809],
        [ 0.4542,  0.3894, -0.2449,  ..., -0.1094, -0.0603,  0.6029],
        ...,
        [ 0.5864,  0.0986, -0.0213,  ...,  0.1310,  0.2408, -0.9546],
        [ 0.3004, -0.3017, -0.1958,  ...,  0.0822,  0.4252, -0.5521],
        [ 0.1860, -0.4057, -0.2238,  ...,  0.0455,  0.3970, -0.3117]],
       device='cuda:0', grad_fn=<DivBackward0>)

In [98]:
X.mean()

tensor(-0.0017, device='cuda:0', grad_fn=<MeanBackward0>)

In [83]:
X.size()

torch.Size([512, 100])

In [100]:
X.sum(0)

tensor([  3.8039,  21.6656, -24.8391, -21.9632,  24.0998,  -7.8583,  15.1261,
          5.5488,   7.3820, -10.2491,   4.9027,   2.4158,   8.2366,  10.4478,
        -14.3395,  18.7449,  -9.5519, -14.9607,  12.8356, -13.0192,   8.0894,
          2.1276,  -5.1401, -20.0954,   3.9024,  14.5469,   7.3188, -22.6988,
          7.0743,  -4.2544, -24.2597, -17.3178,  -0.6942,  13.6916,  19.1789,
         -0.6431,   2.3849,  -3.1406, -10.4475,   4.3627,  18.5296,  -7.3796,
          4.2071,  11.9442,   0.6315,  -4.1080,   7.7273,  18.5813,   6.3990,
        -20.0664,  24.4701,   2.2485,   9.7525, -15.8875, -13.0969,  -0.9899,
          4.5734,  -6.7116,  -9.6121, -17.9269,  15.4654, -20.2539, -21.4766,
         -9.2987,  -6.7565, -16.3769,  -8.1326,   4.8595,   4.5846, -11.7108,
          5.9306,   2.5477, -12.5502, -10.1454,   5.3299, -11.2778,  -2.2759,
          6.5861, -13.9564, -12.6160,  10.6098, -19.3919,  -0.5500,  23.6657,
         13.6566,  13.6390, -12.1026, -16.5028,  14.8447, -13.36

In [101]:
dy_da3 = X.sum(0)/np.sqrt(N)
dy_da3

tensor([ 0.1681,  0.9575, -1.0977, -0.9706,  1.0651, -0.3473,  0.6685,  0.2452,
         0.3262, -0.4530,  0.2167,  0.1068,  0.3640,  0.4617, -0.6337,  0.8284,
        -0.4221, -0.6612,  0.5673, -0.5754,  0.3575,  0.0940, -0.2272, -0.8881,
         0.1725,  0.6429,  0.3234, -1.0032,  0.3126, -0.1880, -1.0721, -0.7653,
        -0.0307,  0.6051,  0.8476, -0.0284,  0.1054, -0.1388, -0.4617,  0.1928,
         0.8189, -0.3261,  0.1859,  0.5279,  0.0279, -0.1816,  0.3415,  0.8212,
         0.2828, -0.8868,  1.0814,  0.0994,  0.4310, -0.7021, -0.5788, -0.0437,
         0.2021, -0.2966, -0.4248, -0.7923,  0.6835, -0.8951, -0.9491, -0.4109,
        -0.2986, -0.7238, -0.3594,  0.2148,  0.2026, -0.5175,  0.2621,  0.1126,
        -0.5546, -0.4484,  0.2355, -0.4984, -0.1006,  0.2911, -0.6168, -0.5576,
         0.4689, -0.8570, -0.0243,  1.0459,  0.6035,  0.6028, -0.5349, -0.7293,
         0.6560, -0.5907,  0.7628, -0.6215, -0.1347, -0.6998, -0.0937, -0.5282,
         0.3599,  0.8678, -0.1025,  0.45

In [165]:
dy_da3.size()

torch.Size([100])

In [102]:
dy_da3.mean()

tensor(-0.0378, device='cuda:0', grad_fn=<MeanBackward0>)

In [103]:
dl_da3 = (dl_dy[0] * dy_da3).mean()
dl_da3

tensor(-0.0200, device='cuda:0', grad_fn=<MeanBackward0>)

In [104]:
a3.grad

tensor([0.0006], device='cuda:0')

In [105]:
cos_multi = torch.cos(torch.outer(torch.arange(1,N, device = dev),w1))
cos_multi.size()

torch.Size([511, 512])

In [106]:
X.size()

torch.Size([512, 100])

In [107]:
dy_da4 = torch.matmul(cos_multi,X)/np.sqrt(N)
dy_da4.size()

torch.Size([511, 100])

In [153]:
dl_dy[1:].size()

torch.Size([511, 100])

In [181]:
dl_da4 = (dl_dy[1:] * dy_da4).sum(0)
dl_da4

tensor([-56.7261, -59.3937, -57.0181, -65.0392, -67.3055, -59.3198, -60.3569,
        -65.7551, -59.8873, -62.4759, -61.7070, -58.9181, -61.8631, -62.8462,
        -59.6636, -63.0951, -60.9490, -62.9357, -58.9745, -62.1437, -63.1904,
        -57.2409, -62.4031, -66.2778, -59.8081, -61.8413, -61.2392, -60.9484,
        -64.4489, -65.6322, -62.3818, -60.8777, -62.1798, -57.0120, -63.6822,
        -62.0460, -59.5706, -61.4337, -66.1770, -57.4036, -58.7815, -58.5906,
        -61.7434, -61.9033, -60.4371, -56.3937, -59.9153, -61.0512, -56.1465,
        -60.2822, -60.7507, -55.2314, -60.7931, -59.6922, -59.4700, -60.0377,
        -59.2238, -63.5398, -61.0271, -60.6404, -60.7489, -60.8953, -60.3152,
        -67.7634, -61.8541, -54.0387, -62.0272, -56.3970, -64.6307, -60.4493,
        -60.8935, -64.2360, -58.3825, -60.0144, -61.9775, -59.5303, -62.9110,
        -54.4366, -57.9840, -58.6321, -61.8986, -61.0809, -59.2599, -60.4900,
        -57.2575, -59.1833, -62.3352, -60.3611, -64.9047, -62.93

In [180]:
dl_da4 = torch.mean((dl_dy[1:] * dy_da4),0)
dl_da4

tensor([-0.1110, -0.1162, -0.1116, -0.1273, -0.1317, -0.1161, -0.1181, -0.1287,
        -0.1172, -0.1223, -0.1208, -0.1153, -0.1211, -0.1230, -0.1168, -0.1235,
        -0.1193, -0.1232, -0.1154, -0.1216, -0.1237, -0.1120, -0.1221, -0.1297,
        -0.1170, -0.1210, -0.1198, -0.1193, -0.1261, -0.1284, -0.1221, -0.1191,
        -0.1217, -0.1116, -0.1246, -0.1214, -0.1166, -0.1202, -0.1295, -0.1123,
        -0.1150, -0.1147, -0.1208, -0.1211, -0.1183, -0.1104, -0.1173, -0.1195,
        -0.1099, -0.1180, -0.1189, -0.1081, -0.1190, -0.1168, -0.1164, -0.1175,
        -0.1159, -0.1243, -0.1194, -0.1187, -0.1189, -0.1192, -0.1180, -0.1326,
        -0.1210, -0.1058, -0.1214, -0.1104, -0.1265, -0.1183, -0.1192, -0.1257,
        -0.1143, -0.1174, -0.1213, -0.1165, -0.1231, -0.1065, -0.1135, -0.1147,
        -0.1211, -0.1195, -0.1160, -0.1184, -0.1120, -0.1158, -0.1220, -0.1181,
        -0.1270, -0.1232, -0.1236, -0.1175, -0.1106, -0.1217, -0.1196, -0.1197,
        -0.1242, -0.1196, -0.1181, -0.12

In [183]:
dl_da4.mean()/511

tensor(-0.1192, device='cuda:0', grad_fn=<DivBackward0>)

In [198]:
dl_da3

tensor(0.1242, device='cuda:0', grad_fn=<MeanBackward0>)

In [199]:
a4.grad/2

tensor([-0.1196], device='cuda:0')

In [202]:
a3.grad * 511

tensor([0.2699], device='cuda:0')

In [191]:
0.1242/0.0005

248.4

In [204]:
w1.grad

tensor([ 4.8353e-02, -2.1832e-02, -6.7726e-02,  4.3784e-02,  2.0491e-02,
         2.2088e-02, -4.1011e-02, -2.0388e-02, -2.6816e-02,  6.0689e-02,
        -2.7155e-02, -5.8485e-02,  3.8285e-02, -1.6583e-02,  2.6921e-02,
         5.4503e-02,  3.6822e-02, -2.1170e-02,  2.6921e-02, -2.8249e-02,
        -3.1844e-02,  1.0213e-02, -9.0833e-03,  4.4703e-02,  3.9033e-02,
        -5.1572e-02,  4.1289e-02,  1.5854e-02, -6.9290e-03,  1.8753e-02,
        -2.4627e-02, -3.4077e-02,  2.2680e-02,  7.3754e-02, -1.9768e-02,
        -9.0115e-04, -9.6327e-03,  3.1749e-02, -1.7784e-02, -6.0398e-03,
         2.5965e-02,  4.7157e-03, -6.5160e-03, -5.4560e-03,  4.7569e-03,
        -3.9051e-03,  7.6522e-03,  1.2075e-02,  1.5859e-02, -9.0265e-03,
        -2.6498e-02, -3.2532e-02,  2.9611e-02,  2.4680e-02,  3.0539e-03,
        -3.2995e-03, -2.1078e-02,  1.0367e-02, -2.7440e-02,  1.8663e-02,
        -4.7591e-03, -5.9654e-03,  1.6895e-02, -1.9860e-03, -1.1219e-03,
         2.3761e-02, -4.3094e-02, -5.0682e-02,  7.9

In [193]:
w1

tensor([-1.4524e-03,  1.3914e-02,  2.6110e-02,  5.1137e-02,  5.7489e-02,
         6.3424e-02,  8.8028e-02,  9.3941e-02,  7.5795e-02,  1.0031e-01,
         8.1919e-02,  1.0623e-01,  1.1273e-01,  1.1876e-01,  1.2497e-01,
         1.3113e-01,  1.4362e-01,  1.5592e-01,  1.6827e-01,  1.8005e-01,
         1.8673e-01,  1.9262e-01,  2.0473e-01,  2.1068e-01,  2.1706e-01,
         2.2296e-01,  2.2906e-01,  2.5316e-01,  2.3492e-01,  2.4714e-01,
         2.4084e-01,  2.6496e-01,  2.5898e-01,  2.7122e-01,  2.7660e-01,
         2.8789e-01,  2.8243e-01,  2.9324e-01,  2.9842e-01,  3.0359e-01,
         3.0874e-01,  3.1842e-01,  3.3233e-01,  3.1360e-01,  3.2322e-01,
         3.4127e-01,  3.3682e-01,  3.4574e-01,  3.2781e-01,  3.6703e-01,
         3.4996e-01,  3.8414e-01,  3.5427e-01,  3.6289e-01,  3.7983e-01,
         3.9299e-01,  3.5851e-01,  3.7568e-01,  4.0197e-01,  3.8858e-01,
         4.0680e-01,  3.7128e-01,  4.1154e-01,  3.9740e-01,  4.2205e-01,
         4.2762e-01,  4.1654e-01,  4.3905e-01,  4.3

In [40]:
w2 = torch.clone(w1).detach()
w2.requires_grad_(True)
w2

tensor([-1.4524e-03,  1.3914e-02,  2.6110e-02,  5.1137e-02,  5.7489e-02,
         6.3424e-02,  8.8028e-02,  9.3941e-02,  7.5795e-02,  1.0031e-01,
         8.1919e-02,  1.0623e-01,  1.1273e-01,  1.1876e-01,  1.2497e-01,
         1.3113e-01,  1.4362e-01,  1.5592e-01,  1.6827e-01,  1.8005e-01,
         1.8673e-01,  1.9262e-01,  2.0473e-01,  2.1068e-01,  2.1706e-01,
         2.2296e-01,  2.2906e-01,  2.5316e-01,  2.3492e-01,  2.4714e-01,
         2.4084e-01,  2.6496e-01,  2.5898e-01,  2.7122e-01,  2.7660e-01,
         2.8789e-01,  2.8243e-01,  2.9324e-01,  2.9842e-01,  3.0359e-01,
         3.0874e-01,  3.1842e-01,  3.3233e-01,  3.1360e-01,  3.2322e-01,
         3.4127e-01,  3.3682e-01,  3.4574e-01,  3.2781e-01,  3.6703e-01,
         3.4996e-01,  3.8414e-01,  3.5427e-01,  3.6289e-01,  3.7983e-01,
         3.9299e-01,  3.5851e-01,  3.7568e-01,  4.0197e-01,  3.8858e-01,
         4.0680e-01,  3.7128e-01,  4.1154e-01,  3.9740e-01,  4.2205e-01,
         4.2762e-01,  4.1654e-01,  4.3905e-01,  4.3

In [41]:
X1,y1 = forward_pass_dif(x_batch,a1,a2,a3,a4,w1,w2,dev)

In [46]:
X1

tensor([[-0.7498, -0.8620, -0.0225,  ..., -0.1736,  0.4348, -0.1198],
        [-0.2499,  0.5823, -0.5130,  ...,  0.5855,  1.1677, -0.0350],
        [-0.4920, -0.0222, -0.4505,  ..., -0.2605,  0.4984, -0.1551],
        ...,
        [-1.3402, -0.1130,  0.4768,  ...,  0.0270, -0.7364, -0.1691],
        [ 1.0679, -0.4533,  0.2394,  ...,  0.6994,  1.4049,  0.3380],
        [-1.3402, -0.1130,  0.4768,  ...,  0.0270, -0.7364, -0.1691]],
       device='cuda:0', grad_fn=<DivBackward0>)

In [25]:
X

tensor([[-0.7498, -0.8620, -0.0225,  ..., -0.1736,  0.4348, -0.1198],
        [-0.2499,  0.5823, -0.5130,  ...,  0.5855,  1.1677, -0.0350],
        [-0.4920, -0.0222, -0.4505,  ..., -0.2605,  0.4984, -0.1551],
        ...,
        [-1.3402, -0.1130,  0.4768,  ...,  0.0270, -0.7364, -0.1691],
        [ 1.0679, -0.4533,  0.2394,  ...,  0.6994,  1.4049,  0.3380],
        [-1.3402, -0.1130,  0.4768,  ...,  0.0270, -0.7364, -0.1691]],
       device='cuda:0', grad_fn=<DivBackward0>)

In [26]:
y1

tensor([[ 0.6552, -0.8300, -0.6165,  ...,  0.9054,  0.5810,  0.4090],
        [ 0.3385, -0.6028,  0.2263,  ...,  0.0966,  0.4237, -0.1138],
        [ 0.1991, -0.1402,  0.5351,  ...,  0.3229,  0.0932,  0.0115],
        ...,
        [-0.5201,  0.0823,  0.2022,  ...,  0.2409,  0.0792, -0.4491],
        [ 0.2720, -0.7335,  0.5273,  ...,  0.2191, -0.3254, -0.0637],
        [-0.5109,  0.3300,  0.2367,  ..., -0.5150, -0.2529,  0.1025]],
       device='cuda:0', grad_fn=<DivBackward0>)

In [27]:
y

tensor([[ 0.6552, -0.8300, -0.6165,  ...,  0.9054,  0.5810,  0.4090],
        [ 0.3385, -0.6028,  0.2263,  ...,  0.0966,  0.4237, -0.1138],
        [ 0.1991, -0.1402,  0.5351,  ...,  0.3229,  0.0932,  0.0115],
        ...,
        [-0.5201,  0.0823,  0.2022,  ...,  0.2409,  0.0792, -0.4491],
        [ 0.2720, -0.7335,  0.5273,  ...,  0.2191, -0.3254, -0.0637],
        [-0.5109,  0.3300,  0.2367,  ..., -0.5150, -0.2529,  0.1025]],
       device='cuda:0', grad_fn=<DivBackward0>)

In [42]:
ls = loss(x_batch, y1)
ls

tensor(0.0699, device='cuda:0', grad_fn=<MeanBackward0>)

In [47]:
X,y = forward_pass_id(x_batch,a1,a2,a3,a4,w1,dev)

In [48]:
ls1 = loss(x_batch, y)
ls1

tensor(0.0699, device='cuda:0', grad_fn=<MeanBackward0>)

In [49]:
ls1.backward()

In [50]:
w1.grad

tensor([-3.2269e-02, -1.6775e-02, -8.6333e-03,  6.0898e-02,  2.7650e-02,
        -2.6620e-02,  8.6567e-03, -4.7139e-02,  9.2869e-02,  5.0643e-02,
         5.5782e-02, -1.3405e-01,  8.8785e-02,  4.0628e-02, -4.5992e-02,
        -8.4136e-02,  1.6509e-02,  8.9708e-02, -2.3796e-02, -1.2901e-01,
         1.5366e-01, -5.4417e-02,  2.9598e-03, -4.0075e-02,  1.0697e-01,
        -1.0461e-02,  1.6068e-02,  4.5605e-02,  5.8366e-02, -3.6545e-02,
        -1.1060e-02,  1.2200e-02,  3.2817e-03,  4.3648e-02, -6.6355e-02,
         3.9191e-03, -2.2088e-02,  5.6015e-03,  3.1159e-02,  2.8556e-03,
         6.9526e-05, -3.9369e-02, -3.0371e-03,  3.5657e-02,  3.5211e-02,
        -1.2532e-02,  2.4644e-02,  2.9904e-03,  7.9798e-03, -2.1865e-02,
        -1.9660e-02, -2.2828e-02,  1.2928e-02, -4.3067e-03, -1.5276e-02,
        -1.1111e-02,  1.0466e-02,  3.9705e-02, -3.1349e-02,  1.8986e-02,
        -1.4962e-03, -8.1867e-04, -1.0493e-02, -4.3695e-03,  3.3145e-02,
         1.6788e-02, -4.9543e-02,  1.6041e-02, -7.5

In [43]:
ls.backward()

In [44]:
w1.grad

tensor([-9.4187e-03, -8.8009e-03, -6.4019e-03,  3.3779e-02,  1.8285e-02,
        -1.4840e-02,  4.7133e-03, -2.8404e-02,  5.2655e-02,  2.9327e-02,
         3.3072e-02, -7.9524e-02,  5.4245e-02,  2.4831e-02, -2.6935e-02,
        -4.7668e-02,  1.1047e-02,  5.0902e-02, -1.1984e-02, -7.8454e-02,
         9.0431e-02, -3.2866e-02,  2.6786e-03, -2.4339e-02,  6.3557e-02,
        -6.2083e-03,  9.8891e-03,  2.7597e-02,  3.4790e-02, -2.2362e-02,
        -7.0334e-03,  7.2461e-03,  1.4739e-03,  2.5491e-02, -4.0113e-02,
         1.6682e-03, -1.3106e-02,  3.0579e-03,  1.9101e-02,  1.8769e-03,
         3.4521e-04, -2.4048e-02, -1.9994e-03,  2.1364e-02,  2.0577e-02,
        -7.3541e-03,  1.5155e-02,  2.1160e-03,  4.2882e-03, -1.3057e-02,
        -1.1888e-02, -1.3530e-02,  7.3353e-03, -2.6826e-03, -9.3052e-03,
        -6.9142e-03,  6.1498e-03,  2.3509e-02, -1.7654e-02,  1.0989e-02,
        -7.5325e-04, -1.9375e-04, -6.4017e-03, -2.5433e-03,  1.9550e-02,
         9.6819e-03, -2.9558e-02,  9.6846e-03, -4.6

In [45]:
w2.grad

tensor([-3.6281e-02, -7.1465e-03,  1.9390e-03,  2.0458e-02,  4.4487e-04,
        -8.7193e-03,  3.1737e-03, -9.0673e-03,  2.7773e-02,  1.3306e-02,
         1.2349e-02, -2.9535e-02,  1.4836e-02,  6.7634e-03, -1.1179e-02,
        -2.5268e-02, -1.2340e-04,  2.6710e-02, -1.1641e-02, -2.2657e-02,
         3.6031e-02, -1.0236e-02, -2.1162e-03, -7.1345e-03,  2.3267e-02,
        -2.2976e-03,  2.4696e-03,  8.4192e-03,  1.2360e-02, -6.0040e-03,
        -1.0195e-03,  2.6623e-03,  2.1417e-03,  1.0823e-02, -1.2371e-02,
         2.8336e-03, -4.8587e-03,  2.0293e-03,  5.0149e-03,  8.0322e-05,
        -8.9657e-04, -6.5924e-03, -7.6106e-05,  7.2217e-03,  8.6908e-03,
        -3.0024e-03,  3.8220e-03, -3.6734e-04,  3.0952e-03, -4.5600e-03,
        -3.6560e-03, -5.0658e-03,  3.8507e-03, -5.6566e-04, -2.6363e-03,
        -1.4801e-03,  2.4822e-03,  8.8836e-03, -9.7366e-03,  5.0043e-03,
        -7.3257e-04, -1.0561e-03, -1.7809e-03, -1.1092e-03,  7.6410e-03,
         4.5294e-03, -1.0412e-02,  3.0288e-03, -1.3

#### test

In [21]:
x = torch.rand(3, requires_grad = True,device = dev)
x

tensor([0.4936, 0.6857, 0.4417], device='cuda:0', requires_grad=True)

In [22]:
y = x + 2
y

tensor([2.4936, 2.6857, 2.4417], device='cuda:0', grad_fn=<AddBackward0>)

In [23]:
z = y * y
z = z.mean()
z

tensor(6.4644, device='cuda:0', grad_fn=<MeanBackward0>)

In [26]:
if z < 10:
    print(z)

tensor(6.4644, device='cuda:0', grad_fn=<MeanBackward0>)


In [30]:
torch.save([x,z],'z.pt')

In [32]:
x,z = torch.load('z.pt')

In [34]:
z

tensor(6.4644, device='cuda:0', requires_grad=True)

In [17]:
l

tensor(0.3342, device='cuda:0', grad_fn=<MeanBackward0>)

In [22]:
a4.grad

  """Entry point for launching an IPython kernel.


In [18]:
l

tensor(0.1684, device='cuda:0', grad_fn=<MeanBackward0>)

In [35]:
b=0
x[:,b:b+batch_size].size()

torch.Size([512, 100])

In [40]:
np.minimum(3,4)

3

In [43]:
a1,a2,a3,a4,w1,w2 = random_init(M,dev)

In [44]:
a2

tensor([0.3381], device='cuda:0', requires_grad=True)

In [28]:
torch.randperm(5)

tensor([1, 0, 2, 3, 4])

In [30]:
a = torch.rand(4,5)
a

tensor([[0.6569, 0.0911, 0.5602, 0.0551, 0.7909],
        [0.3291, 0.7895, 0.5451, 0.7612, 0.4997],
        [0.7181, 0.5634, 0.7729, 0.7181, 0.1605],
        [0.0017, 0.2127, 0.3340, 0.1913, 0.4456]])

In [32]:
a[:,torch.randperm(5)]

tensor([[0.0911, 0.5602, 0.6569, 0.7909, 0.0551],
        [0.7895, 0.5451, 0.3291, 0.4997, 0.7612],
        [0.5634, 0.7729, 0.7181, 0.1605, 0.7181],
        [0.2127, 0.3340, 0.0017, 0.4456, 0.1913]])

In [24]:
l = torch.tensor([1,2,3,4,5])
l

tensor([1, 2, 3, 4, 5])

In [25]:
l.max()

tensor(5)

In [36]:
loss(x_batch,y).item()

0.10698503255844116