Source: https://pytorch.org/tutorials/beginner/pytorch_with_examples.html

In [110]:
import torch

dtype = torch.float
device = torch.device("cpu")
device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H1, H2, D_out = 1, 2, 100, 100, 10

# Create random Tensors to hold input and outputs.
# Setting requires_grad=False indicates that we do not need to compute gradients
# with respect to these Tensors during the backward pass.
x = torch.randn(N, D_in, device=device, dtype=dtype, requires_grad=False)
y = torch.randn(N, D_out, device=device, dtype=dtype, requires_grad=False)

# Create random Tensors for weights.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
w1 = torch.randn(D_in, H1, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H1, H2, device=device, dtype=dtype, requires_grad=True)
w3 = torch.randn(H2, D_out, device=device, dtype=dtype, requires_grad=True)


#learning_rate = 1e-6
learning_rate = 0.0001
for t in range(5000):
    # Forward pass: compute predicted y using operations on Tensors; these
    # are exactly the same operations we used to compute the forward pass using
    # Tensors, but we do not need to keep references to intermediate values since
    # we are not implementing the backward pass by hand.
    y_pred = x.mm(w1).clamp(min=0).mm(w2).clamp(min=0).mm(w3)

    # Compute and print loss using operations on Tensors.
    # Now loss is a Tensor of shape (1,)
    # loss.item() gets the a scalar value held in the loss.
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass. This call will compute the
    # gradient of loss with respect to all Tensors with requires_grad=True.
    # After this call w1.grad and w2.grad will be Tensors holding the gradient
    # of the loss with respect to w1 and w2 respectively.
    loss.backward()
    # minimizo y_pred, no loss:
    #y_sum = y_pred.sum()
    #y_sum.backward()
    #print(t, y_sum.item())

    # Manually update weights using gradient descent. Wrap in torch.no_grad()
    # because weights have requires_grad=True, but we don't need to track this
    # in autograd.
    # An alternative way is to operate on weight.data and weight.grad.data.
    # Recall that tensor.data gives a tensor that shares the storage with
    # tensor, but doesn't track history.
    # You can also use torch.optim.SGD to achieve this.
    #learning_rate = learning_rate * 1.0/(t+1)
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad
        w3 -= learning_rate * w3.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()
        w3.grad.zero_()

0 47193.4296875
1 48124.01171875
2 42706.234375
3 3826.84716796875
4 668.5179443359375
5 248.9574432373047
6 138.90078735351562
7 83.72297668457031
8 51.529258728027344
9 32.189083099365234
10 20.276185989379883
11 13.020116806030273
12 8.346830368041992
13 5.458409786224365
14 3.5204849243164062
15 2.300955057144165
16 1.5174181461334229
17 0.9864927530288696
18 0.6432247757911682
19 0.4299500286579132
20 0.2808133363723755
21 0.18326477706432343
22 0.12127506732940674
23 0.08041305840015411
24 0.05258965492248535
25 0.0344645231962204
26 0.02290826290845871
27 0.015173491090536118
28 0.009937651455402374
29 0.006521172821521759
30 0.0043366593308746815
31 0.002876451937481761
32 0.0018858204130083323
33 0.001238411059603095
34 0.0008222288452088833
35 0.0005472780903801322
36 0.0003590115229599178
37 0.00023590491036884487
38 0.00015614206495229155
39 0.00010442701750434935
40 6.854634557384998e-05
41 4.507265839492902e-05
42 2.964250961667858e-05
43 1.9926246750401333e-05
44 1.31312

365 2.5048407792382932e-11
366 2.5048407792382932e-11
367 2.5048407792382932e-11
368 2.5048407792382932e-11
369 2.5048407792382932e-11
370 2.5048407792382932e-11
371 2.5048407792382932e-11
372 2.5048407792382932e-11
373 2.5048407792382932e-11
374 2.5048407792382932e-11
375 2.5048407792382932e-11
376 2.5048407792382932e-11
377 2.5048407792382932e-11
378 2.5048407792382932e-11
379 2.5048407792382932e-11
380 2.5048407792382932e-11
381 2.5048407792382932e-11
382 2.5048407792382932e-11
383 2.5048407792382932e-11
384 2.5048407792382932e-11
385 2.5048407792382932e-11
386 2.5048407792382932e-11
387 2.5048407792382932e-11
388 2.5048407792382932e-11
389 2.5048407792382932e-11
390 2.5048407792382932e-11
391 2.5048407792382932e-11
392 2.5048407792382932e-11
393 2.5048407792382932e-11
394 2.5048407792382932e-11
395 2.5048407792382932e-11
396 2.5048407792382932e-11
397 2.177991120788647e-11
398 1.5981882484084053e-11
399 1.432276519608422e-11
400 1.432276519608422e-11
401 1.432276519608422e-11
402 1

