Source: https://pytorch.org/tutorials/beginner/pytorch_with_examples.html

In [1]:
import torch

dtype = torch.float
device = torch.device("cpu")
device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H1, H2, D_out = 1, 2, 100, 100, 10

# Create random Tensors to hold input and outputs.
# Setting requires_grad=False indicates that we do not need to compute gradients
# with respect to these Tensors during the backward pass.
x = torch.randn(N, D_in, device=device, dtype=dtype, requires_grad=False)
y = torch.randn(N, D_out, device=device, dtype=dtype, requires_grad=False)

# Create random Tensors for weights.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
w1 = torch.randn(D_in, H1, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H1, H2, device=device, dtype=dtype, requires_grad=True)
w3 = torch.randn(H2, D_out, device=device, dtype=dtype, requires_grad=True)


#learning_rate = 1e-6
learning_rate = 0.0001
for t in range(5000):
    # Forward pass: compute predicted y using operations on Tensors; these
    # are exactly the same operations we used to compute the forward pass using
    # Tensors, but we do not need to keep references to intermediate values since
    # we are not implementing the backward pass by hand.
    y_pred = x.mm(w1).clamp(min=0).mm(w2).clamp(min=0).mm(w3)

    # Compute and print loss using operations on Tensors.
    # Now loss is a Tensor of shape (1,)
    # loss.item() gets the a scalar value held in the loss.
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass. This call will compute the
    # gradient of loss with respect to all Tensors with requires_grad=True.
    # After this call w1.grad and w2.grad will be Tensors holding the gradient
    # of the loss with respect to w1 and w2 respectively.
    loss.backward()
    # minimizo y_pred, no loss:
    #y_sum = y_pred.sum()
    #y_sum.backward()
    #print(t, y_sum.item())

    # Manually update weights using gradient descent. Wrap in torch.no_grad()
    # because weights have requires_grad=True, but we don't need to track this
    # in autograd.
    # An alternative way is to operate on weight.data and weight.grad.data.
    # Recall that tensor.data gives a tensor that shares the storage with
    # tensor, but doesn't track history.
    # You can also use torch.optim.SGD to achieve this.
    #learning_rate = learning_rate * 1.0/(t+1)
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad
        w3 -= learning_rate * w3.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()
        w3.grad.zero_()

0 7624.07763671875
1 1858.953125
2 1413.042236328125
3 1028.6907958984375
4 1091.4951171875
5 1164.2652587890625
6 1324.1083984375
7 1448.115234375
8 1529.6422119140625
9 1498.400634765625
10 1362.060546875
11 1159.2322998046875
12 916.797119140625
13 687.1103515625
14 489.0791320800781
15 337.1251220703125
16 225.79299926757812
17 148.9300079345703
18 96.86558532714844
19 62.60679244995117
20 40.20335388183594
21 25.75741958618164
22 16.453296661376953
23 10.502440452575684
24 6.69453239440918
25 4.266636371612549
26 2.7173871994018555
27 1.7307437658309937
28 1.1019346714019775
29 0.7016331553459167
30 0.44667932391166687
31 0.28438258171081543
32 0.18102611601352692
33 0.11524602770805359
34 0.07337146997451782
35 0.04671202227473259
36 0.029734371230006218
37 0.018929827958345413
38 0.012053090147674084
39 0.007672385778278112
40 0.004885001573711634
41 0.003109959652647376
42 0.001978575950488448
43 0.0012593934079632163
44 0.0008023218833841383
45 0.0005113435327075422
46 0.00032

334 1.6239712352650315e-10
335 2.0931781308242137e-10
336 1.2600934184980872e-10
337 1.5612747206183997e-10
338 1.790868842110882e-10
339 1.4533610426248345e-10
340 6.031966592878746e-11
341 1.2954890632466132e-11
342 1.2954890632466132e-11
343 1.2393561871215653e-11
344 2.7883393510785837e-11
345 4.144765433644615e-11
346 2.7783917527779423e-11
347 2.3847510771668468e-11
348 9.021930424957247e-11
349 1.358556878106043e-10
350 8.112968630236139e-11
351 1.6791821405570673e-11
352 9.498704600652275e-11
353 9.154624280860446e-11
354 1.9632287462378883e-10
355 7.634418097701712e-11
356 5.4034918900436324e-11
357 5.972636274442777e-11
358 1.293595508489176e-10
359 7.76302633287429e-11
360 8.053460676116231e-11
361 4.427561442477135e-11
362 6.719771961094523e-11
363 2.8933220402871385e-11
364 4.307302084449738e-11
365 3.166880993554777e-11
366 4.3833301571760686e-11
367 8.395942274752599e-11
368 6.142278352605501e-11
369 2.052899239490813e-10
370 1.5187842650199457e-10
371 1.5795179053590402

