Source: https://pytorch.org/tutorials/beginner/pytorch_with_examples.html

In [1]:
import torch

dtype = torch.float
device = torch.device("cpu")
device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H1, H2, D_out = 1, 2, 100, 100, 10

# Create random Tensors to hold input and outputs.
# Setting requires_grad=False indicates that we do not need to compute gradients
# with respect to these Tensors during the backward pass.
x = torch.randn(N, D_in, device=device, dtype=dtype, requires_grad=False)
y = torch.randn(N, D_out, device=device, dtype=dtype, requires_grad=False)

# Create random Tensors for weights.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
w1 = torch.randn(D_in, H1, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H1, H2, device=device, dtype=dtype, requires_grad=True)
w3 = torch.randn(H2, D_out, device=device, dtype=dtype, requires_grad=True)


#learning_rate = 1e-6
learning_rate = 0.0001
for t in range(5000):
    # Forward pass: compute predicted y using operations on Tensors; these
    # are exactly the same operations we used to compute the forward pass using
    # Tensors, but we do not need to keep references to intermediate values since
    # we are not implementing the backward pass by hand.
    y_pred = x.mm(w1).clamp(min=0).mm(w2).clamp(min=0).mm(w3)

    # Compute and print loss using operations on Tensors.
    # Now loss is a Tensor of shape (1,)
    # loss.item() gets the a scalar value held in the loss.
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass. This call will compute the
    # gradient of loss with respect to all Tensors with requires_grad=True.
    # After this call w1.grad and w2.grad will be Tensors holding the gradient
    # of the loss with respect to w1 and w2 respectively.
    loss.backward()
    # minimizo y_pred, no loss:
    #y_sum = y_pred.sum()
    #y_sum.backward()
    #print(t, y_sum.item())

    # Manually update weights using gradient descent. Wrap in torch.no_grad()
    # because weights have requires_grad=True, but we don't need to track this
    # in autograd.
    # An alternative way is to operate on weight.data and weight.grad.data.
    # Recall that tensor.data gives a tensor that shares the storage with
    # tensor, but doesn't track history.
    # You can also use torch.optim.SGD to achieve this.
    #learning_rate = learning_rate * 1.0/(t+1)
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad
        w3 -= learning_rate * w3.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()
        w3.grad.zero_()

0 59774.671875
1 37577.58203125
2 5851.09326171875
3 268.9912109375
4 38.263389587402344
5 10.922318458557129
6 3.408141613006592
7 1.1035585403442383
8 0.3654593229293823
9 0.12309855222702026
10 0.04206923395395279
11 0.014563562348484993
12 0.005102201364934444
13 0.0018069930374622345
14 0.0006459626602008939
15 0.00023329464602284133
16 8.486256410833448e-05
17 3.1118943297769874e-05
18 1.1508658644743264e-05
19 4.252947292116005e-06
20 1.5899769323368673e-06
21 5.935345370744471e-07
22 2.2383197517683584e-07
23 8.679882768092284e-08
24 3.318261221352259e-08
25 1.4104973189432712e-08
26 6.345295844312204e-09
27 3.106386259332794e-09
28 1.8519585864851251e-09
29 9.84677339666007e-10
30 6.42135344897099e-10
31 3.717276486625565e-10
32 3.1131375255455396e-10
33 2.3752744215954635e-10
34 2.741985527521251e-10
35 2.6165036803860175e-10
36 1.7287871134641364e-10
37 1.7445611621980106e-10
38 1.6010315295744704e-10
39 1.7428913867689744e-10
40 6.674139019224867e-11
41 5.608324915584717e-1