725 1.4948042803553108e-11
726 1.4948042803553108e-11
727 1.4948042803553108e-11
728 1.4948042803553108e-11
729 1.4948042803553108e-11
730 1.4948042803553108e-11
731 1.4948042803553108e-11
732 1.4948042803553108e-11
733 1.4948042803553108e-11
734 1.4948042803553108e-11
735 1.4948042803553108e-11
736 1.4948042803553108e-11
737 1.4948042803553108e-11
738 1.4948042803553108e-11
739 1.4948042803553108e-11
740 1.4948042803553108e-11
741 1.4948042803553108e-11
742 1.4948042803553108e-11
743 1.4948042803553108e-11
744 1.4948042803553108e-11
745 1.4948042803553108e-11
746 1.4948042803553108e-11
747 1.4948042803553108e-11
748 1.4948042803553108e-11
749 1.4948042803553108e-11
750 1.4948042803553108e-11
751 1.4948042803553108e-11
752 1.4948042803553108e-11
753 1.4948042803553108e-11
754 1.4948042803553108e-11
755 1.4948042803553108e-11
756 1.4948042803553108e-11
757 1.4948042803553108e-11
758 1.4948042803553108e-11
759 1.4948042803553108e-11
760 1.4948042803553108e-11
761 1.4948042803553108e-11
7

1102 1.4948042803553108e-11
1103 1.4948042803553108e-11
1104 1.4948042803553108e-11
1105 1.4948042803553108e-11
1106 1.4948042803553108e-11
1107 1.4948042803553108e-11
1108 1.4948042803553108e-11
1109 1.4948042803553108e-11
1110 1.4948042803553108e-11
1111 1.4948042803553108e-11
1112 1.4948042803553108e-11
1113 1.4948042803553108e-11
1114 1.4948042803553108e-11
1115 1.4948042803553108e-11
1116 1.4948042803553108e-11
1117 1.4948042803553108e-11
1118 1.4948042803553108e-11
1119 1.4948042803553108e-11
1120 1.4948042803553108e-11
1121 1.4948042803553108e-11
1122 1.4948042803553108e-11
1123 1.4948042803553108e-11
1124 1.4948042803553108e-11
1125 1.4948042803553108e-11
1126 1.4948042803553108e-11
1127 1.4948042803553108e-11
1128 1.4948042803553108e-11
1129 1.4948042803553108e-11
1130 1.4948042803553108e-11
1131 1.4948042803553108e-11
1132 1.4948042803553108e-11
1133 1.4948042803553108e-11
1134 1.4948042803553108e-11
1135 1.4948042803553108e-11
1136 1.4948042803553108e-11
1137 1.4948042803553

1452 1.3093526263219246e-11
1453 1.1047163184230158e-11
1454 9.682921131570765e-12
1455 1.9346302337908128e-11
1456 1.3093526263219246e-11
1457 1.1047163184230158e-11
1458 2.2870594307278225e-11
1459 1.9346302337908128e-11
1460 1.3093526263219246e-11
1461 1.2919443292958022e-11
1462 1.1047163184230158e-11
1463 9.682921131570765e-12
1464 1.9346302337908128e-11
1465 1.3093526263219246e-11
1466 1.1047163184230158e-11
1467 2.2870594307278225e-11
1468 1.6930457036323787e-11
1469 1.658939652315894e-11
1470 1.1047163184230158e-11
1471 9.682921131570765e-12
1472 2.488853567683691e-11
1473 1.658939652315894e-11
1474 1.658939652315894e-11
1475 1.1047163184230158e-11
1476 3.929834235805174e-11
1477 2.1023183194301964e-11
1478 1.658939652315894e-11
1479 1.658939652315894e-11
1480 1.1047163184230158e-11
1481 5.589662066540768e-11
1482 1.667466165145015e-11
1483 2.3524293624177517e-11
1484 1.786837344752712e-11
1485 1.8834711568160856e-11
1486 1.8834711568160856e-11
1487 2.0820678514610336e-11
1488 

1849 1.1683098932735447e-11
1850 1.1683098932735447e-11
1851 1.1683098932735447e-11
1852 1.1683098932735447e-11
1853 1.1683098932735447e-11
1854 1.1683098932735447e-11
1855 1.1683098932735447e-11
1856 1.1683098932735447e-11
1857 1.1683098932735447e-11
1858 1.1683098932735447e-11
1859 1.1683098932735447e-11
1860 1.1683098932735447e-11
1861 1.1683098932735447e-11
1862 1.1683098932735447e-11
1863 1.1683098932735447e-11
1864 1.1683098932735447e-11
1865 1.1683098932735447e-11
1866 1.1683098932735447e-11
1867 1.1683098932735447e-11
1868 1.1683098932735447e-11
1869 1.1683098932735447e-11
1870 1.1683098932735447e-11
1871 1.1683098932735447e-11
1872 1.1683098932735447e-11
1873 1.1683098932735447e-11
1874 1.1683098932735447e-11
1875 1.1683098932735447e-11
1876 1.1683098932735447e-11
1877 1.1683098932735447e-11
1878 1.1683098932735447e-11
1879 1.1683098932735447e-11
1880 1.1683098932735447e-11
1881 1.1683098932735447e-11
1882 1.1683098932735447e-11
1883 1.1683098932735447e-11
1884 1.1683098932735

