In [1]:
import torch

In [2]:
x = torch.randn([1, 100, 3])
x

tensor([[[ 0.1886, -1.9313,  0.4641],
         [ 0.2014,  1.4163,  0.3644],
         [ 0.5420, -0.5078, -0.6814],
         [ 1.0743, -0.2754, -0.7268],
         [ 2.1418, -0.3118, -0.9626],
         [-0.6959, -0.5682,  0.6131],
         [ 0.1704,  0.3047,  0.0907],
         [-2.5671, -1.1497, -1.0126],
         [-1.5670,  0.4291,  0.3933],
         [ 0.4028,  0.3603, -1.1063],
         [-1.1847, -0.3725, -1.4269],
         [-1.5416,  0.0916, -1.4121],
         [-1.7325,  1.1106,  0.1059],
         [-0.2482,  0.0446, -0.0752],
         [ 1.0771,  0.2439, -0.1640],
         [-0.8642,  1.2327, -1.2407],
         [-1.1373, -0.7371, -0.9571],
         [ 0.2219, -0.5558, -0.6385],
         [-0.7771, -1.3614,  0.3535],
         [-0.5261, -0.7367,  0.5674],
         [ 0.1557, -2.3513,  1.4731],
         [ 2.2819,  1.8996, -0.3253],
         [-0.2503,  0.6653,  0.1656],
         [ 0.2045, -1.2054,  1.3440],
         [-0.7550,  1.2551, -1.4434],
         [ 1.2443,  0.3172, -1.8432],
         [-0

In [3]:
target = torch.randn([1, 100, 64])
target

tensor([[[-0.4358, -1.3479,  0.8358,  ...,  0.2478, -0.9059, -0.5684],
         [-1.6908, -1.0111, -0.2906,  ..., -0.5051,  0.2567, -1.9523],
         [ 0.1445, -0.0155,  0.4045,  ..., -1.7331, -0.4210,  0.2701],
         ...,
         [ 0.0789,  0.8596,  1.9881,  ..., -0.6076,  0.6359,  0.1348],
         [ 0.5301, -0.7586,  0.6715,  ...,  0.7232, -1.2076, -1.3507],
         [ 0.8830, -0.7228, -1.5907,  ..., -0.8787,  0.3828,  1.4236]]])

In [4]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(3, 64)  # w*x+b, w - weight, b - bias

    def forward(self, x):  # __call__ -> forward()
        x = self.fc1(x)
        return x

In [5]:
model = Net()

In [6]:
model(x)

tensor([[[ 1.1106,  0.3220,  0.4648,  ..., -0.5395, -0.0068, -0.2968],
         [-0.4575, -0.5705, -0.3382,  ...,  0.4390, -0.7644,  0.2331],
         [ 0.9963,  0.2057,  0.0534,  ..., -0.7729, -0.1326, -0.3513],
         ...,
         [ 1.7822, -0.4386, -0.0107,  ..., -0.2880,  0.2617, -1.4483],
         [ 1.7190, -0.0941,  0.0554,  ..., -0.6564,  0.2123, -1.1476],
         [-0.0122, -0.2823,  0.0839,  ...,  0.3198, -0.5111,  0.1383]]],
       grad_fn=<ViewBackward0>)

In [7]:
model(x).shape

torch.Size([1, 100, 64])

In [8]:
criterion = torch.nn.MSELoss()

In [9]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.0003)

In [12]:
output = model(x)

### Loss

In [13]:
loss = criterion(output, target)

In [15]:
print(loss.item())

1.4128917455673218


In [16]:
torch.sum((output - target)**2)/6400

tensor(1.4129, grad_fn=<DivBackward0>)

In [10]:
for epoch in range(10):
    output = model(x)
    loss = criterion(output, target)
    print('Epoch: ', epoch, 'Loss: ', loss.item())
    optimizer.zero_grad()  # w.grad, b.grad - reset to zero
    loss.backward()  # update w.grad, b.grad
    optimizer.step()

Epoch:  0 Loss:  1.4129834175109863
Epoch:  1 Loss:  1.412974238395691
Epoch:  2 Loss:  1.4129650592803955
Epoch:  3 Loss:  1.4129558801651
Epoch:  4 Loss:  1.4129468202590942
Epoch:  5 Loss:  1.4129375219345093
Epoch:  6 Loss:  1.4129283428192139
Epoch:  7 Loss:  1.4129191637039185
Epoch:  8 Loss:  1.412909984588623
Epoch:  9 Loss:  1.4129008054733276


In [11]:
list(model.parameters())

[Parameter containing:
 tensor([[-0.1108, -0.4839, -0.5358],
         [ 0.4631, -0.2711, -0.0925],
         [ 0.0186, -0.2379,  0.0693],
         [-0.1824, -0.3259,  0.4563],
         [-0.2399,  0.2391,  0.2108],
         [ 0.4658,  0.0530,  0.3135],
         [-0.2289,  0.4873, -0.2183],
         [ 0.0676, -0.2338, -0.5768],
         [-0.4936,  0.2793,  0.2278],
         [ 0.0121,  0.4278,  0.5225],
         [ 0.4636,  0.0181,  0.3542],
         [-0.0743, -0.1652,  0.4555],
         [-0.5748,  0.3301, -0.1855],
         [ 0.4070, -0.0157,  0.2231],
         [-0.5392, -0.0966,  0.5703],
         [-0.5713, -0.4461, -0.3337],
         [-0.5467,  0.4507, -0.3565],
         [-0.4622, -0.1913,  0.3445],
         [-0.0410,  0.0660, -0.3433],
         [-0.1777, -0.3218,  0.5148],
         [ 0.1953,  0.1761, -0.2110],
         [ 0.0536, -0.4375,  0.1953],
         [ 0.1947,  0.0492,  0.4788],
         [-0.2869, -0.0300,  0.2454],
         [-0.4383,  0.0014,  0.4209],
         [ 0.3826,  0.1925,