In [33]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 1024, 10, 5, 1

# Create random Tensors to hold inputs and outputs.
x = torch.randn(N, D_in)
y = x.sum(dim=1).reshape((N,-1))

# y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.ReLU(),
          torch.nn.Linear(H, D_out),
        )
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(10000):
  # Forward pass: compute predicted y by passing x to the model.
  y_pred = model(x)

  # Compute and print loss.
  loss = loss_fn(y_pred, y)
  print(t, loss.item())
  
  # Before the backward pass, use the optimizer object to zero all of the
  # gradients for the Tensors it will update (which are the learnable weights
  # of the model)
  optimizer.zero_grad()

  # Backward pass: compute gradient of the loss with respect to model parameters
  loss.backward()

  # Calling the step function on an Optimizer makes an update to its parameters
  optimizer.step()

0 10465.658203125
1 10463.7158203125
2 10461.7724609375
3 10459.830078125
4 10457.888671875
5 10455.94921875
6 10454.01171875
7 10452.0732421875
8 10450.134765625
9 10448.197265625
10 10446.259765625
11 10444.3251953125
12 10442.3955078125
13 10440.4677734375
14 10438.541015625
15 10436.615234375
16 10434.6923828125
17 10432.76953125
18 10430.8486328125
19 10428.9287109375
20 10427.01171875
21 10425.09765625
22 10423.185546875
23 10421.2802734375
24 10419.37890625
25 10417.478515625
26 10415.5791015625
27 10413.6806640625
28 10411.7822265625
29 10409.8857421875
30 10407.9912109375
31 10406.09765625
32 10404.2060546875
33 10402.3154296875
34 10400.42578125
35 10398.5390625
36 10396.6533203125
37 10394.76953125
38 10392.88671875
39 10391.0048828125
40 10389.1240234375
41 10387.2451171875
42 10385.3681640625
43 10383.494140625
44 10381.6201171875
45 10379.74609375
46 10377.8759765625
47 10376.0078125
48 10374.1396484375
49 10372.2744140625
50 10370.4072265625
51 10368.5400390625
52 10366.

455 9585.21875
456 9583.130859375
457 9581.0419921875
458 9578.9521484375
459 9576.861328125
460 9574.7705078125
461 9572.6787109375
462 9570.5859375
463 9568.490234375
464 9566.39453125
465 9564.2978515625
466 9562.19921875
467 9560.0986328125
468 9557.9951171875
469 9555.8916015625
470 9553.787109375
471 9551.681640625
472 9549.57421875
473 9547.466796875
474 9545.357421875
475 9543.248046875
476 9541.13671875
477 9539.0263671875
478 9536.9140625
479 9534.8017578125
480 9532.6875
481 9530.572265625
482 9528.4580078125
483 9526.3408203125
484 9524.2236328125
485 9522.1044921875
486 9519.9833984375
487 9517.8603515625
488 9515.7373046875
489 9513.61328125
490 9511.4873046875
491 9509.361328125
492 9507.234375
493 9505.10546875
494 9502.9736328125
495 9500.8408203125
496 9498.70703125
497 9496.5732421875
498 9494.4375
499 9492.30078125
500 9490.1640625
501 9488.0263671875
502 9485.88671875
503 9483.7451171875
504 9481.603515625
505 9479.4599609375
506 9477.31640625
507 9475.1708984375
5

984 8314.669921875
985 8312.0146484375
986 8309.35546875
987 8306.6943359375
988 8304.0341796875
989 8301.3720703125
990 8298.7119140625
991 8296.0498046875
992 8293.38671875
993 8290.72265625
994 8288.0576171875
995 8285.3916015625
996 8282.7236328125
997 8280.056640625
998 8277.3876953125
999 8274.71875
1000 8272.048828125
1001 8269.37890625
1002 8266.70703125
1003 8264.0361328125
1004 8261.3583984375
1005 8258.673828125
1006 8255.9873046875
1007 8253.302734375
1008 8250.6162109375
1009 8247.9296875
1010 8245.2412109375
1011 8242.5517578125
1012 8239.861328125
1013 8237.169921875
1014 8234.4775390625
1015 8231.7841796875
1016 8229.08984375
1017 8226.39453125
1018 8223.6982421875
1019 8221.0009765625
1020 8218.302734375
1021 8215.6044921875
1022 8212.9052734375
1023 8210.205078125
1024 8207.505859375
1025 8204.8056640625
1026 8202.10546875
1027 8199.4052734375
1028 8196.7041015625
1029 8194.0029296875
1030 8191.3017578125
1031 8188.599609375
1032 8185.8974609375
1033 8183.19482421875


1498 6865.57861328125
1499 6862.69140625
1500 6859.80419921875
1501 6856.91748046875
1502 6854.03076171875
1503 6851.14453125
1504 6848.25830078125
1505 6845.37353515625
1506 6842.48876953125
1507 6839.60400390625
1508 6836.720703125
1509 6833.837890625
1510 6830.955078125
1511 6828.07275390625
1512 6825.19091796875
1513 6822.30859375
1514 6819.427734375
1515 6816.5458984375
1516 6813.6630859375
1517 6810.7802734375
1518 6807.89990234375
1519 6805.01953125
1520 6802.13916015625
1521 6799.25927734375
1522 6796.37939453125
1523 6793.5
1524 6790.62255859375
1525 6787.74560546875
1526 6784.86865234375
1527 6781.9951171875
1528 6779.123046875
1529 6776.25244140625
1530 6773.38232421875
1531 6770.51220703125
1532 6767.64306640625
1533 6764.77392578125
1534 6761.90576171875
1535 6759.03759765625
1536 6756.169921875
1537 6753.30322265625
1538 6750.43603515625
1539 6747.57275390625
1540 6744.708984375
1541 6741.8447265625
1542 6738.98095703125
1543 6736.11767578125
1544 6733.25537109375
1545 67