809 5.975211991859908e-11
810 1.5831860822324018e-10
811 2.5250637691875966e-10
812 1.6324799845257587e-10
813 1.379384662048011e-10
814 2.01096833629677e-10
815 1.201553578855652e-10
816 1.5875203929205384e-10
817 2.387378350565683e-10
818 1.7165727173029666e-10
819 2.4181270874557015e-10
820 3.0437244391556817e-10
821 7.066783269671362e-11
822 6.550218700773769e-11
823 2.0399229527789942e-10
824 8.75449990278554e-11
825 9.292913660807756e-11
826 4.3128976084938486e-11
827 7.05132896516858e-11
828 5.836833794070628e-11
829 6.679537478682107e-11
830 9.470549344747781e-11
831 1.2493731049723067e-10
832 2.911297036778393e-10
833 3.2205785260863706e-10
834 1.9219906222112115e-10
835 2.2430315937960188e-10
836 1.7184201284159428e-10
837 2.5038185413883696e-10
838 2.6245930428991926e-10
839 1.237400459874749e-10
840 1.3120429742663475e-10
841 2.507939689255778e-10
842 1.3893144967802584e-10
843 8.12069578248753e-11
844 3.0772637910070344e-11
845 2.12300489688122e-11
846 3.698988684797122e-1

1162 1.6683463638367257e-11
1163 2.2396227233878463e-11
1164 4.089876007307147e-11
1165 3.5242839896421074e-11
1166 3.6464973401928447e-11
1167 1.5232606842552343e-10
1168 7.48031914188374e-11
1169 6.97974178454075e-11
1170 9.802905709399568e-11
1171 7.742864682747097e-11
1172 6.672609687008446e-11
1173 1.3328085857189365e-10
1174 1.486019363117208e-10
1175 5.869153496540491e-10
1176 4.797121588850928e-10
1177 2.323065351195197e-10
1178 3.9153489478360726e-11
1179 9.447634341519517e-11
1180 1.0474368594692862e-10
1181 3.663816819376997e-11
1182 1.0003989303619676e-10
1183 8.396741635330329e-11
1184 4.168302161766668e-11
1185 5.0136703816372474e-11
1186 1.6097426169814355e-10
1187 2.2617988038042824e-10
1188 1.1249748355091072e-10
1189 1.2770842716669506e-10
1190 9.247438925719109e-11
1191 6.097602978094585e-11
1192 7.379599709089746e-11
1193 5.185976995059072e-11
1194 3.951764263043778e-11
1195 8.97174834424419e-11
1196 3.401358361632134e-10
1197 7.26573523568419e-11
1198 6.92236545862

1513 3.111813931533369e-11
1514 3.530323602896068e-11
1515 7.689840431091e-11
1516 1.0143522133354566e-10
1517 7.90300325181903e-11
1518 1.5179849044422156e-10
1519 6.609993108419587e-11
1520 2.6534321961868557e-10
1521 3.4123628922522187e-10
1522 1.5386972251896225e-10
1523 5.216263879170846e-11
1524 4.698988767537493e-11
1525 7.084280384539454e-11
1526 6.55918930281274e-11
1527 3.370984394401866e-11
1528 5.572956332522416e-11
1529 9.648451482213716e-11
1530 1.5952386633877325e-10
1531 9.18233544755509e-11
1532 1.2992976139436507e-10
1533 1.0629178093246594e-10
1534 9.129932920792783e-11
1535 1.97468624785202e-10
1536 2.620445249679193e-10
1537 1.0543024786535682e-10
1538 5.451453524707439e-11
1539 7.703873650122262e-11
1540 9.143966139824045e-11
1541 1.5407933262601148e-10
1542 1.1583259351688469e-10
1543 1.2574999375125628e-10
1544 3.6821737325887227e-10
1545 3.995301034454002e-10
1546 4.3338124755543106e-10
1547 1.3626869077576487e-10
1548 1.0115100423924162e-10
1549 1.006971450667

1863 5.975744898911728e-11
1864 2.3242661267852682e-11
1865 4.9014046293871516e-11
1866 5.0003477053417456e-11
1867 5.2055169202924745e-11
1868 4.583969662186327e-11
1869 1.400914106941542e-10
1870 1.173407204735355e-10
1871 9.436176839905386e-11
1872 1.6818715864452827e-10
1873 1.5942971942628503e-10
1874 5.188907983844082e-11
1875 6.668257612751916e-11
1876 1.2886039457704612e-10
1877 9.67527447048866e-11
1878 1.6890658316448537e-10
1879 7.26671223194586e-11
1880 8.254633088178309e-11
1881 5.4614011230080806e-11
1882 4.932402056234686e-11
1883 2.0512827547669588e-10
1884 1.1582992898162559e-10
1885 1.4809034554197353e-10
1886 5.2262114774714874e-11
1887 1.1986137082864445e-10
1888 4.500177175348341e-10
1889 3.604911091859009e-10
1890 7.679093472212628e-11
1891 7.181003014444798e-11
1892 6.596226342914235e-11
1893 5.914105316584539e-11
1894 1.357446655081418e-10
1895 2.131725074239199e-10
1896 9.730430350352037e-11
1897 7.463976658961258e-11
1898 5.519932080866319e-11
1899 7.807790525

