# Customer Module

In [1]:
import torch
from torch.autograd import Variable

In [4]:
class TwoLayerNet(torch.nn.Module):
    def __init__(self,D_in, H, D_out):
        super(TwoLayerNet,self).__init__()
        self.linear1 = torch.nn.Linear(D_in,H)
        self.linear2 = torch.nn.Linear(H,D_out)
    def forward(self,x):
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred

In [3]:
N, D_in, H, D_out = 64, 1000, 100, 10
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

In [5]:
model = TwoLayerNet(D_in,H,D_out)

In [7]:
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(),lr=1e-4)

In [8]:
for t in range(500):
    y_pred = model(x)
    loss = loss_fn(y_pred,y)
    print(t,loss.data[0])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 725.7544555664062
1 670.37744140625
2 623.1653442382812
3 581.8516235351562
4 545.5613403320312
5 513.1138305664062
6 483.7323303222656
7 457.184814453125
8 432.6991882324219
9 409.6831359863281
10 388.116455078125
11 367.8990478515625
12 348.8924560546875
13 330.83929443359375
14 313.75146484375
15 297.6293029785156
16 282.2702331542969
17 267.58209228515625
18 253.5344696044922
19 240.07989501953125
20 227.25161743164062
21 214.99781799316406
22 203.28477478027344
23 192.05360412597656
24 181.36639404296875
25 171.15762329101562
26 161.42095947265625
27 152.1680145263672
28 143.37274169921875
29 135.0388641357422
30 127.14613342285156
31 119.67090606689453
32 112.59944152832031
33 105.88972473144531
34 99.56292724609375
35 93.60193634033203
36 87.96621704101562
37 82.65080261230469
38 77.6543960571289
39 72.9538803100586
40 68.5313949584961
41 64.38270568847656
42 60.490150451660156
43 56.8192138671875
44 53.37432861328125
45 50.14763259887695
46 47.11772537231445
47 44.24076461791

485 4.220221399009461e-06
486 4.109213023184566e-06
487 4.000268745585345e-06
488 3.895023382938234e-06
489 3.792870984398178e-06
490 3.692605105243274e-06
491 3.5960140394308837e-06
492 3.501493210933404e-06
493 3.4095123737643007e-06
494 3.319760708109243e-06
495 3.2328543966286816e-06
496 3.1478375603910536e-06
497 3.065113105549244e-06
498 2.984935235872399e-06
499 2.90691514237551e-06


In [9]:
import random
import torch
from torch.autograd import Variable

In [13]:
class DynamicNet(torch.nn.Module):
    def __init__(self,D_in,H,D_out):
        super(DynamicNet,self).__init__()
        self.input_linear = torch.nn.Linear(D_in,H)
        self.middle_linear = torch.nn.Linear(H,H)
        self.output_linear = torch.nn.Linear(H,D_out)
    def forward(self,x):
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0,3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred

In [14]:
N, D_in, H, D_out = 64, 1000, 100, 10
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

In [15]:
model = DynamicNet(D_in,H,D_out)
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(),lr=1e-4)
for t in range(500):
    y_pred = model(x)
    
    loss = loss_fn(y_pred,y)
    
    print(t,loss.data[0])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 677.8267211914062
1 676.5379638671875
2 675.2866821289062
3 674.0646362304688
4 684.9634399414062
5 676.3870849609375
6 639.0174560546875
7 598.8134765625
8 676.0250854492188
9 675.8939208984375
10 671.6927490234375
11 670.5164794921875
12 668.55908203125
13 669.351806640625
14 562.1741333007812
15 661.6380004882812
16 674.9970092773438
17 667.9838256835938
18 530.0537719726562
19 500.53533935546875
20 666.7437133789062
21 655.0119018554688
22 473.4829406738281
23 448.160888671875
24 424.560546875
25 665.3704223632812
26 663.9849853515625
27 674.1983642578125
28 402.3142395019531
29 662.5014038085938
30 673.7531127929688
31 673.4135131835938
32 660.8395385742188
33 381.3753662109375
34 672.9721069335938
35 672.6248779296875
36 672.2828369140625
37 361.5393371582031
38 648.4260864257812
39 671.8368530273438
40 658.6547241210938
41 657.135498046875
42 640.2018432617188
43 655.4592895507812
44 343.01983642578125
45 653.8536376953125
46 670.702880859375
47 670.3306274414062
48 669.968017

477 5.196931838989258
478 4.468626022338867
479 3.953098773956299
480 7.012937068939209
481 3.7984936237335205
482 7.208336353302002
483 9.542912483215332
484 3.687302827835083
485 6.813019275665283
486 2.9671034812927246
487 2.624783992767334
488 5.112341403961182
489 5.880403995513916
490 5.414257049560547
491 3.7777087688446045
492 2.9975602626800537
493 3.2517547607421875
494 4.803659915924072
495 3.309929370880127
496 4.495142936706543
497 3.099874973297119
498 6.410638332366943
499 4.205537796020508