1963 5563.52978515625
1964 5560.8408203125
1965 5558.15283203125
1966 5555.46630859375
1967 5552.7802734375
1968 5550.0947265625
1969 5547.41064453125
1970 5544.72705078125
1971 5542.04443359375
1972 5539.36279296875
1973 5536.681640625
1974 5534.00146484375
1975 5531.32177734375
1976 5528.64306640625
1977 5525.96533203125
1978 5523.2880859375
1979 5520.611328125
1980 5517.935546875
1981 5515.26025390625
1982 5512.58642578125
1983 5509.9130859375
1984 5507.23828125
1985 5504.5625
1986 5501.88671875
1987 5499.21142578125
1988 5496.5361328125
1989 5493.86181640625
1990 5491.18798828125
1991 5488.5146484375
1992 5485.84130859375
1993 5483.1689453125
1994 5480.49755859375
1995 5477.826171875
1996 5475.15576171875
1997 5472.4853515625
1998 5469.81591796875
1999 5467.14697265625
2000 5464.478515625
2001 5461.81103515625
2002 5459.14404296875
2003 5456.478515625
2004 5453.8134765625
2005 5451.1494140625
2006 5448.4853515625
2007 5445.82177734375
2008 5443.15869140625
2009 5440.49609375
2010 5

2370 4526.3759765625
2371 4524.01171875
2372 4521.64794921875
2373 4519.28564453125
2374 4516.923828125
2375 4514.5634765625
2376 4512.20361328125
2377 4509.84521484375
2378 4507.48828125
2379 4505.13232421875
2380 4502.77880859375
2381 4500.42822265625
2382 4498.0791015625
2383 4495.7314453125
2384 4493.38427734375
2385 4491.0390625
2386 4488.6943359375
2387 4486.3505859375
2388 4484.0078125
2389 4481.66552734375
2390 4479.32421875
2391 4476.98486328125
2392 4474.646484375
2393 4472.31005859375
2394 4469.97509765625
2395 4467.64111328125
2396 4465.3076171875
2397 4462.97509765625
2398 4460.64404296875
2399 4458.31298828125
2400 4455.9833984375
2401 4453.654296875
2402 4451.3251953125
2403 4448.9970703125
2404 4446.66943359375
2405 4444.3427734375
2406 4442.0166015625
2407 4439.6923828125
2408 4437.36865234375
2409 4435.0458984375
2410 4432.7236328125
2411 4430.40234375
2412 4428.08203125
2413 4425.7626953125
2414 4423.443359375
2415 4421.12548828125
2416 4418.80810546875
2417 4416.491

2819 3557.901611328125
2820 3555.952392578125
2821 3554.003662109375
2822 3552.055908203125
2823 3550.10888671875
2824 3548.162353515625
2825 3546.216552734375
2826 3544.271484375
2827 3542.326904296875
2828 3540.383056640625
2829 3538.43994140625
2830 3536.49755859375
2831 3534.55615234375
2832 3532.614990234375
2833 3530.6748046875
2834 3528.73486328125
2835 3526.7958984375
2836 3524.857666015625
2837 3522.921142578125
2838 3520.985595703125
2839 3519.051025390625
2840 3517.116943359375
2841 3515.183837890625
2842 3513.25146484375
2843 3511.31982421875
2844 3509.38916015625
2845 3507.458984375
2846 3505.529296875
2847 3503.6005859375
2848 3501.6728515625
2849 3499.745361328125
2850 3497.81884765625
2851 3495.893310546875
2852 3493.96923828125
2853 3492.0458984375
2854 3490.12353515625
2855 3488.2021484375
2856 3486.28125
2857 3484.36181640625
2858 3482.443115234375
2859 3480.525146484375
2860 3478.608154296875
2861 3476.69189453125
2862 3474.7763671875
2863 3472.861572265625
2864 347

3286 2729.80029296875
3287 2728.20703125
3288 2726.614501953125
3289 2725.02294921875
3290 2723.431640625
3291 2721.841064453125
3292 2720.251220703125
3293 2718.66259765625
3294 2717.074951171875
3295 2715.48779296875
3296 2713.902099609375
3297 2712.317626953125
3298 2710.733642578125
3299 2709.150390625
3300 2707.56787109375
3301 2705.986083984375
3302 2704.40478515625
3303 2702.824462890625
3304 2701.244873046875
3305 2699.665771484375
3306 2698.08740234375
3307 2696.50927734375
3308 2694.931884765625
3309 2693.355224609375
3310 2691.779296875
3311 2690.203857421875
3312 2688.627685546875
3313 2687.05224609375
3314 2685.477294921875
3315 2683.90283203125
3316 2682.32861328125
3317 2680.75537109375
3318 2679.1826171875
3319 2677.610107421875
3320 2676.03857421875
3321 2674.467529296875
3322 2672.89697265625
3323 2671.326904296875
3324 2669.757568359375
3325 2668.18896484375
3326 2666.62109375
3327 2665.0537109375
3328 2663.48681640625
3329 2661.920654296875
3330 2660.35498046875
333