2163 9.455361493770909e-11
2164 1.429007190356657e-10
2165 1.1281811596042246e-10
2166 9.60954926743085e-11
2167 1.7144854980166713e-10
2168 1.1832126944888444e-10
2169 1.1493908602666636e-10
2170 1.0676784456542521e-10
2171 6.995640178253382e-11
2172 1.1446479875054649e-10
2173 6.649428230254273e-11
2174 3.088959366071009e-10
2175 1.825543327615975e-10
2176 2.385461619902607e-11
2177 5.322845289534861e-11
2178 9.659109623250117e-11
2179 8.502523685116614e-11
2180 1.5536186226405846e-10
2181 5.5382285563121414e-11
2182 5.1338409218226744e-11
2183 6.390080131701836e-11
2184 5.0354307529199005e-11
2185 5.419834372966115e-11
2186 3.3052591913440565e-11
2187 3.008962870532095e-11
2188 2.982672789308971e-11
2189 4.2836765384857145e-11
2190 2.414949143436651e-11
2191 2.3026833911865552e-11
2192 1.2411325439609655e-11
2193 3.875114465423657e-11
2194 2.331165538382862e-10
2195 2.912860230797065e-10
2196 2.472323734625803e-10
2197 2.6790383800268103e-10
2198 1.3660619857525091e-10
2199 4.209691

2487 1.4327126290902825e-11
2488 1.4327126290902825e-11
2489 1.4327126290902825e-11
2490 1.4327126290902825e-11
2491 1.4327126290902825e-11
2492 1.4327126290902825e-11
2493 1.4327126290902825e-11
2494 1.4327126290902825e-11
2495 1.2053389536470505e-11
2496 1.2053389536470505e-11
2497 1.2053389536470505e-11
2498 1.2053389536470505e-11
2499 2.1546240486225443e-11
2500 1.7453514328247266e-11
2501 2.1546240486225443e-11
2502 3.557945951748742e-11
2503 5.5815716631935075e-11
2504 5.10053423108392e-11
2505 9.754588803367881e-11
2506 1.1155423806918918e-10
2507 5.769154945434174e-11
2508 9.671988210335769e-11
2509 2.2080551276282279e-10
2510 1.897103862891214e-10
2511 1.053050147081791e-10
2512 3.6947254283825615e-11
2513 4.640990716731075e-11
2514 1.130872340215916e-10
2515 1.9804149986590858e-10
2516 1.364534318870625e-10
2517 3.3416194811231037e-10
2518 1.768726554107758e-10
2519 1.7186510548050649e-10
2520 2.841237523032447e-10
2521 3.3469485516413044e-10
2522 5.5542157678667436e-11
2523 

2817 8.978409682391941e-11
2818 2.0266802125412653e-10
2819 3.72660929892632e-10
2820 1.2324000153718373e-10
2821 4.292647140524686e-11
2822 6.499414895166922e-11
2823 1.2694281736891355e-10
2824 7.151870762278634e-11
2825 1.605319488451329e-10
2826 5.402159622414082e-11
2827 1.861185927598541e-10
2828 1.818500072747753e-10
2829 1.9379955973342078e-10
2830 8.04537825249696e-11
2831 1.2146808558988198e-10
2832 1.3414949706636037e-10
2833 8.169190324203157e-11
2834 1.2780790314970147e-10
2835 7.173009408667497e-11
2836 1.159258522509532e-10
2837 2.3393634251966944e-10
2838 5.6105262796757316e-11
2839 5.2616497964175224e-11
2840 8.569581155803974e-11
2841 7.352243813762982e-11
2842 8.873710793944056e-12
2843 3.266978701454981e-11
2844 5.0710467075498755e-11
2845 3.377112825497797e-11
2846 3.0360523123329486e-11
2847 3.791359240445935e-11
2848 1.3045467484040785e-10
2849 2.3985338715171167e-10
2850 8.691350417144861e-11
2851 8.025127784527797e-11
2852 7.955761049949217e-11
2853 7.654313294

3157 6.450831535609325e-11
3158 1.315871023255255e-10
3159 5.491243917910005e-11
3160 8.19254941664127e-11
3161 1.1459269644298331e-10
3162 5.154446661159717e-11
3163 7.008962854548884e-11
3164 1.0838788200295824e-10
3165 3.065902254295594e-10
3166 1.2710180130603987e-10
3167 1.5686821286386987e-10
3168 1.4600401443409794e-10
3169 1.4048220919882226e-10
3170 5.1139457252213916e-11
3171 9.904424502771292e-11
3172 3.208454890657464e-10
3173 5.825827598116007e-10
3174 2.4030902268101784e-10
3175 1.6213688724953101e-10
3176 9.024594960216348e-11
3177 5.691616969394353e-11
3178 5.3121871484984595e-11
3179 7.314496230925727e-11
3180 1.8952653335624348e-10
3181 1.4510873058704021e-10
3182 4.700409853009013e-11
3183 1.2272663441059706e-10
3184 1.193941889798822e-10
3185 1.9995907707404115e-10
3186 1.188470710733469e-10
3187 8.128955841790741e-11
3188 4.1135015532711705e-11
3189 1.0075665302089476e-10
3190 5.5255276049104296e-11
3191 5.823777571301036e-11
3192 3.2983313996703956e-11
3193 8.5826