441 1.5056511593058985e-11
442 1.0967338148759609e-11
443 8.693601394327288e-12
444 9.386380561693386e-12
445 8.693601394327288e-12
446 9.148348745213752e-12
447 8.693601394327288e-12
448 9.148348745213752e-12
449 8.693601394327288e-12
450 9.148348745213752e-12
451 1.4747425503003342e-11
452 1.0285217122429913e-11
453 1.5202172853889806e-11
454 1.3340550886198344e-11
455 1.0612066780879559e-11
456 1.1066814131766023e-11
457 1.0612066780879559e-11
458 1.1066814131766023e-11
459 1.0285217122429913e-11
460 1.9465429268450407e-11
461 9.077294471637742e-12
462 4.510181117467482e-11
463 1.430333629315328e-11
464 9.98678917341067e-12
465 1.558586593120026e-11
466 4.231648365049523e-11
467 1.4417023130874895e-11
468 1.0100476011132287e-11
469 1.3578582702677977e-11
470 9.869549622010254e-12
471 1.4335310716262484e-11
472 1.4104384327140451e-11
473 4.6835535449929466e-11
474 4.2440828629253247e-11
475 3.530697956222184e-11
476 3.9246939032011596e-11
477 1.1134315691663232e-11
478 2.469147108996

763 8.150036201470812e-12
764 7.908451671312378e-12
765 8.150036201470812e-12
766 7.695288850584348e-12
767 1.4459655695020501e-11
768 7.965295090173186e-12
769 9.969025605016668e-12
770 7.908451671312378e-12
771 7.794764833590762e-12
772 7.908451671312378e-12
773 7.794764833590762e-12
774 7.581602012862731e-12
775 1.4615975096887723e-11
776 8.107403637325206e-12
777 1.2228551504733787e-11
778 7.581602012862731e-12
779 7.581602012862731e-12
780 7.581602012862731e-12
781 8.036349363749196e-12
782 7.581602012862731e-12
783 1.4857559627046157e-11
784 2.8920199568460703e-12
785 2.6504354266876362e-12
786 2.6504354266876362e-12
787 2.6504354266876362e-12
788 1.1290635093530454e-11
789 1.2228551504733787e-11
790 7.581602012862731e-12
791 7.581602012862731e-12
792 1.4615975096887723e-11
793 2.6504354266876362e-12
794 2.6504354266876362e-12
795 2.6504354266876362e-12
796 1.1290635093530454e-11
797 1.2228551504733787e-11
798 7.908451671312378e-12
799 7.908451671312378e-12
800 1.4615975096887723

1077 5.978872952283609e-11
1078 3.446143370666732e-11
1079 3.694833328182767e-11
1080 3.4049318919926463e-11
1081 3.6326608388037585e-11
1082 3.297284667524991e-11
1083 4.2010950274118386e-11
1084 3.297284667524991e-11
1085 4.3485326450820594e-11
1086 2.970435009075345e-11
1087 4.166633704727474e-11
1088 3.152333949429931e-11
1089 4.166633704727474e-11
1090 3.181821472963975e-11
1091 3.744571319685974e-11
1092 4.600420044909015e-11
1093 2.814115607208123e-11
1094 3.2645997016800266e-11
1095 4.5773274059968116e-11
1096 4.6345260962255e-11
1097 2.9391711287019007e-11
1098 3.2987057529965114e-11
1099 8.297018627700936e-11
1100 3.7538083752508555e-11
1101 1.3038570223500301e-11
1102 6.102862659673747e-11
1103 4.686140364640323e-12
1104 4.686140364640323e-12
1105 4.686140364640323e-12
1106 4.686140364640323e-12
1107 4.686140364640323e-12
1108 4.686140364640323e-12
1109 4.686140364640323e-12
1110 4.686140364640323e-12
1111 4.512057394379099e-12
1112 3.0199176492828883e-12
1113 3.019917649282

1389 2.6126767416201346e-11
1390 2.3995139208921046e-11
1391 2.1923907134180354e-11
1392 1.6847079287174438e-11
1393 3.013422844588831e-11
1394 1.3265943898943533e-11
1395 5.0313642141475157e-11
1396 1.2598033727329039e-11
1397 4.5197734444002435e-11
1398 2.1309287667747867e-11
1399 2.0001889033949283e-11
1400 1.3535950138532371e-11
1401 4.27996527108121e-11
1402 1.763222901018935e-11
1403 1.997702003819768e-11
1404 4.49277282044136e-11
1405 4.948941256799344e-11
1406 1.4207413023825666e-11
1407 1.6424306359397178e-11
1408 1.4335310716262484e-11
1409 3.465683295900135e-11
1410 1.5926926444365108e-11
1411 4.277123100138169e-11
1412 4.822464649834046e-11
1413 2.197719783936236e-11
1414 4.3989811793210265e-11
1415 2.0588086790951365e-11
1416 7.691991488201211e-11
1417 2.020439371364091e-11
1418 2.020439371364091e-11
1419 4.0810133050683817e-11
1420 2.018663014524691e-11
1421 3.5193292724500225e-11
1422 1.6946555270180852e-11
1423 3.8092307086401433e-11
1424 3.662858905073563e-11
1425 4.60