2227 8.194334100153355e-12
2228 1.092281820547214e-11
2229 8.194334100153355e-12
2230 1.092281820547214e-11
2231 8.194334100153355e-12
2232 1.092281820547214e-11
2233 8.194334100153355e-12
2234 1.092281820547214e-11
2235 8.194334100153355e-12
2236 1.092281820547214e-11
2237 8.194334100153355e-12
2238 1.092281820547214e-11
2239 8.194334100153355e-12
2240 1.092281820547214e-11
2241 1.20596865826883e-11
2242 1.092281820547214e-11
2243 1.3651302310790925e-11
2244 2.5019986082952528e-11
2245 1.4560797012563853e-11
2246 2.5019986082952528e-11
2247 2.396838283402758e-11
2248 6.091127602303459e-12
2249 1.0809131367750524e-11
2250 2.106936847212637e-11
2251 2.3996804543457984e-11
2252 2.354205719257152e-11
2253 2.956745959181717e-11
2254 3.269384762916161e-11
2255 3.2807534466883226e-11
2256 1.0183853760281636e-11
2257 1.0240697179142444e-11
2258 1.524291803889355e-11
2259 1.6436629834970518e-11
2260 1.7289281117882638e-11
2261 1.7687185049908294e-11
2262 9.615419571673556e-12
2263 7.2848393983

2604 1.2372325386422744e-11
2605 7.853273586988507e-12
2606 5.892175636290631e-12
2607 5.892175636290631e-12
2608 1.0524914273446484e-11
2609 1.2372325386422744e-11
2610 6.545874953189923e-12
2611 1.0240697179142444e-11
2612 9.871214956547192e-12
2613 1.3224976669334865e-11
2614 1.1150191880915372e-11
2615 2.3741009158584347e-11
2616 1.723243769902183e-11
2617 1.1150191880915372e-11
2618 1.7601919921617082e-11
2619 1.2372325386422744e-11
2620 7.853273586988507e-12
2621 5.892175636290631e-12
2622 5.892175636290631e-12
2623 1.1381118270037405e-11
2624 1.2372325386422744e-11
2625 5.440980999082967e-12
2626 6.3504757008558954e-12
2627 5.440980999082967e-12
2628 6.3504757008558954e-12
2629 5.440980999082967e-12
2630 6.3504757008558954e-12
2631 5.440980999082967e-12
2632 6.3504757008558954e-12
2633 5.440980999082967e-12
2634 6.3504757008558954e-12
2635 5.440980999082967e-12
2636 6.3504757008558954e-12
2637 5.440980999082967e-12
2638 6.3504757008558954e-12
2639 5.440980999082967e-12
2640 6.35

2981 8.986589250525867e-12
2982 8.986589250525867e-12
2983 8.986589250525867e-12
2984 8.986589250525867e-12
2985 8.986589250525867e-12
2986 8.986589250525867e-12
2987 8.986589250525867e-12
2988 8.986589250525867e-12
2989 8.986589250525867e-12
2990 8.986589250525867e-12
2991 8.986589250525867e-12
2992 8.986589250525867e-12
2993 8.986589250525867e-12
2994 8.986589250525867e-12
2995 8.986589250525867e-12
2996 8.986589250525867e-12
2997 8.986589250525867e-12
2998 8.986589250525867e-12
2999 8.986589250525867e-12
3000 1.2169820706731116e-11
3001 8.986589250525867e-12
3002 8.986589250525867e-12
3003 1.2169820706731116e-11
3004 1.148769968040142e-11
3005 1.2397194382174348e-11
3006 1.148769968040142e-11
3007 1.2397194382174348e-11
3008 7.394973522423243e-12
3009 6.144418307485466e-12
3010 9.213962925969099e-12
3011 4.211742066217994e-12
3012 8.72724115197343e-12
3013 4.211742066217994e-12
3014 9.068301665138279e-12
3015 1.6912693467929785e-11
3016 1.546318628697918e-11
3017 1.7481127656537865e