3499 1.500452262437335e-10
3500 1.0434844655016207e-10
3501 1.4978054907466287e-10
3502 1.0347270262833774e-10
3503 3.0146472124181756e-11
3504 4.452252802544798e-11
3505 1.0555814555779364e-10
3506 6.404290986417038e-11
3507 9.527836852818439e-11
3508 5.7653357782294634e-11
3509 9.138992340673724e-11
3510 5.91952320494471e-11
3511 6.62580268429025e-11
3512 1.1723236270633208e-10
3513 1.8522419709121607e-10
3514 1.0184911247712591e-10
3515 1.330170695812427e-10
3516 1.30919192153911e-10
3517 9.629799735400013e-11
3518 6.637171368062411e-11
3519 1.0080461465555857e-10
3520 8.652270566678055e-11
3521 3.1908618108866804e-11
3522 3.435999054723915e-11
3523 2.5833477718117948e-11
3524 3.8857726064600584e-11
3525 4.096803798980808e-11
3526 4.649783683086106e-11
3527 1.413766048674603e-10
3528 1.1278436518047386e-10
3529 1.0407488759689443e-10
3530 7.68930752403918e-11
3531 3.316627875116218e-11
3532 3.2619160844626904e-11
3533 2.7673783403736607e-11
3534 6.098402338672315e-11
3535 5.64081316

3939 8.770487114340142e-11
3940 7.469483365163399e-11
3941 1.0415393547624774e-10
3942 8.252856731338909e-11
3943 4.2192836030574554e-11
3944 3.2188394311072344e-11
3945 1.7297194926380044e-11
3946 9.649606114159326e-11
3947 6.119452167219208e-11
3948 4.541870005092541e-11
3949 4.705472470001304e-11
3950 3.96935019575384e-11
3951 5.5678937155301256e-11
3952 5.3250657355841113e-11
3953 6.476855163306539e-11
3954 1.428039075879184e-10
3955 8.882930502274178e-11
3956 5.572689878996506e-11
3957 3.284919905532924e-11
3958 3.097336623292257e-11
3959 5.413173034818364e-11
3960 5.7167524186718666e-11
3961 6.013226028223073e-11
3962 1.2405090843436994e-10
3963 7.894743192515818e-11
3964 8.778480720117443e-11
3965 1.0702630448555794e-10
3966 6.264758156682149e-11
3967 2.101688614808417e-11
3968 1.4160309036448382e-10
3969 1.1753185230611862e-11
3970 2.357661635365993e-11
3971 1.1753185230611862e-11
3972 8.68364061212823e-12
3973 8.68364061212823e-12
3974 1.5959598226311655e-11
3975 2.50545452440

4296 2.4193012176931816e-11
4297 1.9702382086927983e-11
4298 2.4193012176931816e-11
4299 1.9702382086927983e-11
4300 2.513803401549275e-11
4301 4.558212488015023e-11
4302 2.1423671864306826e-11
4303 2.930359080388634e-11
4304 5.094672253513899e-11
4305 2.934622336803194e-11
4306 2.8562850001856432e-11
4307 5.082415391322037e-11
4308 1.1456693926881201e-10
4309 3.095070033598546e-10
4310 1.5865078695220802e-10
4311 1.3518955399582921e-10
4312 1.1973169677936824e-10
4313 1.0707781883390055e-10
4314 1.761869816707673e-10
4315 1.5867565594795963e-10
4316 1.4979920082147657e-10
4317 5.956204973678325e-11
4318 4.9351554093357564e-11
4319 3.5394718406189796e-11
4320 4.124248512149542e-11
4321 2.180203587109908e-11
4322 2.146275171477363e-11
4323 4.9564716914085594e-11
4324 9.483339113991462e-11
4325 1.5601112068885925e-10
4326 2.138066668155858e-10
4327 2.0407933676303003e-10
4328 9.639303244490804e-11
4329 5.5815716631935075e-11
4330 1.1087034068602009e-10
4331 1.0472947509221342e-10
4332 9.