1703 2.622013717257232e-12
1704 2.380429187098798e-12
1705 2.622013717257232e-12
1706 2.380429187098798e-12
1707 2.622013717257232e-12
1708 2.3129276272015886e-12
1709 2.5545121573600227e-12
1710 2.3129276272015886e-12
1711 4.822109378466166e-11
1712 2.1302182240390266e-11
1713 2.062006121406057e-11
1714 1.9952151042446076e-11
1715 1.9017787344921544e-11
1716 1.8730017536938703e-11
1717 1.8118950784185017e-11
1718 1.0160872143671895e-11
1719 9.880207763046656e-12
1720 9.606648809779017e-12
1721 9.421907698481391e-12
1722 9.421907698481391e-12
1723 4.38060698826348e-12
1724 4.38060698826348e-12
1725 4.38060698826348e-12
1726 4.38060698826348e-12
1727 4.38060698826348e-12
1728 4.224287586396258e-12
1729 4.224287586396258e-12
1730 4.224287586396258e-12
1731 4.224287586396258e-12
1732 4.224287586396258e-12
1733 4.224287586396258e-12
1734 4.224287586396258e-12
1735 4.224287586396258e-12
1736 1.530520155057502e-11
1737 6.778688721453818e-12
1738 6.778688721453818e-12
1739 6.778688721453818e-

2021 9.599543382421416e-12
2022 1.139721650389447e-11
2023 2.458133696592313e-11
2024 1.2061573961830163e-11
2025 9.503620113093803e-12
2026 1.1244449815706048e-11
2027 2.1625479185161112e-11
2028 1.4150569604964858e-11
2029 1.1219580819954444e-11
2030 1.1219580819954444e-11
2031 1.1219580819954444e-11
2032 1.2061573961830163e-11
2033 1.1194711824202841e-11
2034 1.1194711824202841e-11
2035 1.2061573961830163e-11
2036 1.9323320721298387e-11
2037 1.8651857836005092e-11
2038 1.3535950138532371e-11
2039 1.8012369373821002e-11
2040 1.2924883385778685e-11
2041 1.0352718682327122e-11
2042 1.1276424238815252e-11
2043 1.7660650719619753e-11
2044 1.1208922678918043e-11
2045 5.6524784852740595e-12
2046 1.1081024986481225e-11
2047 8.807288232048904e-12
2048 6.345257652640157e-12
2049 6.345257652640157e-12
2050 6.345257652640157e-12
2051 6.306177802173352e-12
2052 6.306177802173352e-12
2053 6.306177802173352e-12
2054 6.306177802173352e-12
2055 1.1351031226070063e-11
2056 1.8030132942215005e-11
2057

2349 1.2498557744322625e-11
2350 1.2498557744322625e-11
2351 9.77007363900384e-12
2352 9.77007363900384e-12
2353 9.77007363900384e-12
2354 7.681077995869146e-12
2355 7.681077995869146e-12
2356 7.681077995869146e-12
2357 7.681077995869146e-12
2358 7.656209000117542e-12
2359 2.8333002610736457e-11
2360 4.3208214783874155e-11
2361 2.4321988867370692e-11
2362 2.373934382404741e-11
2363 1.7600254587080144e-11
2364 1.7600254587080144e-11
2365 1.7600254587080144e-11
2366 1.7159718090908882e-11
2367 2.0158208435816505e-11
2368 2.0158208435816505e-11
2369 2.0158208435816505e-11
2370 9.386380561693386e-12
2371 9.386380561693386e-12
2372 9.41480227112379e-12
2373 9.41480227112379e-12
2374 9.41480227112379e-12
2375 9.20163945039576e-12
2376 9.20163945039576e-12
2377 9.98323645973187e-12
2378 1.8538170998283476e-11
2379 5.648925771595259e-12
2380 5.648925771595259e-12
2381 5.648925771595259e-12
2382 5.648925771595259e-12
2383 4.0715208982078366e-12
2384 4.046651902456233e-12
2385 4.046651902456233e