3387 2.6576074674267147e-11
3388 1.9872103962370602e-11
3389 2.6576074674267147e-11
3390 1.787903158856352e-11
3391 2.6576074674267147e-11
3392 1.4667378422927868e-11
3393 2.737188253831846e-11
3394 1.4667378422927868e-11
3395 2.1801227489959274e-11
3396 1.5974777056726452e-11
3397 1.674216321134736e-11
3398 1.850430919603241e-11
3399 2.2312818259706546e-11
3400 1.3359979789129284e-11
3401 2.2312818259706546e-11
3402 1.70299330193302e-11
3403 2.737188253831846e-11
3404 1.2081002864761103e-11
3405 1.9811707829830993e-11
3406 1.70299330193302e-11
3407 3.169198237173987e-11
3408 2.970246271161159e-11
3409 2.163069723337685e-11
3410 1.6912693467929785e-11
3411 2.1374901848503214e-11
3412 1.9470647316666145e-11
3413 1.836220064888039e-11
3414 2.2540191935149778e-11
3415 1.2308376540204335e-11
3416 9.125145083999087e-12
3417 1.9555912444957357e-11
3418 9.409362178303127e-12
3419 9.409362178303127e-12
3420 1.9555912444957357e-11
3421 9.409362178303127e-12
3422 1.0546230555519287e-11
3423 1.53

3765 9.409362178303127e-12
3766 9.409362178303127e-12
3767 9.409362178303127e-12
3768 9.409362178303127e-12
3769 9.409362178303127e-12
3770 9.409362178303127e-12
3771 9.409362178303127e-12
3772 9.409362178303127e-12
3773 9.409362178303127e-12
3774 9.409362178303127e-12
3775 9.409362178303127e-12
3776 9.409362178303127e-12
3777 9.409362178303127e-12
3778 9.409362178303127e-12
3779 9.409362178303127e-12
3780 9.409362178303127e-12
3781 9.409362178303127e-12
3782 9.409362178303127e-12
3783 9.409362178303127e-12
3784 9.409362178303127e-12
3785 9.409362178303127e-12
3786 9.409362178303127e-12
3787 1.375788372115494e-11
3788 1.2592593634508376e-11
3789 1.6969536886790593e-11
3790 1.5718981671852816e-11
3791 2.452971159527806e-11
3792 2.0721202531603922e-11
3793 2.703082202515361e-11
3794 2.0721202531603922e-11
3795 2.703082202515361e-11
3796 2.5840662942755444e-11
3797 1.5125678487493133e-11
3798 2.4362734052374435e-11
3799 7.110756428119203e-12
3800 7.110756428119203e-12
3801 7.1107564281192

4149 5.377032152864558e-12
4150 6.8549610432455665e-12
4151 5.377032152864558e-12
4152 5.377032152864558e-12
4153 6.8549610432455665e-12
4154 5.377032152864558e-12
4155 5.377032152864558e-12
4156 6.8549610432455665e-12
4157 5.377032152864558e-12
4158 5.377032152864558e-12
4159 6.8549610432455665e-12
4160 5.377032152864558e-12
4161 5.377032152864558e-12
4162 6.8549610432455665e-12
4163 5.377032152864558e-12
4164 5.377032152864558e-12
4165 6.8549610432455665e-12
4166 5.377032152864558e-12
4167 5.377032152864558e-12
4168 6.8549610432455665e-12
4169 5.377032152864558e-12
4170 5.377032152864558e-12
4171 6.8549610432455665e-12
4172 5.377032152864558e-12
4173 5.377032152864558e-12
4174 6.8549610432455665e-12
4175 5.377032152864558e-12
4176 5.377032152864558e-12
4177 6.8549610432455665e-12
4178 5.377032152864558e-12
4179 5.377032152864558e-12
4180 6.8549610432455665e-12
4181 5.377032152864558e-12
4182 5.377032152864558e-12
4183 6.8549610432455665e-12
4184 5.377032152864558e-12
4185 5.377032152

4518 5.377032152864558e-12
4519 6.8549610432455665e-12
4520 5.377032152864558e-12
4521 5.377032152864558e-12
4522 6.8549610432455665e-12
4523 5.377032152864558e-12
4524 5.377032152864558e-12
4525 6.8549610432455665e-12
4526 5.377032152864558e-12
4527 5.377032152864558e-12
4528 6.8549610432455665e-12
4529 5.377032152864558e-12
4530 5.377032152864558e-12
4531 6.8549610432455665e-12
4532 5.377032152864558e-12
4533 5.377032152864558e-12
4534 6.8549610432455665e-12
4535 5.377032152864558e-12
4536 5.377032152864558e-12
4537 6.8549610432455665e-12
4538 5.377032152864558e-12
4539 5.377032152864558e-12
4540 6.8549610432455665e-12
4541 5.377032152864558e-12
4542 5.377032152864558e-12
4543 6.8549610432455665e-12
4544 5.377032152864558e-12
4545 5.377032152864558e-12
4546 6.8549610432455665e-12
4547 5.377032152864558e-12
4548 5.377032152864558e-12
4549 6.8549610432455665e-12
4550 5.377032152864558e-12
4551 5.377032152864558e-12
4552 6.8549610432455665e-12
4553 5.377032152864558e-12
4554 5.377032152