3674 2160.864990234375
3675 2159.531982421875
3676 2158.199462890625
3677 2156.86767578125
3678 2155.536376953125
3679 2154.20556640625
3680 2152.87548828125
3681 2151.5458984375
3682 2150.216796875
3683 2148.88818359375
3684 2147.560302734375
3685 2146.233642578125
3686 2144.907958984375
3687 2143.58251953125
3688 2142.2578125
3689 2140.933837890625
3690 2139.610595703125
3691 2138.28759765625
3692 2136.96533203125
3693 2135.643798828125
3694 2134.32275390625
3695 2133.002197265625
3696 2131.682373046875
3697 2130.36328125
3698 2129.04443359375
3699 2127.726318359375
3700 2126.409912109375
3701 2125.0947265625
3702 2123.780029296875
3703 2122.466064453125
3704 2121.15283203125
3705 2119.840087890625
3706 2118.528076171875
3707 2117.216796875
3708 2115.90625
3709 2114.59765625
3710 2113.28955078125
3711 2111.982421875
3712 2110.676025390625
3713 2109.3701171875
3714 2108.064697265625
3715 2106.76025390625
3716 2105.456298828125
3717 2104.1533203125
3718 2102.851318359375
3719 2101.5500

4141 1601.681884765625
4142 1600.610595703125
4143 1599.5399169921875
4144 1598.4703369140625
4145 1597.4014892578125
4146 1596.333251953125
4147 1595.265625
4148 1594.198486328125
4149 1593.132080078125
4150 1592.06591796875
4151 1591.0006103515625
4152 1589.935791015625
4153 1588.871337890625
4154 1587.8076171875
4155 1586.744140625
4156 1585.68115234375
4157 1584.61865234375
4158 1583.5567626953125
4159 1582.49560546875
4160 1581.4349365234375
4161 1580.3746337890625
4162 1579.3150634765625
4163 1578.2559814453125
4164 1577.197265625
4165 1576.13916015625
4166 1575.081787109375
4167 1574.0247802734375
4168 1572.96826171875
4169 1571.912353515625
4170 1570.8568115234375
4171 1569.8018798828125
4172 1568.7474365234375
4173 1567.6934814453125
4174 1566.64013671875
4175 1565.58740234375
4176 1564.534912109375
4177 1563.483154296875
4178 1562.4317626953125
4179 1561.3809814453125
4180 1560.33056640625
4181 1559.281005859375
4182 1558.2318115234375
4183 1557.18310546875
4184 1556.13488769

4566 1193.7242431640625
4567 1192.8717041015625
4568 1192.01953125
4569 1191.1680908203125
4570 1190.317138671875
4571 1189.4664306640625
4572 1188.6163330078125
4573 1187.766845703125
4574 1186.917724609375
4575 1186.0689697265625
4576 1185.2208251953125
4577 1184.3731689453125
4578 1183.52587890625
4579 1182.6790771484375
4580 1181.832763671875
4581 1180.98681640625
4582 1180.1414794921875
4583 1179.296630859375
4584 1178.4522705078125
4585 1177.6082763671875
4586 1176.7646484375
4587 1175.9217529296875
4588 1175.07958984375
4589 1174.2376708984375
4590 1173.396484375
4591 1172.5556640625
4592 1171.71533203125
4593 1170.87548828125
4594 1170.0362548828125
4595 1169.197509765625
4596 1168.359130859375
4597 1167.521240234375
4598 1166.6839599609375
4599 1165.8470458984375
4600 1165.0106201171875
4601 1164.1748046875
4602 1163.33935546875
4603 1162.50439453125
4604 1161.6697998046875
4605 1160.835693359375
4606 1160.0020751953125
4607 1159.1689453125
4608 1158.3363037109375
4609 1157.50

5060 827.8118896484375
5061 827.178466796875
5062 826.545654296875
5063 825.9132690429688
5064 825.2811279296875
5065 824.6493530273438
5066 824.01806640625
5067 823.3870849609375
5068 822.756591796875
5069 822.126708984375
5070 821.4971923828125
5071 820.867919921875
5072 820.2391357421875
5073 819.6107788085938
5074 818.9827270507812
5075 818.3550415039062
5076 817.72802734375
5077 817.101318359375
5078 816.47509765625
5079 815.84912109375
5080 815.2235717773438
5081 814.5985107421875
5082 813.9735717773438
5083 813.3494873046875
5084 812.7256469726562
5085 812.1022338867188
5086 811.4791259765625
5087 810.8564453125
5088 810.2340698242188
5089 809.6121826171875
5090 808.9907836914062
5091 808.369873046875
5092 807.749267578125
5093 807.1290893554688
5094 806.5091552734375
5095 805.8896484375
5096 805.2705688476562
5097 804.6519165039062
5098 804.033447265625
5099 803.4153442382812
5100 802.7979736328125
5101 802.1812133789062
5102 801.56494140625
5103 800.9489135742188
5104 800.3336