2692 4.0146774793470286e-12
2693 7.197908935552277e-12
2694 5.378919532006421e-12
2695 1.536915039679343e-11
2696 3.0716873489211594e-11
2697 1.197275611986015e-11
2698 5.606293207449653e-12
2699 1.509914415720459e-11
2700 7.42528261099551e-12
2701 5.378919532006421e-12
2702 8.789524663654902e-12
2703 1.8339219032270648e-11
2704 9.869549622010254e-12
2705 7.42528261099551e-12
2706 1.3280154753658735e-11
2707 2.8897884085665737e-11
2708 1.4473866549735703e-11
2709 5.606293207449653e-12
2710 6.231570814918541e-12
2711 2.7036262117974275e-11
2712 8.576361842926872e-12
2713 6.302625088494551e-12
2714 1.126221338410005e-11
2715 1.7060242107902468e-11
2716 1.0580092357770354e-11
2717 1.0679568340776768e-11
2718 7.979505944888388e-12
2719 1.1716960734986515e-11
2720 9.215850305110962e-12
2721 9.897971331440658e-12
2722 9.215850305110962e-12
2723 1.2711720565050655e-11
2724 3.4320324360237464e-12
2725 4.1141534623534426e-12
2726 1.2526979453753029e-11
2727 9.897971331440658e-12
2728 9.21585030

3123 4.313105428366271e-12
3124 4.895750471689553e-12
3125 1.3877010651697219e-11
3126 1.0807466033213586e-11
3127 1.430333629315328e-11
3128 1.4786505353470147e-11
3129 1.069377919549197e-11
3130 1.0423772955903132e-11
3131 1.1560641333119293e-11
3132 4.838907052828745e-12
3133 1.3621215266823583e-11
3134 1.1223133533633245e-11
3135 4.856670621222747e-12
3136 1.4250045587971272e-11
3137 1.1208922678918043e-11
3138 1.1208922678918043e-11
3139 3.792532954349781e-11
3140 2.1948776129931957e-11
3141 1.522704184964141e-11
3142 9.102163467389346e-12
3143 6.373679362070561e-12
3144 6.373679362070561e-12
3145 5.464184660297633e-12
3146 6.348810366318958e-12
3147 5.4393156645460294e-12
3148 6.348810366318958e-12
3149 5.4393156645460294e-12
3150 6.50512976818618e-12
3151 5.140887715526787e-12
3152 1.636746294053637e-11
3153 2.9313551586085396e-11
3154 1.693589712914445e-11
3155 2.9810931501117466e-11
3156 4.13892253803283e-11
3157 2.756206374243675e-11
3158 4.366296213476062e-11
3159 2.55583332

3532 4.0589864802598186e-11
3533 4.109790285866666e-11
3534 5.549349868516629e-11
3535 1.23209220603826e-11
3536 1.1795120435920126e-11
3537 1.884370437466032e-11
3538 2.22543095063088e-11
3539 3.700517670068848e-11
3540 4.344269388667499e-11
3541 1.2590928299971438e-11
3542 1.168143359819851e-11
3543 6.22446538756094e-12
3544 6.22446538756094e-12
3545 6.22446538756094e-12
3546 4.405475984015084e-12
3547 4.405475984015084e-12
3548 4.405475984015084e-12
3549 4.405475984015084e-12
3550 2.883748795312613e-11
3551 5.885081311163276e-11
3552 2.883748795312613e-11
3553 5.502454047956462e-11
3554 3.4496960843455327e-11
3555 1.8545276425641077e-11
3556 5.4200310906082905e-11
3557 4.217082238966441e-11
3558 5.0654702654640005e-11
3559 2.607702942469814e-11
3560 3.750610932939935e-11
3561 2.1742718736561528e-11
3562 3.750610932939935e-11
3563 2.1742718736561528e-11
3564 2.8880120517271735e-11
3565 3.744926591053854e-11
3566 2.550859523609006e-11
3567 4.0717762495035004e-11
3568 5.259448432326508