4860 9.626077712709957e-12
4861 9.597656003279553e-12
4862 9.626077712709957e-12
4863 9.597656003279553e-12
4864 9.626077712709957e-12
4865 9.597656003279553e-12
4866 9.626077712709957e-12
4867 9.597656003279553e-12
4868 9.626077712709957e-12
4869 9.597656003279553e-12
4870 9.626077712709957e-12
4871 9.597656003279553e-12
4872 9.626077712709957e-12
4873 9.597656003279553e-12
4874 9.626077712709957e-12
4875 9.597656003279553e-12
4876 9.626077712709957e-12
4877 9.597656003279553e-12
4878 9.626077712709957e-12
4879 9.597656003279553e-12
4880 9.626077712709957e-12
4881 9.597656003279553e-12
4882 9.626077712709957e-12
4883 9.597656003279553e-12
4884 9.626077712709957e-12
4885 9.597656003279553e-12
4886 9.626077712709957e-12
4887 9.597656003279553e-12
4888 9.626077712709957e-12
4889 9.597656003279553e-12
4890 9.626077712709957e-12
4891 9.597656003279553e-12
4892 9.626077712709957e-12
4893 9.597656003279553e-12
4894 9.626077712709957e-12
4895 9.597656003279553e-12
4896 9.626077712709957e-12
4

Lo mismo pero definiendo una red bien bonita y trabajando sobre ella después.

Se observa una velocidad de ejecución de aprox 20% con respecto a la anterior. 

In [120]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        #self.conv1 = nn.Conv2d(1, 6, 3)
        #self.conv2 = nn.Conv2d(6, 16, 3)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(2, 100)  # 6*6 from image dimension
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 10)


    def forward(self, x):
        # Max pooling over a (2, 2) window
        #x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        #x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        #x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

Net(
  (fc1): Linear(in_features=2, out_features=100, bias=True)
  (fc2): Linear(in_features=100, out_features=100, bias=True)
  (fc3): Linear(in_features=100, out_features=10, bias=True)
)


In [121]:
params = list(net.parameters())
print(len(params))
print(params[0].size())  # conv1's .weight

6
torch.Size([100, 2])


In [122]:
input = torch.randn(2)
out = net(input)
print(out)


tensor([-0.1499, -0.1285,  0.2043, -0.0272, -0.3382, -0.1479, -0.0184,  0.0513,
        -0.0155, -0.1450], grad_fn=<AddBackward0>)


In [123]:
net.zero_grad()
out.backward(torch.randn(1, 10))

In [124]:
output = net(input)
target = torch.randn(10)  # a dummy target, for example
target = target.view(1, -1)  # make it the same shape as output
criterion = nn.MSELoss()

loss = criterion(output, target)
print(loss)


tensor(0.4094, grad_fn=<MseLossBackward>)


In [125]:
print(loss.grad_fn)  # MSELoss
print(loss.grad_fn.next_functions[0][0])  # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])  # ReLU

<MseLossBackward object at 0x7fa425d9dc18>
<ExpandBackward object at 0x7fa425d9d400>
<AddBackward0 object at 0x7fa425d9dc18>


In [126]:
net.zero_grad()     # zeroes the gradient buffers of all parameters

print('conv1.bias.grad before backward')
print(net.fc1.bias.grad)

loss.backward()

print('conv1.bias.grad after backward')
print(net.fc1.bias.grad)

