In [1]:
import torch

"""
 y = sum(a[i]*x[i])
 X - input signals, shape is (n, )
 y - output
 A - coefficients, shape is (n, )
"""
class Linear(torch.nn.Module):
    def __init__(self, input_size):
        super(Linear, self).__init__()
        self.coeffs = torch.nn.Parameter(torch.rand(input_size, requires_grad = True))

    def forward(self, x):
        return torch.matmul(self.coeffs, x)

In [2]:
model = Linear(5)

list(model.parameters())

[Parameter containing:
 tensor([0.7297, 0.8068, 0.6318, 0.8962, 0.9281], requires_grad=True)]

In [3]:
"""
   (f(x) - y)**2
   L = dot(A, X)
   if L > 0: A = -A => L => -L
   if L < 0: A = 10 * A
   
   f(X) = a1*x1 + a2*x2 + ...
   df/da = (x1, x2, ...)
"""

loss = model(torch.tensor([-0.1, 2.3, -1.5, 3.14, 0.0]))

loss.backward()

loss

tensor(3.6492, grad_fn=<DotBackward0>)

In [4]:
for param in model.parameters():
    print(param)
    print(param.grad)

Parameter containing:
tensor([0.7297, 0.8068, 0.6318, 0.8962, 0.9281], requires_grad=True)
tensor([-0.1000,  2.3000, -1.5000,  3.1400,  0.0000])


In [5]:
a = [0.0225, 0.0147, 0.8035, 0.7803, 0.2380]
x = [-0.1, 2.3, -1.5, 3.14, 0.0]
sum([a[i] * x[i] for i in range(len(a))])

1.2764520000000001

In [7]:
for _ in range(100):
    model.zero_grad()
    loss = model(torch.tensor([-0.1, 2.3, -1.5, 3.14, 0.0]))
    print(loss.item())
    loss.backward()
    for param in model.parameters():
        param.data -= param.grad
for param in model.parameters():
    print(param)
    print(param.grad)

-1737.31201171875
-1754.7215576171875
-1772.1312255859375
-1789.541015625
-1806.9505615234375
-1824.3602294921875
-1841.77001953125
-1859.1795654296875
-1876.589111328125
-1893.998779296875
-1911.408447265625
-1928.8182373046875
-1946.227783203125
-1963.637451171875
-1981.046875
-1998.4566650390625
-2015.8662109375
-2033.27587890625
-2050.685546875
-2068.094970703125
-2085.5048828125
-2102.914306640625
-2120.323974609375
-2137.733642578125
-2155.143310546875
-2172.552734375
-2189.96240234375
-2207.3720703125
-2224.78173828125
-2242.19140625
-2259.60107421875
-2277.010498046875
-2294.420166015625
-2311.829833984375
-2329.2392578125
-2346.64892578125
-2364.05859375
-2381.46826171875
-2398.8779296875
-2416.28759765625
-2433.697021484375
-2451.106689453125
-2468.516357421875
-2485.926025390625
-2503.33544921875
-2520.7451171875
-2538.15478515625
-2555.564453125
-2572.97412109375
-2590.3837890625
-2607.793212890625
-2625.202880859375
-2642.612548828125
-2660.022216796875
-2677.431640625
-26