3939 4.149680599141448e-12
3940 8.924527783449321e-12
3941 4.135469744426246e-12
3942 4.817590770755942e-12
3943 1.210775923965457e-11
3944 4.817590770755942e-12
3945 4.149680599141448e-12
3946 4.868649927658453e-11
3947 4.783384799367241e-11
3948 8.924527783449321e-12
3949 7.105538379903464e-12
3950 3.4475644561382524e-11
3951 2.5551227800235665e-11
3952 8.924527783449321e-12
3953 5.51392265180084e-12
3954 1.310251906971871e-11
3955 2.8052338230111218e-11
3956 1.1653011888768106e-11
3957 3.914391033532638e-11
3958 2.798838938389281e-11
3959 2.537359211629564e-11
3960 3.567646178481709e-11
3961 2.211575367283558e-11
3962 1.2139733662763774e-11
3963 1.9039103626994347e-11
3964 3.213085353337419e-11
3965 1.9039103626994347e-11
3966 2.055611236784216e-11
3967 1.883659894730272e-11
3968 2.4265145448509884e-11
3969 1.883659894730272e-11
3970 3.5616065652277484e-11
3971 3.4980129903772195e-11
3972 2.36718422641502e-11
3973 6.26344531795553e-11
3974 4.28316271339213e-11
3975 4.826727906248607

4401 1.5319412405290223e-11
4402 6.679212738447404e-12
4403 6.679212738447404e-12
4404 1.3500423001744366e-11
4405 1.9042656340673147e-11
4406 5.382372325613005e-11
4407 4.2280956513707224e-11
4408 5.387346124763326e-11
4409 2.7018498549580272e-11
4410 2.8112734362650826e-11
4411 2.247102504071563e-11
4412 1.8605672558180686e-11
4413 2.8595903422967694e-11
4414 1.8946733071345534e-11
4415 4.046196711016137e-11
4416 1.3571477275320376e-11
4417 1.3162915202258318e-11
4418 1.3162915202258318e-11
4419 1.2132628235406173e-11
4420 1.2132628235406173e-11
4421 1.175959329913212e-11
4422 1.2789880265984266e-11
4423 1.5269674413787016e-11
4424 1.4822032490258152e-11
4425 1.4822032490258152e-11
4426 1.4822032490258152e-11
4427 1.480426892186415e-11
4428 1.4505840972844908e-11
4429 1.2782774838626665e-11
4430 1.0239031844605506e-11
4431 3.8015146586189985e-12
4432 3.8015146586189985e-12
4433 3.8015146586189985e-12
4434 1.0608514067200758e-11
4435 1.2907119817384682e-11
4436 1.2374212765564607e-11


4851 2.3849477948090225e-11
4852 2.332012360994895e-11
4853 3.510447488253021e-11
4854 2.6847968292997848e-11
4855 2.1600610189409508e-11
4856 4.6473158654691815e-11
4857 3.112188284859485e-11
4858 6.354039516764942e-11
4859 4.409994591725308e-11
4860 2.947697641531022e-11
4861 2.5384250257332042e-11
4862 3.0727531630247995e-11
4863 2.969013923603825e-11
4864 2.0595192218308966e-11
4865 2.2311152925169608e-11
4866 3.673517046109964e-11
4867 4.2348458073604434e-11
4868 3.14913650711901e-11
4869 2.594913173226132e-11
4870 2.1870616428998346e-11
4871 3.673517046109964e-11
4872 4.730804636920993e-11
4873 2.899025464131455e-11
4874 2.1362578372929875e-11
4875 2.118849540266865e-11
4876 2.303590651564491e-11
4877 2.5682678206351284e-11
4878 9.762968211646239e-12
4879 9.762968211646239e-12
4880 9.212297591432161e-12
4881 1.9540036255705218e-11
4882 7.332912055346696e-12
4883 1.913147418264316e-11
4884 7.332912055346696e-12
4885 1.913147418264316e-11
4886 7.00606239689705e-12
4887 1.8918311361