conv1.bias.grad before backward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.])
conv1.bias.grad after backward
tensor([ 1.0078e-03,  0.0000e+00,  0.0000e+00, -8.1420e-03,  0.0000e+00,
         0.0000e+00,  0.0000e+00, -2.7867e-03,  0.0000e+00,  4.1753e-03,
         0.0000e+00,  1.9151e-02,  1.8808e-02,  0.0000e+00,  0.0000e+00,
        -3.8877e-03,  9.8137e-03,  0.0000e+00,  0.0000e+00,  7.5190e-03,
         0.0000e+00,  0.0000e+00,  0.0000e+00, -2.2943e-03,  4.7559e-04,
         4.2798e-03, -6.7302e-03, -7.1130e-05,  0.0000e+00, -1.0452e-02,
         0.0000e+00,  0.0000e+00, -5.0372e-03,  0.0000e+00

In [127]:
learning_rate = 0.01
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)


In [128]:
import torch.optim as optim

# create your optimizer
learning_rate = 1e-6
#optimizer = optim.SGD(net.parameters(), lr=0.01)
optimizer = optim.SGD(net.parameters(), lr=0.01)

# in your training loop:
for t in range(5000):
    optimizer.zero_grad()   # zero the gradient buffers
    output = net(input)
    loss = criterion(output, target)
    loss.backward()
    print(t, loss.item())
    optimizer.step()    # Does the update


0 0.39646151661872864
1 0.3840448558330536
2 0.37213268876075745
3 0.36068812012672424
4 0.34967777132987976
5 0.33907151222229004
6 0.3288422226905823
7 0.3189754784107208
8 0.3094472289085388
9 0.30022740364074707
10 0.2912974953651428
11 0.28268781304359436
12 0.2743530869483948
13 0.2663406729698181
14 0.2586561441421509
15 0.2511778473854065
16 0.24389609694480896
17 0.23680195212364197
18 0.22988736629486084
19 0.2231719046831131
20 0.21666719019412994
21 0.21031782031059265
22 0.2041182965040207
23 0.1980636715888977
24 0.19214974343776703
25 0.1863735467195511
26 0.18074223399162292
27 0.1752389520406723
28 0.1698608100414276
29 0.16460493206977844
30 0.15946874022483826
31 0.15444979071617126
32 0.1495458483695984
33 0.14475490152835846
34 0.14007627964019775
35 0.13552019000053406
36 0.13107091188430786
37 0.1267269104719162
38 0.12248676270246506
39 0.1183491051197052
40 0.11431258916854858
41 0.11037600785493851
42 0.10660909116268158
43 0.10294225811958313
44 0.09936618804

548 4.2244551606840375e-11
549 4.049763302482745e-11
550 3.917017404875267e-11
551 3.7834426563909673e-11
552 3.628211620032573e-11
553 3.4728022541008485e-11
554 3.375860008203446e-11
555 3.246639068699153e-11
556 3.1347414247706595e-11
557 3.01439775918233e-11
558 2.913934024628695e-11
559 2.8003379129448724e-11
560 2.7066095892314124e-11
561 2.6042846698604016e-11
562 2.5161008693208586e-11
563 2.4280050192615477e-11
564 2.348363691007105e-11
565 2.2548764938568056e-11
566 2.1809049358112453e-11
567 2.104574500783052e-11
568 2.025028061902745e-11
569 1.96382146655516e-11
570 1.8872859872121062e-11
571 1.8247138175442323e-11
572 1.7633273313721887e-11
573 1.6985196235608235e-11
574 1.6364049004735648e-11
575 1.5859914076488124e-11
576 1.5265512812168147e-11
577 1.4827051045274153e-11
578 1.4240680683286211e-11
579 1.3929973494142267e-11
580 1.341515960817663e-11
581 1.2960778283943597e-11
582 1.2510899434492462e-11
583 1.2163393556252622e-11
584 1.1719791803699309e-11
585 1.131542082

1072 1.4453577766061765e-13
1073 1.4453577766061765e-13
1074 1.435809750174183e-13
1075 1.435809750174183e-13
1076 1.435809750174183e-13
1077 1.435809750174183e-13
1078 1.4260398688726444e-13
1079 1.4260398688726444e-13
1080 1.4260398688726444e-13
1081 1.4260398688726444e-13
1082 1.4260398688726444e-13
1083 1.4260398688726444e-13
1084 1.4260398688726444e-13
1085 1.4260398688726444e-13
1086 1.4260398688726444e-13
1087 1.4260398688726444e-13
1088 1.4230421853909936e-13
1089 1.4230421853909936e-13
1090 1.4230421853909936e-13
1091 1.413938410799176e-13
1092 1.413938410799176e-13
1093 1.413938410799176e-13
1094 1.413938410799176e-13
1095 1.413938410799176e-13
1096 1.413938410799176e-13
1097 1.413938410799176e-13
1098 1.413938410799176e-13
1099 1.413938410799176e-13
1100 1.4124812430793554e-13
1101 1.4124812430793554e-13
1102 1.4124812430793554e-13
1103 1.4124812430793554e-13
1104 1.4124812430793554e-13
1105 1.4124812430793554e-13
1106 1.4038214492771706e-13
1107 1.4038214492771706e-13
1108 

1622 8.537198996783762e-14
1623 8.537198996783762e-14
1624 8.537198996783762e-14
1625 8.537198996783762e-14
1626 8.537198996783762e-14
1627 8.537198996783762e-14
1628 8.537198996783762e-14
1629 8.537198996783762e-14
1630 8.537198996783762e-14
1631 8.537198996783762e-14
1632 8.529843362669806e-14
1633 8.529843362669806e-14
1634 8.029688025601445e-14
1635 8.029688025601445e-14
1636 8.029688025601445e-14
1637 8.029688025601445e-14
1638 8.029688025601445e-14
1639 8.029688025601445e-14
1640 8.029688025601445e-14
1641 8.029688025601445e-14
1642 8.029688025601445e-14
1643 8.029688025601445e-14
1644 8.029688025601445e-14
1645 8.029688025601445e-14
1646 7.9954096186656e-14
1647 7.9954096186656e-14
1648 7.9954096186656e-14
1649 7.9954096186656e-14
1650 7.9954096186656e-14
1651 7.9954096186656e-14
1652 7.9954096186656e-14
1653 7.9954096186656e-14
1654 7.9954096186656e-14
1655 7.9954096186656e-14
1656 7.9954096186656e-14
1657 7.9954096186656e-14
1658 7.9954096186656e-14
1659 7.9954096186656e-14
16

2150 5.642708522657358e-14
2151 5.642708522657358e-14
2152 5.642708522657358e-14
2153 5.642708522657358e-14
2154 5.638406272911664e-14
2155 5.638406272911664e-14
2156 5.638406272911664e-14
2157 5.638406272911664e-14
2158 5.638406272911664e-14
2159 5.638406272911664e-14
2160 5.638406272911664e-14
2161 5.638406272911664e-14
2162 5.638406272911664e-14
2163 5.638406272911664e-14
2164 5.638406272911664e-14
2165 5.638406272911664e-14
2166 5.638406272911664e-14
2167 5.638406272911664e-14
2168 5.638406272911664e-14
2169 5.638406272911664e-14
2170 5.638406272911664e-14
2171 5.638406272911664e-14
2172 5.638406272911664e-14
2173 5.638406272911664e-14
2174 5.638406272911664e-14
2175 5.638406272911664e-14
2176 5.638406272911664e-14
2177 5.638406272911664e-14
2178 5.638406272911664e-14
2179 5.638406272911664e-14
2180 5.638406272911664e-14
2181 5.638406272911664e-14
2182 5.638406272911664e-14
2183 5.922623773791519e-14
2184 5.922623773791519e-14
2185 5.922623773791519e-14
2186 5.922623773791519e-14
2

2707 4.626299479138299e-14
2708 4.626299479138299e-14
2709 4.626299479138299e-14
2710 4.626299479138299e-14
2711 4.626299479138299e-14
2712 4.626299479138299e-14
2713 4.626299479138299e-14
2714 4.626299479138299e-14
2715 4.626299479138299e-14
2716 4.626299479138299e-14
2717 4.626299479138299e-14
2718 4.626299479138299e-14
2719 4.626299479138299e-14
2720 4.626299479138299e-14
2721 4.626299479138299e-14
2722 4.626299479138299e-14
2723 4.626299479138299e-14
2724 4.626299479138299e-14
2725 4.626299479138299e-14
2726 4.626299479138299e-14
2727 4.626299479138299e-14
2728 4.626299479138299e-14
2729 4.626299479138299e-14
2730 4.626299479138299e-14
2731 4.626299479138299e-14
2732 4.626299479138299e-14
2733 4.626299479138299e-14
2734 4.626299479138299e-14
2735 4.626299479138299e-14
2736 4.626299479138299e-14
2737 4.626299479138299e-14
2738 4.626299479138299e-14
2739 4.626299479138299e-14
2740 4.626299479138299e-14
2741 4.626299479138299e-14
2742 4.626299479138299e-14
2743 4.626299479138299e-14
2

3269 4.175132461980979e-14
3270 4.175132461980979e-14
3271 4.175132461980979e-14
3272 4.175132461980979e-14
3273 4.175132461980979e-14
3274 4.175132461980979e-14
3275 4.175132461980979e-14
3276 4.175132461980979e-14
3277 4.175132461980979e-14
3278 4.175132461980979e-14
3279 4.175132461980979e-14
3280 4.175132461980979e-14
3281 4.175132461980979e-14
3282 4.175132461980979e-14
3283 4.175132461980979e-14
3284 4.175132461980979e-14
3285 4.175132461980979e-14
3286 4.175132461980979e-14
3287 4.175132461980979e-14
3288 4.175132461980979e-14
3289 4.175132461980979e-14
3290 4.175132461980979e-14
3291 4.175132461980979e-14
3292 4.175132461980979e-14
3293 4.175132461980979e-14
3294 4.175132461980979e-14
3295 4.175132461980979e-14
3296 4.175132461980979e-14
3297 4.175132461980979e-14
3298 4.175132461980979e-14
3299 4.175132461980979e-14
3300 4.175132461980979e-14
3301 4.175132461980979e-14
3302 4.175132461980979e-14
3303 4.175132461980979e-14
3304 4.175132461980979e-14
3305 4.175132461980979e-14
3

3830 4.154038292275737e-14
3831 4.154038292275737e-14
3832 4.154038292275737e-14
3833 4.154038292275737e-14
3834 4.154038292275737e-14
3835 4.154038292275737e-14
3836 4.154038292275737e-14
3837 4.154038292275737e-14
3838 4.154038292275737e-14
3839 4.154038292275737e-14
3840 4.154038292275737e-14
3841 4.154038292275737e-14
3842 4.154038292275737e-14
3843 4.154038292275737e-14
3844 4.154038292275737e-14
3845 4.154038292275737e-14
3846 4.154038292275737e-14
3847 4.154038292275737e-14
3848 4.154038292275737e-14
3849 4.154038292275737e-14
3850 4.154038292275737e-14
3851 4.154038292275737e-14
3852 4.154038292275737e-14
3853 4.154038292275737e-14
3854 4.154038292275737e-14
3855 4.154038292275737e-14
3856 4.154038292275737e-14
3857 4.154038292275737e-14
3858 4.154038292275737e-14
3859 4.154038292275737e-14
3860 4.154038292275737e-14
3861 4.154038292275737e-14
3862 4.154038292275737e-14
3863 4.154038292275737e-14
3864 4.154038292275737e-14
3865 4.154038292275737e-14
3866 4.154038292275737e-14
3

4387 4.1283643848312804e-14
4388 4.1283643848312804e-14
4389 4.1283643848312804e-14
4390 4.1283643848312804e-14
4391 4.1283643848312804e-14
4392 4.1283643848312804e-14
4393 4.1283643848312804e-14
4394 4.1283643848312804e-14
4395 4.1283643848312804e-14
4396 4.1283643848312804e-14
4397 4.1283643848312804e-14
4398 4.1283643848312804e-14
4399 4.1283643848312804e-14
4400 4.1283643848312804e-14
4401 4.1283643848312804e-14
4402 4.1283643848312804e-14
4403 4.1283643848312804e-14
4404 4.1283643848312804e-14
4405 4.1283643848312804e-14
4406 4.1283643848312804e-14
4407 4.1283643848312804e-14
4408 4.1283643848312804e-14
4409 4.1283643848312804e-14
4410 4.1283643848312804e-14
4411 4.1283643848312804e-14
4412 4.1283643848312804e-14
4413 4.1283643848312804e-14
4414 4.1283643848312804e-14
4415 4.1283643848312804e-14
4416 4.1283643848312804e-14
4417 4.1283643848312804e-14
4418 4.1283643848312804e-14
4419 4.1283643848312804e-14
4420 4.1283643848312804e-14
4421 4.1283643848312804e-14
4422 4.1283643848312

4935 3.4856838959261213e-14
4936 3.4856838959261213e-14
4937 3.4856838959261213e-14
4938 3.4856838959261213e-14
4939 3.4856838959261213e-14
4940 3.4856838959261213e-14
4941 3.4856838959261213e-14
4942 3.4856838959261213e-14
4943 3.4856838959261213e-14
4944 3.4856838959261213e-14
4945 3.4856838959261213e-14
4946 3.4856838959261213e-14
4947 3.4856838959261213e-14
4948 3.4856838959261213e-14
4949 3.4856838959261213e-14
4950 3.4856838959261213e-14
4951 3.4856838959261213e-14
4952 3.4856838959261213e-14
4953 3.4856838959261213e-14
4954 3.4856838959261213e-14
4955 3.4856838959261213e-14
4956 3.4856838959261213e-14
4957 3.4856838959261213e-14
4958 3.4856838959261213e-14
4959 3.4856838959261213e-14
4960 3.4856838959261213e-14
4961 3.4856838959261213e-14
4962 3.4856838959261213e-14
4963 3.4856838959261213e-14
4964 3.4856838959261213e-14
4965 3.4856838959261213e-14
4966 3.4856838959261213e-14
4967 3.4856838959261213e-14
4968 3.4856838959261213e-14
4969 3.4856838959261213e-14
4970 3.4856838959261

In [152]:
# 1. Para predecir un valor de entrada state-action
print(net.fc1)
print(net.fc2)
print(net.fc3)
input = torch.randn(2)
input = torch.tensor([1, 2.0])
out = net(input)
print(input)
print(input, out)

Linear(in_features=2, out_features=100, bias=True)
Linear(in_features=100, out_features=100, bias=True)
Linear(in_features=100, out_features=10, bias=True)
tensor([1., 2.])
tensor([1., 2.]) tensor([ 0.0486, -0.1873, -0.5130, -0.1192,  0.0852,  1.0731,  0.2093, -0.4722,
         0.6940,  0.2438], grad_fn=<AddBackward0>)


In [153]:
# 2. Gradientes de cada parametro con respecto a la Advantage Function
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)