5545 566.3527221679688
5546 565.9013671875
5547 565.4501953125
5548 564.9996948242188
5549 564.5493774414062
5550 564.0994262695312
5551 563.6497192382812
5552 563.2002563476562
5553 562.7510986328125
5554 562.3024291992188
5555 561.85400390625
5556 561.4059448242188
5557 560.9581298828125
5558 560.5106201171875
5559 560.0634155273438
5560 559.6165161132812
5561 559.170166015625
5562 558.7239990234375
5563 558.2782592773438
5564 557.8326416015625
5565 557.3873901367188
5566 556.9423828125
5567 556.4979248046875
5568 556.0537109375
5569 555.6098022460938
5570 555.1661987304688
5571 554.722900390625
5572 554.2799072265625
5573 553.8372802734375
5574 553.3952026367188
5575 552.9532470703125
5576 552.5115966796875
5577 552.0703125
5578 551.6292724609375
5579 551.1885986328125
5580 550.7483520507812
5581 550.3084106445312
5582 549.8687744140625
5583 549.429443359375
5584 548.9904174804688
5585 548.5516967773438
5586 548.1134643554688
5587 547.6755981445312
5588 547.2379150390625
5589 546.80

6027 384.5555419921875
6028 384.2478942871094
6029 383.9404602050781
6030 383.6332092285156
6031 383.3262634277344
6032 383.01971435546875
6033 382.7133483886719
6034 382.40728759765625
6035 382.1014099121094
6036 381.7957763671875
6037 381.49041748046875
6038 381.1854248046875
6039 380.8807373046875
6040 380.5762023925781
6041 380.27191162109375
6042 379.9678955078125
6043 379.66412353515625
6044 379.3606872558594
6045 379.0576171875
6046 378.75469970703125
6047 378.45208740234375
6048 378.14971923828125
6049 377.8476867675781
6050 377.5459289550781
6051 377.2444152832031
6052 376.9430847167969
6053 376.6419372558594
6054 376.34112548828125
6055 376.04071044921875
6056 375.74041748046875
6057 375.4404296875
6058 375.1407165527344
6059 374.84112548828125
6060 374.5417785644531
6061 374.2428894042969
6062 373.9442443847656
6063 373.64569091796875
6064 373.3471984863281
6065 373.0489501953125
6066 372.7508239746094
6067 372.453125
6068 372.15557861328125
6069 371.85833740234375
6070 371.

6505 263.75689697265625
6506 263.5535888671875
6507 263.3504943847656
6508 263.1475524902344
6509 262.9447021484375
6510 262.74212646484375
6511 262.5396423339844
6512 262.3375549316406
6513 262.1355285644531
6514 261.93365478515625
6515 261.7319030761719
6516 261.5302734375
6517 261.3290100097656
6518 261.12786865234375
6519 260.9268798828125
6520 260.72607421875
6521 260.5254211425781
6522 260.3249816894531
6523 260.1247863769531
6524 259.9247741699219
6525 259.7249450683594
6526 259.525146484375
6527 259.325439453125
6528 259.12591552734375
6529 258.92657470703125
6530 258.7273864746094
6531 258.5284423828125
6532 258.32952880859375
6533 258.130859375
6534 257.9324035644531
6535 257.73431396484375
6536 257.5362548828125
6537 257.33843994140625
6538 257.14080810546875
6539 256.94329833984375
6540 256.74609375
6541 256.5491027832031
6542 256.3522644042969
6543 256.1556396484375
6544 255.9591522216797
6545 255.76284790039062
6546 255.56680297851562
6547 255.37098693847656
6548 255.1753

6969 187.00177001953125
6970 186.86968994140625
6971 186.73765563964844
6972 186.60586547851562
6973 186.4742431640625
6974 186.34266662597656
6975 186.21127319335938
6976 186.08001708984375
6977 185.94891357421875
6978 185.8179168701172
6979 185.68711853027344
6980 185.556396484375
6981 185.42584228515625
6982 185.29539489746094
6983 185.1651611328125
6984 185.0350799560547
6985 184.90509033203125
6986 184.7752685546875
6987 184.64549255371094
6988 184.5159149169922
6989 184.3863983154297
6990 184.25656127929688
6991 184.12684631347656
6992 183.99716186523438
6993 183.86758422851562
6994 183.7380828857422
6995 183.60879516601562
6996 183.4795684814453
6997 183.35047912597656
6998 183.22142028808594
6999 183.092529296875
7000 182.96380615234375
7001 182.83518981933594
7002 182.70668029785156
7003 182.5782928466797
7004 182.45001220703125
7005 182.3218231201172
7006 182.19381713867188
7007 182.0659637451172
7008 181.938232421875
7009 181.81057739257812
7010 181.68295288085938
7011 181.5

7461 135.02752685546875
7462 134.94508361816406
7463 134.86264038085938
7464 134.7802734375
7465 134.697998046875
7466 134.6157684326172
7467 134.53363037109375
7468 134.45164489746094
7469 134.3697052001953
7470 134.287841796875
7471 134.20602416992188
7472 134.12432861328125
7473 134.04269409179688
7474 133.96124267578125
7475 133.8798065185547
7476 133.79847717285156
7477 133.71719360351562
7478 133.63600158691406
7479 133.55490112304688
7480 133.47396850585938
7481 133.39305114746094
7482 133.312255859375
7483 133.23150634765625
7484 133.15084838867188
7485 133.07025146484375
7486 132.98974609375
7487 132.9093475341797
7488 132.82907104492188
7489 132.74887084960938
7490 132.668701171875
7491 132.58863830566406
7492 132.50869750976562
7493 132.4288330078125
7494 132.34906005859375
7495 132.2693634033203
7496 132.1897735595703
7497 132.11021423339844
7498 132.03074645996094
7499 131.95138549804688
7500 131.87213134765625
7501 131.79298400878906
7502 131.71389770507812
7503 131.63485

