# simple pytorch example

In [1]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 663.2269897460938
1 646.5026245117188
2 630.2738647460938
3 614.476318359375
4 599.1412353515625
5 584.2478637695312
6 569.7598266601562
7 555.6636962890625
8 541.8959350585938
9 528.5098876953125
10 515.5949096679688
11 503.08795166015625
12 490.90863037109375
13 479.0406799316406
14 467.5443420410156
15 456.4287109375
16 445.643798828125
17 435.234130859375
18 425.07269287109375
19 415.1888122558594
20 405.59307861328125
21 396.2909240722656
22 387.2802734375
23 378.4934997558594
24 369.93902587890625
25 361.5513000488281
26 353.3130187988281
27 345.22259521484375
28 337.3291015625
29 329.65618896484375
30 322.1756286621094
31 314.8938293457031
32 307.7514953613281
33 300.7596435546875
34 293.9186706542969
35 287.21881103515625
36 280.6751708984375
37 274.26251220703125
38 268.01416015625
39 261.8861999511719
40 255.89224243164062
41 250.00636291503906
42 244.23414611816406
43 238.6081085205078
44 233.11363220214844
45 227.7243194580078
46 222.41763305664062
47 217.20574951171875
4

365 2.073223004117608e-05
366 1.9374527255422436e-05
367 1.810719732020516e-05
368 1.6922049326240085e-05
369 1.581281321705319e-05
370 1.4777931937715039e-05
371 1.38080531542073e-05
372 1.2901261470688041e-05
373 1.205408352689119e-05
374 1.1263200576649979e-05
375 1.0523825039854273e-05
376 9.830088856688235e-06
377 9.183500878862105e-06
378 8.577982043789234e-06
379 8.01282385509694e-06
380 7.482412001991179e-06
381 6.98920348440879e-06
382 6.526671313622501e-06
383 6.0928041420993395e-06
384 5.68868608752382e-06
385 5.310783762979554e-06
386 4.957795226800954e-06
387 4.62805337519967e-06
388 4.318641913414467e-06
389 4.0302052184415516e-06
390 3.7608992897730786e-06
391 3.509131602186244e-06
392 3.2738860227254918e-06
393 3.052879719689372e-06
394 2.84745169665257e-06
395 2.654951231306768e-06
396 2.4764274257904617e-06
397 2.3084517124516424e-06
398 2.151458147636731e-06
399 2.00540421246842e-06
400 1.868811295935302e-06
401 1.7413291288903565e-06
402 1.6223978036578046e-06
403 1