4667 2.996224657270119e-10
4668 1.4520820657004663e-10
4669 4.416548030072853e-11
4670 4.6444546125679054e-11
4671 4.101422326763249e-11
4672 7.420544734237922e-11
4673 4.909309417322483e-11
4674 2.1282522966181716e-10
4675 1.2196724186175345e-10
4676 7.092540443842665e-11
4677 9.794201560886506e-11
4678 4.7873625202976555e-11
4679 6.283587539179791e-11
4680 1.3023795930600102e-10
4681 6.121761431110428e-11
4682 6.005410058129712e-11
4683 4.378978082919538e-11
4684 2.385905709112457e-11
4685 5.266445959883903e-11
4686 1.060386500828514e-10
4687 3.958692054717439e-11
4688 1.3397363773925974e-10
4689 1.4383969709763633e-11
4690 1.3986065777737977e-11
4691 2.0551480656161303e-11
4692 5.182957188432091e-11
4693 2.6015554294156473e-11
4694 6.309167077667155e-11
4695 4.005765510961545e-11
4696 7.535919110956968e-11
4697 3.664349726428817e-11
4698 2.5743771697728235e-11
4699 3.4895562134318325e-11
4700 1.727301912612944e-10
4701 1.349985956355937e-10
4702 2.0351800800177955e-10
4703 1.1095205

Lo mismo pero definiendo una red bien bonita y trabajando sobre ella después.

Se observa una velocidad de ejecución de aprox 20% con respecto a la anterior. 

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        #self.conv1 = nn.Conv2d(1, 6, 3)
        #self.conv2 = nn.Conv2d(6, 16, 3)
        # an affine operation: y = Wx + b
        # input: 147 chars from state and 1 from action taken
        self.fc1 = nn.Linear(2, 100)  # 6*6 from image dimension
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 10)


    def forward(self, x):
        # Max pooling over a (2, 2) window
        #x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        #x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        #x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
#net = net.to('cuda:0')
print(net)

Net(
  (fc1): Linear(in_features=2, out_features=100, bias=True)
  (fc2): Linear(in_features=100, out_features=100, bias=True)
  (fc3): Linear(in_features=100, out_features=10, bias=True)
)


In [6]:
params = list(net.parameters())
print(len(params))
print(params[0].size())  # conv1's .weight

6
torch.Size([100, 2])


In [7]:
input = torch.randn(2)
out = net(input)
print(out)


tensor([-0.1424, -0.1077, -0.1101,  0.1672, -0.0550,  0.2324,  0.0541,  0.1777,
        -0.0224, -0.1102], grad_fn=<AddBackward0>)


In [8]:
net.zero_grad()
out.backward(torch.randn(1, 10))

In [9]:
output = net(input)
target = torch.randn(10)  # a dummy target, for example
target = target.view(1, -1)  # make it the same shape as output
criterion = nn.MSELoss()

loss = criterion(output, target)
print(loss)


tensor(0.6235, grad_fn=<MseLossBackward>)


  return F.mse_loss(input, target, reduction=self.reduction)


In [10]:
print(loss.grad_fn)  # MSELoss
print(loss.grad_fn.next_functions[0][0])  # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])  # ReLU

<MseLossBackward object at 0x7f4aac698f60>
<ExpandBackward object at 0x7f4aac698d30>
<AddBackward0 object at 0x7f4aac698f60>


In [11]:
net.zero_grad()     # zeroes the gradient buffers of all parameters

print('conv1.bias.grad before backward')
print(net.fc1.bias.grad)

loss.backward()

print('conv1.bias.grad after backward')
print(net.fc1.bias.grad)