7960 102.24430847167969
7961 102.19244384765625
7962 102.14065551757812
7963 102.08888244628906
7964 102.03711700439453
7965 101.98538970947266
7966 101.93372344970703
7967 101.88208770751953
7968 101.83048248291016
7969 101.77892303466797
7970 101.72747039794922
7971 101.67603302001953
7972 101.62460327148438
7973 101.57328033447266
7974 101.52191162109375
7975 101.4706039428711
7976 101.4193344116211
7977 101.36809539794922
7978 101.31697082519531
7979 101.26588439941406
7980 101.21484375
7981 101.16386413574219
7982 101.11294555664062
7983 101.06206512451172
7984 101.01123046875
7985 100.9604263305664
7986 100.90975952148438
7987 100.85908508300781
7988 100.8084716796875
7989 100.75787353515625
7990 100.70733642578125
7991 100.6568832397461
7992 100.60643005371094
7993 100.55606842041016
7994 100.50574493408203
7995 100.45547485351562
7996 100.40528106689453
7997 100.35508728027344
7998 100.3049087524414
7999 100.25482177734375
8000 100.20477294921875
8001 100.15476989746094
8002 10

8459 81.07569885253906
8460 81.04176330566406
8461 81.00784301757812
8462 80.97393798828125
8463 80.9400634765625
8464 80.90621185302734
8465 80.87239074707031
8466 80.83863067626953
8467 80.80487823486328
8468 80.77105712890625
8469 80.73723602294922
8470 80.70347595214844
8471 80.66972351074219
8472 80.6360092163086
8473 80.6023178100586
8474 80.56864166259766
8475 80.53499603271484
8476 80.50133514404297
8477 80.46766662597656
8478 80.4339828491211
8479 80.40033721923828
8480 80.36670684814453
8481 80.33307647705078
8482 80.29949188232422
8483 80.26594543457031
8484 80.23241424560547
8485 80.19892120361328
8486 80.16542053222656
8487 80.13199615478516
8488 80.09856414794922
8489 80.065185546875
8490 80.03181457519531
8491 79.9985122680664
8492 79.96531677246094
8493 79.9321060180664
8494 79.89894104003906
8495 79.86580657958984
8496 79.83265686035156
8497 79.79955291748047
8498 79.76643371582031
8499 79.73340606689453
8500 79.70036315917969
8501 79.66738891601562
8502 79.63442993164

8957 66.83607482910156
8958 66.81224822998047
8959 66.78844451904297
8960 66.7646713256836
8961 66.74089050292969
8962 66.7171401977539
8963 66.6933822631836
8964 66.66957092285156
8965 66.64578247070312
8966 66.62199401855469
8967 66.59822845458984
8968 66.57447814941406
8969 66.55075073242188
8970 66.52702331542969
8971 66.5032958984375
8972 66.47962188720703
8973 66.4559555053711
8974 66.43228149414062
8975 66.40865325927734
8976 66.3850326538086
8977 66.3614501953125
8978 66.33787536621094
8979 66.3143081665039
8980 66.2907943725586
8981 66.26728820800781
8982 66.2437973022461
8983 66.22035217285156
8984 66.19692993164062
8985 66.17352294921875
8986 66.150146484375
8987 66.12676239013672
8988 66.1034164428711
8989 66.08008575439453
8990 66.05680084228516
8991 66.03353118896484
8992 66.01026153564453
8993 65.98704528808594
8994 65.96383666992188
8995 65.94064331054688
8996 65.91748046875
8997 65.89434051513672
8998 65.8712387084961
8999 65.84815979003906
9000 65.82510375976562
9001 

9435 56.4086799621582
9436 56.38822937011719
9437 56.36780548095703
9438 56.34738540649414
9439 56.32697296142578
9440 56.30656433105469
9441 56.28615188598633
9442 56.26570129394531
9443 56.24529266357422
9444 56.22488784790039
9445 56.204498291015625
9446 56.18410110473633
9447 56.163734436035156
9448 56.143367767333984
9449 56.12303161621094
9450 56.10268020629883
9451 56.08233642578125
9452 56.061981201171875
9453 56.04166793823242
9454 56.02135467529297
9455 56.00102996826172
9456 55.98073959350586
9457 55.960479736328125
9458 55.940185546875
9459 55.919918060302734
9460 55.89967727661133
9461 55.87944412231445
9462 55.85923385620117
9463 55.839027404785156
9464 55.81884002685547
9465 55.7985954284668
9466 55.77836227416992
9467 55.75813674926758
9468 55.73795700073242
9469 55.717742919921875
9470 55.69755935668945
9471 55.677371978759766
9472 55.657196044921875
9473 55.63702392578125
9474 55.61689758300781
9475 55.59676742553711
9476 55.5766487121582
9477 55.55653381347656
9478 5