Lo mismo pero definiendo una red bien bonita y trabajando sobre ella después.

Se observa una velocidad de ejecución de aprox 20% con respecto a la anterior. 

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        #self.conv1 = nn.Conv2d(1, 6, 3)
        #self.conv2 = nn.Conv2d(6, 16, 3)
        # an affine operation: y = Wx + b
        # input: 147 chars from state and 1 from action taken
        self.fc1 = nn.Linear(2, 100)  # 6*6 from image dimension
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 10)


    def forward(self, x):
        # Max pooling over a (2, 2) window
        #x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        #x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        #x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
net = net.to('cuda:0')
print(net)

Net(
  (fc1): Linear(in_features=2, out_features=100, bias=True)
  (fc2): Linear(in_features=100, out_features=100, bias=True)
  (fc3): Linear(in_features=100, out_features=10, bias=True)
)


In [17]:
params = list(net.parameters())
print(len(params))
print(params[0].size())  # conv1's .weight

6
torch.Size([100, 2])


In [18]:
input = torch.randn(2)
out = net(input)
print(out)


RuntimeError: Expected object of backend CPU but got backend CUDA for argument #2 'mat2'

In [19]:
net.zero_grad()
out.backward(torch.randn(1, 10))

In [6]:
output = net(input)
target = torch.randn(10)  # a dummy target, for example
target = target.view(1, -1)  # make it the same shape as output
criterion = nn.MSELoss()

loss = criterion(output, target)
print(loss)


tensor(0.4881, grad_fn=<MseLossBackward>)


  return F.mse_loss(input, target, reduction=self.reduction)


In [7]:
print(loss.grad_fn)  # MSELoss
print(loss.grad_fn.next_functions[0][0])  # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])  # ReLU

<MseLossBackward object at 0x7f9ba0da6be0>
<ExpandBackward object at 0x7f9ba0da6d30>
<AddBackward0 object at 0x7f9ba0da6be0>


In [20]:
net.zero_grad()     # zeroes the gradient buffers of all parameters

print('conv1.bias.grad before backward')
print(net.fc1.bias.grad)

loss.backward()

print('conv1.bias.grad after backward')
print(net.fc1.bias.grad)

conv1.bias.grad before backward
None


RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

In [9]:
learning_rate = 0.01
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)


In [23]:
import torch.optim as optim

# create your optimizer
learning_rate = 1e-6
#optimizer = optim.SGD(net.parameters(), lr=0.01)
optimizer = optim.SGD(net.parameters(), lr=0.01)

# in your training loop:
for t in range(5000):
    optimizer.zero_grad()   # zero the gradient buffers
    output = net(input)
    loss = criterion(output, target)
    loss.backward()
    print(t, loss.item())
    optimizer.step()    # Does the update


RuntimeError: Expected object of backend CPU but got backend CUDA for argument #2 'mat2'

In [11]:
# 1. Para predecir un valor de entrada state-action
print(net.fc1)
print(net.fc2)
print(net.fc3)
input = torch.randn(2)
input = torch.tensor([1, 2.0])
out = net(input)
print(input)
print(input, out)

Linear(in_features=2, out_features=100, bias=True)
Linear(in_features=100, out_features=100, bias=True)
Linear(in_features=100, out_features=10, bias=True)
tensor([1., 2.])
tensor([1., 2.]) tensor([-1.2330, -0.7993,  0.3579, -0.2056, -1.0268,  0.4119, -0.8720, -0.2456,
         1.9545,  0.9065], grad_fn=<AddBackward0>)


In [12]:
# 2. Gradientes de cada parametro con respecto a la Advantage Function
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)

In [13]:
# 3. actualizar parameters:
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)

#o capaz lo pueda definir en el backward pass


In [14]:
# Iterate over layers
for name, param in net.named_parameters(): 
    print(name, param.shape) 

fc1.weight torch.Size([100, 2])
fc1.bias torch.Size([100])
fc2.weight torch.Size([100, 100])
fc2.bias torch.Size([100])
fc3.weight torch.Size([10, 100])
fc3.bias torch.Size([10])
