In [1]:
import torch

In [14]:
class MyReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x.clamp(min=0)
    
    def backward(ctx, dout):
        (x, ) = ctx.saved_tensors
        grad_x = dout.clone()
        grad_x[x < 0] = 0
        return grad_x

In [15]:
device = torch.device('cpu')
N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_out, device=device)

w1 = torch.randn(D_in, H, device=device, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, requires_grad=True)

learning_rate = 1e-6

In [16]:
for t in range(500):
    y_pred = MyReLU.apply(x.mm(w1)).mm(w2)
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())
    loss.backward()
    
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        w1.grad.zero_()
        w2.grad.zero_()

0 32566816.0
1 31606992.0
2 32267054.0
3 29542348.0
4 22292370.0
5 13749096.0
6 7431736.5
7 3932700.5
8 2254174.0
9 1461790.5
10 1059908.375
11 829822.125
12 679965.0
13 571655.125
14 488197.15625
15 421159.96875
16 365853.15625
17 319496.78125
18 280254.96875
19 246811.15625
20 218122.53125
21 193369.09375
22 171905.390625
23 153230.546875
24 136926.171875
25 122652.5703125
26 110135.0703125
27 99116.671875
28 89381.234375
29 80753.53125
30 73096.1015625
31 66277.4296875
32 60213.0625
33 54796.19921875
34 49940.6796875
35 45578.20703125
36 41652.68359375
37 38112.875
38 34916.06640625
39 32023.8046875
40 29404.8125
41 27032.126953125
42 24879.421875
43 22918.873046875
44 21132.78125
45 19503.986328125
46 18015.62890625
47 16655.00390625
48 15410.74609375
49 14272.515625
50 13228.6484375
51 12270.4775390625
52 11391.7666015625
53 10584.15625
54 9840.1767578125
55 9155.732421875
56 8524.2119140625
57 7941.22119140625
58 7402.666015625
59 6904.201171875
60 6442.939453125
61 6015.63769531

385 0.0008801647927612066
386 0.0008516847155988216
387 0.0008249807870015502
388 0.0007977974601089954
389 0.0007707941113039851
390 0.0007455244776792824
391 0.0007218625396490097
392 0.0006998440367169678
393 0.0006773574859835207
394 0.0006558149470947683
395 0.000635328353382647
396 0.0006163360667414963
397 0.000597754551563412
398 0.0005801969091407955
399 0.0005621873424388468
400 0.0005437657819129527
401 0.000528572010807693
402 0.0005118754925206304
403 0.0004973375471308827
404 0.00048320487258024514
405 0.00046797608956694603
406 0.00045506394235417247
407 0.00044234763481654227
408 0.0004304910253267735
409 0.0004183158453088254
410 0.0004057384794577956
411 0.0003943214542232454
412 0.0003833928203675896
413 0.0003737035149242729
414 0.0003615999303292483
415 0.0003530292888171971
416 0.00034304294968023896
417 0.0003332844644319266
418 0.0003247404529247433
419 0.0003172496217302978
420 0.000308716029394418
421 0.00030055566458031535
422 0.0002935499942395836
423 0.0002