9934 46.381134033203125
9935 46.36015319824219
9936 46.3390998840332
9937 46.318023681640625
9938 46.29692840576172
9939 46.27587127685547
9940 46.254764556884766
9941 46.23365783691406
9942 46.21255111694336
9943 46.191471099853516
9944 46.17036437988281
9945 46.14925765991211
9946 46.12811279296875
9947 46.10698699951172
9948 46.08585739135742
9949 46.06473922729492
9950 46.04357147216797
9951 46.02234649658203
9952 46.001094818115234
9953 45.979705810546875
9954 45.95831298828125
9955 45.936927795410156
9956 45.915504455566406
9957 45.89411163330078
9958 45.8726921081543
9959 45.85126495361328
9960 45.82976531982422
9961 45.80826187133789
9962 45.786773681640625
9963 45.76527404785156
9964 45.743770599365234
9965 45.72220993041992
9966 45.700538635253906
9967 45.67882537841797
9968 45.65707778930664
9969 45.635311126708984
9970 45.613494873046875
9971 45.591590881347656
9972 45.56958770751953
9973 45.54753494262695
9974 45.525428771972656
9975 45.503318786621094
9976 45.481185913085

In [35]:
x = torch.zeros(10)
print(x)
model(x)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])


tensor([-0.2419], grad_fn=<AddBackward0>)

---

In [38]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
    
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [40]:
epochs = 10
device = torch.device('cpu')
torch.manual_seed(1)

kwargs = {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=64, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=1000, shuffle=True, **kwargs)


model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!

Test set: Average loss: 0.1018, Accuracy: 9665/10000 (97%)


Test set: Average loss: 0.0608, Accuracy: 9827/10000 (98%)


Test set: Average loss: 0.0562, Accuracy: 9811/10000 (98%)


Test set: Average loss: 0.0408, Accuracy: 9860/10000 (99%)


Test set: Average loss: 0.0382, Accuracy: 9868/10000 (99%)


Test set: Average loss: 0.0336, Accuracy: 9894/10000 (99%)


Test set: Average loss: 0.0342, Accuracy: 9877/10000 (99%)


Test set: Average loss: 0.0392, Accuracy: 9880/10000 (99%)


Test set: Average loss: 0.0292, Accuracy: 9910/10000 (99%)


Test set: Average loss: 0.0313, Accuracy: 9896/10000 (99%)



---

In [25]:
import torch
import numpy as np

import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torchvision import datasets, transforms

from se3cnn import SE3Convolution, SE3Dropout
from se3cnn.blocks import GatedBlock
from se3cnn.non_linearities import ScalarActivation
from se3cnn.dropout import SE3Dropout
from se3cnn import kernel
from se3cnn.filter import low_pass_filter