conv1.bias.grad before backward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.])
conv1.bias.grad after backward
tensor([ 0.0000,  0.0000,  0.0000,  0.0000, -0.0201,  0.0000,  0.0000,  0.0044,
         0.0000,  0.0099,  0.0134,  0.0022, -0.0128,  0.0000,  0.0000, -0.0087,
        -0.0214, -0.0008,  0.0000,  0.0038,  0.0075,  0.0000,  0.0000,  0.0000,
        -0.0186, -0.0041,  0.0093,  0.0000, -0.0132,  0.0000,  0.0028,  0.0106,
        -0.0196,  0.0254,  0.0000,  0.0000, -0.0110,  0.0000,  0.0034,  0.0000,
        -0.0131, -0.0115,  0.0000, -0.0038,  0.0125,  0.0000, -0.0031,  0.0048,
         0.0000,

In [12]:
learning_rate = 0.01
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)


In [13]:
import torch.optim as optim

# create your optimizer
learning_rate = 1e-6
#optimizer = optim.SGD(net.parameters(), lr=0.01)
optimizer = optim.SGD(net.parameters(), lr=0.01)

# in your training loop:
for t in range(5000):
    optimizer.zero_grad()   # zero the gradient buffers
    output = net(input)
    loss = criterion(output, target)
    loss.backward()
    print(t, loss.item())
    optimizer.step()    # Does the update


0 0.5983940958976746
1 0.5740922689437866
2 0.550510585308075
3 0.5275876522064209
4 0.5052733421325684
5 0.48352739214897156
6 0.46231862902641296
7 0.44162312150001526
8 0.42142266035079956
9 0.4017053246498108
10 0.3824838697910309
11 0.36392924189567566
12 0.3460129201412201
13 0.3286775052547455
14 0.3118070662021637
15 0.29537126421928406
16 0.2793858051300049
17 0.26386725902557373
18 0.2488192617893219
19 0.2342539280653
20 0.2201826423406601
21 0.2066156566143036
22 0.19356201589107513
23 0.1810290515422821
24 0.169022336602211
25 0.15754550695419312
26 0.14660008251667023
27 0.13618548214435577
28 0.1262989193201065
29 0.11693547666072845
30 0.10808819532394409
31 0.09974806010723114
32 0.09190420061349869
33 0.08454403281211853
34 0.07765346765518188
35 0.07121698558330536
36 0.06521794945001602
37 0.059638749808073044
38 0.05446101352572441
39 0.04966581612825394
40 0.04523388296365738
41 0.041145745664834976
42 0.03738192096352577
43 0.03392308950424194
44 0.03075022436678

496 1.5545342248700356e-13
497 1.5545342248700356e-13
498 1.5119017691446468e-13
499 1.5119017691446468e-13
500 1.5119017691446468e-13
501 1.4912515124664016e-13
502 1.4912515124664016e-13
503 1.508126983755867e-13
504 1.4912515124664016e-13
505 1.4912515124664016e-13
506 1.4912515124664016e-13
507 1.4761525064365538e-13
508 1.509237206780492e-13
509 1.509237206780492e-13
510 1.4817036215596796e-13
511 1.4817036215596796e-13
512 1.4559465015984846e-13
513 1.4559465015984846e-13
514 1.4559465015984846e-13
515 1.4559465015984846e-13
516 1.4426238253029827e-13
517 1.4426238253029827e-13
518 1.427524819273135e-13
519 1.427524819273135e-13
520 1.356470545697125e-13
521 1.3173906139151564e-13
522 1.3173906139151564e-13
523 1.3173906139151564e-13
524 1.3173906139151564e-13
525 1.2976287253919916e-13
526 1.2976287253919916e-13
527 1.2976287253919916e-13
528 1.2976287253919916e-13
529 1.2976287253919916e-13
530 1.2976287253919916e-13
531 1.2514433391473678e-13
532 1.2514433391473678e-13
533 1.2

1001 2.4180657815149088e-14
1002 2.4180657815149088e-14
1003 2.4180657815149088e-14
1004 2.4180657815149088e-14
1005 2.4180657815149088e-14
1006 2.4180657815149088e-14
1007 2.4180657815149088e-14
1008 2.4180657815149088e-14
1009 2.4180657815149088e-14
1010 2.4180657815149088e-14
1011 2.4180657815149088e-14
1012 2.4180657815149088e-14
1013 2.4180657815149088e-14
1014 2.4180657815149088e-14
1015 2.4180657815149088e-14
1016 2.4180657815149088e-14
1017 2.4180657815149088e-14
1018 2.4180657815149088e-14
1019 2.4180657815149088e-14
1020 2.4180657815149088e-14
1021 2.4180657815149088e-14
1022 2.4180657815149088e-14
1023 2.3359091760486936e-14
1024 2.2915002550636873e-14
1025 2.2915002550636873e-14
1026 2.2915002550636873e-14
1027 2.2915002550636873e-14
1028 2.2915002550636873e-14
1029 2.2915002550636873e-14
1030 2.2915002550636873e-14
1031 2.2915002550636873e-14
1032 2.2915002550636873e-14
1033 2.2915002550636873e-14
1034 2.2915002550636873e-14
1035 2.2915002550636873e-14
1036 2.2915002550636

1503 1.2190249149197398e-14
1504 1.2190249149197398e-14
1505 1.2190249149197398e-14
1506 1.2190249149197398e-14
1507 1.2190249149197398e-14
1508 1.2190249149197398e-14
1509 1.2190249149197398e-14
1510 1.2190249149197398e-14
1511 1.2190249149197398e-14
1512 1.2190249149197398e-14
1513 1.2190249149197398e-14
1514 1.2190249149197398e-14
1515 1.2190249149197398e-14
1516 1.2190249149197398e-14
1517 1.2190249149197398e-14
1518 1.2190249149197398e-14
1519 1.2190249149197398e-14
1520 1.2190249149197398e-14
1521 1.2190249149197398e-14
1522 1.2190249149197398e-14
1523 1.2190249149197398e-14
1524 1.2190249149197398e-14
1525 1.2190249149197398e-14
1526 1.2190249149197398e-14
1527 1.2190249149197398e-14
1528 1.2190249149197398e-14
1529 1.2190249149197398e-14
1530 1.2190249149197398e-14
1531 1.2190249149197398e-14
1532 1.2190249149197398e-14
1533 1.2190249149197398e-14
1534 1.2190249149197398e-14
1535 1.2190249149197398e-14
1536 1.2190249149197398e-14
1537 1.2190249149197398e-14
1538 1.2190249149197

1987 8.970602377784444e-15
1988 8.970602377784444e-15
1989 8.970602377784444e-15
1990 8.970602377784444e-15
1991 8.970602377784444e-15
1992 8.970602377784444e-15
1993 8.970602377784444e-15
1994 8.970602377784444e-15
1995 8.970602377784444e-15
1996 8.970602377784444e-15
1997 8.970602377784444e-15
1998 8.970602377784444e-15
1999 8.970602377784444e-15
2000 8.970602377784444e-15
2001 8.970602377784444e-15
2002 8.970602377784444e-15
2003 8.970602377784444e-15
2004 8.970602377784444e-15
2005 8.970602377784444e-15
2006 8.970602377784444e-15
2007 8.970602377784444e-15
2008 8.970602377784444e-15
2009 8.970602377784444e-15
2010 8.970602377784444e-15
2011 8.970602377784444e-15
2012 8.970602377784444e-15
2013 8.970602377784444e-15
2014 8.970602377784444e-15
2015 8.970602377784444e-15
2016 8.970602377784444e-15
2017 8.970602377784444e-15
2018 8.970602377784444e-15
2019 8.970602377784444e-15
2020 8.970602377784444e-15
2021 8.970602377784444e-15
2022 8.970602377784444e-15
2023 8.970602377784444e-15
2

2470 7.638334748234256e-15
2471 7.638334748234256e-15
2472 7.638334748234256e-15
2473 7.638334748234256e-15
2474 7.638334748234256e-15
2475 7.638334748234256e-15
2476 7.638334748234256e-15
2477 7.638334748234256e-15
2478 7.638334748234256e-15
2479 7.638334748234256e-15
2480 7.638334748234256e-15
2481 7.638334748234256e-15
2482 7.638334748234256e-15
2483 7.638334748234256e-15
2484 7.638334748234256e-15
2485 7.638334748234256e-15
2486 7.638334748234256e-15
2487 7.638334748234256e-15
2488 7.638334748234256e-15
2489 7.638334748234256e-15
2490 7.638334748234256e-15
2491 7.638334748234256e-15
2492 7.638334748234256e-15
2493 7.638334748234256e-15
2494 7.638334748234256e-15
2495 7.638334748234256e-15
2496 7.638334748234256e-15
2497 7.638334748234256e-15
2498 7.638334748234256e-15
2499 7.638334748234256e-15
2500 7.638334748234256e-15
2501 7.638334748234256e-15
2502 7.638334748234256e-15
2503 7.638334748234256e-15
2504 7.638334748234256e-15
2505 7.638334748234256e-15
2506 7.638334748234256e-15
2

2915 7.638334748234256e-15
2916 7.638334748234256e-15
2917 7.638334748234256e-15
2918 7.638334748234256e-15
2919 7.638334748234256e-15
2920 7.638334748234256e-15
2921 7.638334748234256e-15
2922 7.638334748234256e-15
2923 7.638334748234256e-15
2924 7.638334748234256e-15
2925 7.638334748234256e-15
2926 7.638334748234256e-15
2927 7.638334748234256e-15
2928 7.638334748234256e-15
2929 7.638334748234256e-15
2930 7.638334748234256e-15
2931 7.638334748234256e-15
2932 7.638334748234256e-15
2933 7.638334748234256e-15
2934 7.638334748234256e-15
2935 7.638334748234256e-15
2936 7.638334748234256e-15
2937 7.638334748234256e-15
2938 7.638334748234256e-15
2939 7.638334748234256e-15
2940 7.638334748234256e-15
2941 7.638334748234256e-15
2942 7.638334748234256e-15
2943 7.638334748234256e-15
2944 7.638334748234256e-15
2945 7.638334748234256e-15
2946 7.638334748234256e-15
2947 7.638334748234256e-15
2948 7.638334748234256e-15
2949 7.638334748234256e-15
2950 7.638334748234256e-15
2951 7.638334748234256e-15
2

3379 7.993605777301127e-15
3380 7.993605777301127e-15
3381 7.993605777301127e-15
3382 7.993605777301127e-15
3383 7.993605777301127e-15
3384 7.993605777301127e-15
3385 7.993605777301127e-15
3386 7.993605777301127e-15
3387 7.993605777301127e-15
3388 7.993605777301127e-15
3389 7.993605777301127e-15
3390 7.993605777301127e-15
3391 7.993605777301127e-15
3392 7.993605777301127e-15
3393 7.993605777301127e-15
3394 7.993605777301127e-15
3395 7.993605777301127e-15
3396 7.638334748234256e-15
3397 7.638334748234256e-15
3398 7.638334748234256e-15
3399 7.638334748234256e-15
3400 7.638334748234256e-15
3401 7.638334748234256e-15
3402 7.638334748234256e-15
3403 7.638334748234256e-15
3404 7.638334748234256e-15
3405 7.638334748234256e-15
3406 7.638334748234256e-15
3407 7.638334748234256e-15
3408 7.638334748234256e-15
3409 7.638334748234256e-15
3410 7.638334748234256e-15
3411 7.638334748234256e-15
3412 7.638334748234256e-15
3413 7.638334748234256e-15
3414 7.638334748234256e-15
3415 7.638334748234256e-15
3

3794 7.638334748234256e-15
3795 7.638334748234256e-15
3796 7.638334748234256e-15
3797 7.638334748234256e-15
3798 7.638334748234256e-15
3799 7.638334748234256e-15
3800 7.638334748234256e-15
3801 7.638334748234256e-15
3802 7.638334748234256e-15
3803 7.638334748234256e-15
3804 7.638334748234256e-15
3805 7.638334748234256e-15
3806 7.638334748234256e-15
3807 7.638334748234256e-15
3808 7.638334748234256e-15
3809 7.638334748234256e-15
3810 7.638334748234256e-15
3811 7.638334748234256e-15
3812 7.638334748234256e-15
3813 7.638334748234256e-15
3814 7.638334748234256e-15
3815 7.638334748234256e-15
3816 7.638334748234256e-15
3817 7.638334748234256e-15
3818 7.638334748234256e-15
3819 7.638334748234256e-15
3820 7.638334748234256e-15
3821 7.638334748234256e-15
3822 7.638334748234256e-15
3823 7.638334748234256e-15
3824 7.638334748234256e-15
3825 7.638334748234256e-15
3826 7.638334748234256e-15
3827 7.638334748234256e-15
3828 7.638334748234256e-15
3829 7.638334748234256e-15
3830 7.638334748234256e-15
3

4244 7.571720689130389e-15
4245 7.571720689130389e-15
4246 7.571720689130389e-15
4247 7.571720689130389e-15
4248 7.571720689130389e-15
4249 7.571720689130389e-15
4250 7.571720689130389e-15
4251 7.571720689130389e-15
4252 7.571720689130389e-15
4253 7.571720689130389e-15
4254 7.571720689130389e-15
4255 7.571720689130389e-15
4256 7.571720689130389e-15
4257 7.571720689130389e-15
4258 7.571720689130389e-15
4259 7.571720689130389e-15
4260 7.571720689130389e-15
4261 7.571720689130389e-15
4262 7.571720689130389e-15
4263 7.571720689130389e-15
4264 7.571720689130389e-15
4265 7.571720689130389e-15
4266 7.571720689130389e-15
4267 7.571720689130389e-15
4268 7.571720689130389e-15
4269 7.571720689130389e-15
4270 7.571720689130389e-15
4271 7.571720689130389e-15
4272 7.571720689130389e-15
4273 7.571720689130389e-15
4274 7.571720689130389e-15
4275 7.571720689130389e-15
4276 7.571720689130389e-15
4277 7.571720689130389e-15
4278 7.571720689130389e-15
4279 7.571720689130389e-15
4280 7.571720689130389e-15
4

4736 6.505906754896828e-15
4737 6.505906754896828e-15
4738 6.505906754896828e-15
4739 6.505906754896828e-15
4740 6.505906754896828e-15
4741 6.505906754896828e-15
4742 6.505906754896828e-15
4743 6.505906754896828e-15
4744 6.505906754896828e-15
4745 6.505906754896828e-15
4746 6.505906754896828e-15
4747 6.505906754896828e-15
4748 6.505906754896828e-15
4749 6.505906754896828e-15
4750 6.505906754896828e-15
4751 6.505906754896828e-15
4752 6.505906754896828e-15
4753 6.505906754896828e-15
4754 6.505906754896828e-15
4755 6.505906754896828e-15
4756 6.505906754896828e-15
4757 6.505906754896828e-15
4758 6.505906754896828e-15
4759 6.505906754896828e-15
4760 6.505906754896828e-15
4761 6.505906754896828e-15
4762 6.505906754896828e-15
4763 6.505906754896828e-15
4764 6.505906754896828e-15
4765 6.505906754896828e-15
4766 6.505906754896828e-15
4767 6.505906754896828e-15
4768 6.505906754896828e-15
4769 6.505906754896828e-15
4770 6.505906754896828e-15
4771 6.505906754896828e-15
4772 6.505906754896828e-15
4

In [14]:
# 1. Para predecir un valor de entrada state-action
print(net.fc1)
print(net.fc2)
print(net.fc3)
input = torch.randn(2)
input = torch.tensor([1, 2.0])
out = net(input)
print(input)
print(input, out)

Linear(in_features=2, out_features=100, bias=True)
Linear(in_features=100, out_features=100, bias=True)
Linear(in_features=100, out_features=10, bias=True)
tensor([1., 2.])
tensor([1., 2.]) tensor([-1.1987, -0.2445, -0.3766,  0.1018,  0.5629,  0.4160, -0.8819,  0.6287,
        -0.7383,  0.2487], grad_fn=<AddBackward0>)


In [None]:
# 2. Gradientes de cada parametro con respecto a la Advantage Function
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)

In [None]:
# 3. actualizar parameters:
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)

#o capaz lo pueda definir en el backward pass


In [None]:
# Iterate over layers
for name, param in net.named_parameters(): 
    print(name, param.shape) 