tensor([[-3.9129e-09,  1.5784e-09],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [-1.9713e-09,  7.9517e-10],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [-2.1280e-09,  8.5839e-10],
        [ 0.0000e+00,  0.0000e+00],
        [ 1.5246e-09, -6.1500e-10],
        [ 0.0000e+00,  0.0000e+00],
        [-1.8226e-09,  7.3518e-10],
        [ 1.4628e-10, -5.9006e-11],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [-6.4611e-10,  2.6062e-10],
        [ 3.0258e-09, -1.2205e-09],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [-7.3951e-10,  2.9830e-10],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00],
        [-5.3299e-09,  2.1500e-09],
        [-8.7137e-10,  3.5149e-10],
        [ 1.5730e-09, -6.3450e-10],
        [-1.0964e-09,  4.4225e-10],
        [-1.1322e-09,  4.567

In [None]:
# 3. actualizar parameters:
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)

#o capaz lo pueda definir en el backward pass


In [158]:
# Iterate over layers
for name, param in net.named_parameters(): 
    print(name, param.shape) 

fc1.weight torch.Size([100, 2])
fc1.bias torch.Size([100])
fc2.weight torch.Size([100, 100])
fc2.bias torch.Size([100])
fc3.weight torch.Size([10, 100])
fc3.bias torch.Size([10])