In [62]:
class PoolGatedBlock(torch.nn.Module):
    def __init__(self,
                 repr_in, repr_out, size, radial_window=kernel.gaussian_window_wrapper,  # kernel params
                 activation=(None, None), pool_size=0, pool_stride=0, stride=1, padding=0, dilation=1, capsule_dropout_p=None,  # conv/nonlinearity/dropout params
                 normalization=None, batch_norm_momentum=0.1,  # batch norm params
                 bias=True, smooth_stride=True, dyn_iso=False, checkpoint=True, verbose=False, transpose=False):
        '''
        :param repr_in: tuple with multiplicities of repr. (1, 3, 5, ..., 15)
        :param repr_out: same but for the output
        :param int size: the filters are cubes of dimension = size x size x size
        :param radial_window: radial window function
        :param activation: (scalar activation, gate activation) which are functions like torch.nn.functional.relu or None
        :param int pool_size: size of torch.nn.AvgPool3d
        :param int pool_stride: stride of torch.nn.AvgPool3d
        :param int stride: stride of the convolution (for torch.nn.functional.conv3d)
        :param int padding: padding of the convolution (for torch.nn.functional.conv3d)
        :param int dilation: dilation of the convolution (for torch.nn.functional.conv3d)
        :param float capsule_dropout_p: dropout probability
        :param str normalization: "batch", "group", "instance" or None
        :param float batch_norm_momentum: batch normalization momentum (ignored if no batch normalization)
        :param bool bias: bias for the gates and scalar fields
        :param bool smooth_stride: apply a low pass filter before the stride
        :param bool dyn_iso: initialize with some sort of Dynamical Isometry (inspired by Algo. 2 in https://arxiv.org/abs/1806.05393)
        '''
        super().__init__()

        if type(activation) is tuple:
            scalar_activation, gate_activation = activation
        else:
            scalar_activation, gate_activation = activation, activation

        self.repr_out = repr_out

        Rs_in = [(m, l) for l, m in enumerate(repr_in)]
        Rs_out_with_gate = [(m, l) for l, m in enumerate(repr_out)]

        if (scalar_activation is not None and repr_out[0] > 0):
            self.scalar_act = ScalarActivation([(repr_out[0], scalar_activation)], bias=bias)
        else:
            self.scalar_act = None

        n_non_scalar = sum(repr_out[1:])
        if gate_activation is not None and n_non_scalar > 0:
            Rs_out_with_gate.append((n_non_scalar, 0))  # concatenate scalar gate capsules after normal capsules
            self.gate_act = ScalarActivation([(n_non_scalar, gate_activation)], bias=bias)
        else:
            self.gate_act = None
            
        if pool_size > 0:
            self.pool = torch.nn.AvgPool3d(kernel_size=pool_size, stride=pool_stride)

        if normalization == None:
            Convolution = SE3Convolution
        elif normalization == "batch":
            Convolution = partial(SE3BNConvolution, momentum=batch_norm_momentum)
        elif normalization == "batch_max":
            Convolution = partial(SE3BNConvolution, reduce='max', momentum=batch_norm_momentum)
        elif normalization == "group":
            Convolution = SE3GNConvolution
        elif normalization == "instance":
            Convolution = partial(SE3GNConvolution, Rs_gn=[(1, 2 * n + 1) for n, mul in enumerate(repr_in) for _ in range(mul)])
        else:
            raise NotImplementedError('normalization mode unknown')

        if transpose == True:
            Convolution = SE3ConvolutionTranspose

        self.conv = Convolution(
            Rs_in=Rs_in,
            Rs_out=Rs_out_with_gate,
            size=size,
            radial_window=radial_window,
            stride=1 if smooth_stride else stride,
            padding=padding,
            dilation=dilation,
            dyn_iso=dyn_iso,
            verbose=verbose,
        )

        self.stride = stride if smooth_stride else 1

        self.dropout = None
        if capsule_dropout_p is not None:
            Rs_out_without_gate = [(mul, 2 * n + 1) for n, mul in enumerate(repr_out)]  # Rs_out without gates
            self.dropout = SE3Dropout(Rs_out_without_gate, capsule_dropout_p)

        self.checkpoint = checkpoint


    def forward(self, x):  # pylint: disable=W

        def gate(y):
            nbatch = y.size(0)
            nx = y.size(2)
            ny = y.size(3)
            nz = y.size(4)

            size_out = sum(mul * (2 * n + 1) for n, mul in enumerate(self.repr_out))

            if self.gate_act is not None:
                g = y[:, size_out:]
                g = self.gate_act(g)
                begin_g = 0  # index of first scalar gate capsule

            z = y.new_empty((y.size(0), size_out, y.size(2), y.size(3), y.size(4)))
            begin_y = 0  # index of first capsule

            for n, mul in enumerate(self.repr_out):
                if mul == 0:
                    continue
                dim = 2 * n + 1

                # crop out capsules of order n
                field_y = y[:, begin_y: begin_y + mul * dim]  # [batch, feature * repr, x, y, z]

                if n == 0:
                    # Scalar activation
                    if self.scalar_act is not None:
                        field = self.scalar_act(field_y)
                    else:
                        field = field_y
                else:
                    if self.gate_act is not None:
                        # reshape channels in capsules and capsule entries
                        field_y = field_y.contiguous()
                        field_y = field_y.view(nbatch, mul, dim, nx, ny, nz)  # [batch, feature, repr, x, y, z]

                        # crop out corresponding scalar gates
                        field_g = g[:, begin_g: begin_g + mul]  # [batch, feature, x, y, z]
                        begin_g += mul
                        # reshape channels for broadcasting
                        field_g = field_g.contiguous()
                        field_g = field_g.view(nbatch, mul, 1, nx, ny, nz)  # [batch, feature, repr, x, y, z]

                        # scale non-scalar capsules by gate values
                        field = field_y * field_g  # [batch, feature, repr, x, y, z]
                        field = field.view(nbatch, mul * dim, nx, ny, nz)  # [batch, feature * repr, x, y, z]
                        del field_g
                    else:
                        field = field_y
                del field_y

                z[:, begin_y: begin_y + mul * dim] = field
                begin_y += mul * dim
                del field

            return z


        # convolution
        z = self.conv(x)
                          
        # pool
        if self.pool is not None:
            z = self.pool(z)

        # gate
        if self.scalar_act is not None or self.gate_act is not None:
            z = torch.utils.checkpoint.checkpoint(gate, z) if self.checkpoint else gate(z)

        # stride
        if self.stride > 1:
            z = low_pass_filter(z, self.stride, self.stride)

        # dropout
        if self.dropout is not None:
            z = self.dropout(z)

        return z

In [63]:
f = torch.nn.Sequential(
            SE3Convolution([(1, 0)], [(2, 0), (2, 1), (1, 2)], size=4),
            SE3Convolution([(2, 0), (2, 1), (1, 2)], [(1, 0)], size=4),
        ).to(torch.float64)

torch.set_default_dtype(torch.float64)

def rotate(t):
    # rotate 90 degrees in plane of axes 2 and 3
    return torch.flip(t, (2, )).transpose(2, 3)

def unrotate(t):
    # undo the rotation by 3 more rotations
    return rotate(rotate(rotate(t)))

inp = torch.randn(2, 1, 16, 16, 16)
print(inp.size())
inp_r = rotate(inp)
print(inp_r.size())

diff_inp = (inp - unrotate(inp_r)).abs().max().item()
print(diff_inp < 1e-10) # sanity check

out = f(inp)
out_r = f(inp_r)

diff_out = (out - unrotate(out_r)).abs().max().item()
print(diff_out < 1e-10)

torch.Size([2, 1, 16, 16, 16])
torch.Size([2, 1, 16, 16, 16])
True
True


In [64]:
conv = SE3Convolution([(1, 0)], [(1, 2)], size=4)
x = torch.randn(1, 1, 4, 4, 4)

In [65]:
conv(x)

tensor([[[[[ 3.4505]]],


         [[[ 4.8365]]],


         [[[-0.1406]]],


         [[[-1.7393]]],


         [[[-1.4492]]]]], grad_fn=<ThnnConv3DBackward>)

In [104]:
class EquiNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.gated = PoolGatedBlock(repr_in, repr_out, size, activation=activation, 
                                    pool_size=pool_size, pool_stride=pool_stride, bias=bias)
        self.lin1 = nn.Linear(n_input_1, n_output_1)
        self.drop1 = nn.Dropout(prob)
        self.lin2 = nn.Linear(n_output_1, n_output_2)
        self.drop2 = nn.Dropout(prob)
        self.lin3 = nn.Linear(n_output_2, NUM_CLASSES+6)

    def forward(self, x):
        x = self.gated(x)
        x = x.view(batch_size,-1) # tf.reshape(x,[batchSize,-1])
        x = F.relu(self.lin1(x))
        x = self.drop1(x)
        x = F.relu(self.lin2(x))
        x = self.drop2(x)
        return self.lin3(x)
    
    
def train(model, device, train_set, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_set):
        #data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss_fn = nn.MSELoss(reduction='sum')
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_set),
                100. * batch_idx / len(train_set), loss.item()))

def test(model, device, test_set):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_set:
            #data, target = data.to(device), target.to(device)
            output = model(data)
            loss_fn = nn.MSELoss(reduction='sum')
            test_loss += loss_fn(output, target) # sum up batch loss
            # pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += int(torch.argmax(output) == torch.argmax(target))

    test_loss /= len(test_set)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_set),
        100. * correct / len(test_set)))

In [105]:
repr_in = (1,0)
repr_out = (1,0)
size = 4
activation = (F.relu, F.relu)
pool_size = 1
pool_stride = 1
bias = True

n_input_1 = 1 #5
n_output_1 = 1 #512
n_output_2 = 1 #256

batch_size = 1
prob = 0.5
NUM_CLASSES = 1

In [106]:
epochs = 10
device = torch.device('cpu')
torch.manual_seed(1)

train_set = torch.utils.data.DataLoader(
    datasets.MNIST('../MNIST', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=64, shuffle=True)
test_set = torch.utils.data.DataLoader(
    datasets.MNIST('../MNIST', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=1000, shuffle=True)

train_set = [(torch.tensor([[[[[1,2,3,4],
                [1,2,3,4],
                [1,2,3,4],
                [1,2,3,4]],
              
                [[1,2,3,4],
                [1,2,3,4],
                [1,2,3,4],
                [1,2,3,4]],
              
                [[1,2,3,4],
                [1,2,3,4],
                [1,2,3,4],
                [1,2,3,4]],
              
                [[1,2,3,4],
                [1,2,3,4],
                [1,2,3,4],
                [1,2,3,4]]]]], dtype=torch.float64) , torch.tensor([1,1,1,1,1,1,1], dtype=torch.float64))]

test_set = [(torch.tensor([[[[[1,2,3,4],
                [1,2,3,4],
                [1,2,3,4],
                [1,2,3,4]],
              
                [[1,2,3,4],
                [1,2,3,4],
                [1,2,3,4],
                [1,2,3,4]],
              
                [[1,2,3,4],
                [1,2,3,4],
                [1,2,3,4],
                [1,2,3,4]],
              
                [[1,2,3,4],
                [1,2,3,4],
                [1,2,3,4],
                [1,2,3,4]]]]], dtype=torch.float64) , torch.tensor([1,1,1,1,1,1,1], dtype=torch.float64))]


model = EquiNet().to(device)
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(1, epochs + 1):
    train(model, device, train_set, optimizer, epoch)
    test(model, device, test_set)


Test set: Average loss: 7.5405, Accuracy: 0/1 (0%)


Test set: Average loss: 7.5389, Accuracy: 0/1 (0%)


Test set: Average loss: 7.5374, Accuracy: 0/1 (0%)


Test set: Average loss: 7.5358, Accuracy: 0/1 (0%)


Test set: Average loss: 7.5341, Accuracy: 0/1 (0%)


Test set: Average loss: 7.5325, Accuracy: 0/1 (0%)


Test set: Average loss: 7.5309, Accuracy: 0/1 (0%)


Test set: Average loss: 7.5292, Accuracy: 0/1 (0%)


Test set: Average loss: 7.5275, Accuracy: 0/1 (0%)


Test set: Average loss: 7.5258, Accuracy: 0/1 (0%)



In [49]:
a = np.array([[[1,2],[2,3]], [[1,2],[2,3]]])
print(a)
print(np.tile(a, 2))
print(np.tile(a, 2).reshape((2, -1)))

[[[1 2]
  [2 3]]

 [[1 2]
  [2 3]]]
[[[1 2 1 2]
  [2 3 2 3]]

 [[1 2 1 2]
  [2 3 2 3]]]
[[1 2 1 2 2 3 2 3]
 [1 2 1 2 2 3 2 3]]


In [None]:
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.


for t in range(10000):
  # Forward pass: compute predicted y by passing x to the model.
  y_pred = model(x)

  # Compute and print loss.
  loss = loss_fn(y_pred, y)
  print(t, loss.item())
  
  # Before the backward pass, use the optimizer object to zero all of the
  # gradients for the Tensors it will update (which are the learnable weights
  # of the model)
  optimizer.zero_grad()

  # Backward pass: compute gradient of the loss with respect to model parameters
  loss.backward()

  # Calling the step function on an Optimizer makes an update to its parameters
  optimizer.step()

In [None]:
SE3Convolution([(1, 0)], [(1, 2)], size=4)
'activation': (F.relu, torch.sigmoid),
    n_non_scalar = sum(repr_out[1:])
ScalarActivation([(n_non_scalar, F.relu)], bias=bias)



In [None]:
model = torch.nn.Sequential(
            
        )
loss_fn = torch.nn.MSELoss(reduction='sum')

In [None]:
cnn (1,1,1,1,1)
bias
maxpool
relu

linear
relu
dropout 

linear 
relu
dropout

softmax