In [18]:
import torch
import torch.nn as nn
from helpers.convert_to_var_foo import convert_to_var
import numpy as np

In [119]:
class Enet(nn.Module):
    def __init__(self, num_actions, input_dim,
                 num_hidden=2, hidden_size=512):
        super().__init__()
        layers = []
        layers.append(nn.Linear(input_dim, hidden_size))
        layers.append(nn.ReLU())
        for i in range(num_hidden):
            layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_size, num_actions))
        layers.append(nn.Sigmoid())
        self.net = nn.Sequential(*layers)
        #self.set_weights()


    def set_weights(self):
        state_dict = self.net.state_dict()
        keys = sorted(state_dict.keys())
        state_dict[keys[-2]] = torch.zeros_like(state_dict[keys[-2]])
        state_dict[keys[-1]] = torch.zeros_like(state_dict[keys[-1]])
        self.net.load_state_dict(state_dict)


    def forward(self, x):
        out = self.net(x)
        return out


In [125]:
model = Enet(2, 1, hidden_size=10)
opt = torch.optim.Adam(model.parameters())

In [127]:
list(model.net[-2].parameters())

[Parameter containing:
 -0.1501 -0.0329  0.1505  0.1863 -0.1667  0.1146  0.2216 -0.2758 -0.0088 -0.1446
  0.1033 -0.3008 -0.0179  0.0633 -0.2488  0.0114  0.1992  0.1116 -0.2989 -0.0713
 [torch.FloatTensor of size 2x10], Parameter containing:
  0.1579
 -0.2121
 [torch.FloatTensor of size 2]]

In [121]:
print(model)

Enet(
  (net): Sequential(
    (0): Linear(in_features=1, out_features=10, bias=True)
    (1): ReLU()
    (2): Linear(in_features=10, out_features=10, bias=True)
    (3): ReLU()
    (4): Linear(in_features=10, out_features=10, bias=True)
    (5): ReLU()
    (6): Linear(in_features=10, out_features=2, bias=True)
    (7): Sigmoid()
  )
)


In [122]:
p = list(model.parameters())

In [123]:
p[-1]

Parameter containing:
 0.1941
-0.2720
[torch.FloatTensor of size 2]

In [124]:
p[-2]

Parameter containing:
 0.2413 -0.1922  0.1368  0.0941 -0.1402 -0.1898  0.0340 -0.1540  0.0490  0.2392
-0.1650 -0.0361  0.0765 -0.2090 -0.2212 -0.2828  0.0479  0.0892  0.3125  0.0945
[torch.FloatTensor of size 2x10]

In [102]:
x = np.random.randn(1000)
y = np.hstack((x**2, x**2)).reshape((1000,2))

In [92]:
y.shape

(1000, 2)

In [103]:
for _ in range(1000):
    ind = np.random.choice(1000, 32)
    
    x_ = convert_to_var(x[ind][:, np.newaxis])
    y_ = convert_to_var(y[ind])
    mse = nn.MSELoss()
    
    loss = mse(model(x_),  y_)
    
    print(loss.data.numpy())
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    

[3.223347]
[0.7288763]
[1.7346147]
[2.047936]
[1.4716315]
[2.1947198]
[1.1703575]
[1.2441672]
[1.1383361]
[1.6181172]
[1.6037029]
[1.7479316]
[2.021064]
[1.9555826]
[2.3279114]
[1.5727242]
[2.365561]
[1.9275218]
[1.5527768]
[1.0775574]
[0.87951976]
[1.7344458]
[1.6341017]
[1.556728]
[1.561366]
[1.2358235]
[2.3295279]
[3.0029263]
[0.7117042]
[1.649991]
[1.6125838]
[1.8451241]
[1.3706475]
[1.0062699]
[1.7576331]
[1.5614252]
[2.4533546]
[2.5758243]
[2.0550666]
[1.2383928]
[1.4714626]
[0.763224]
[1.4480313]
[1.9706581]
[1.8408517]
[1.7168338]
[1.2362995]
[3.1663756]
[1.8364404]
[1.6895378]
[3.042874]
[1.4958513]
[1.7979953]
[1.9489927]
[1.8300436]
[1.802195]
[1.354597]
[2.504327]
[2.5895255]
[2.2414827]
[1.5935187]
[1.1886942]
[1.9075732]
[2.1703603]
[0.8483476]
[1.0511371]
[2.1395762]
[1.7757137]
[4.2658925]
[2.939877]
[3.2640631]
[0.7953836]
[1.9931586]
[2.1562393]
[2.758805]
[2.8587654]
[1.9008322]
[2.0094254]
[2.6906765]
[1.6660966]
[1.2616688]
[1.8904816]
[1.6375415]
[2.2760868]
[2.49

[2.375364]
[2.7547047]
[1.1111628]
[2.5362115]
[4.0495453]
[1.9297518]
[0.8229731]
[2.3383796]
[2.0136142]
[3.7238307]
[2.187535]
[2.4355657]
[1.561655]
[0.9243715]
[1.4278862]
[1.8069029]
[2.9597504]
[2.5060663]
[1.8804371]
[1.2218423]
[2.5786853]
[1.9593128]
[2.115158]
[1.2688355]
[1.4897568]
[3.2887955]
[1.6755944]
[2.983587]
[1.8542283]
[2.4748075]
[1.6657287]
[1.6305141]
[1.7640406]
[2.0555983]
[2.1408882]
[1.9883028]
[1.99597]
[1.7023797]
[1.6287262]
[2.6526551]
[2.4364593]
[2.8554275]
[2.142396]
[1.204727]
[1.128637]
[1.6421431]
[3.625352]
[1.1919241]
[2.3977172]
[2.532395]
[2.1341572]
[0.9954425]
[2.0451257]
[1.8999281]
[2.0512812]
[2.060645]
[1.119611]
[2.38646]
[2.5787811]
[1.4653313]
[1.6859876]
[3.9888225]
[1.0824106]
[1.6082395]
[1.4422501]
[1.4695609]
[1.6922991]
[1.5641435]
[1.7770896]
[2.644973]
[1.1780612]
[2.1679833]
[2.6770494]
[2.631204]
[2.2386432]
[1.6804333]
[1.5765208]
[3.6104171]
[3.2221916]


In [104]:
state_dict = model.net.state_dict()

In [105]:
print(model.net)

Sequential(
  (0): Linear(in_features=1, out_features=10, bias=True)
  (1): ReLU()
  (2): Linear(in_features=10, out_features=10, bias=True)
  (3): ReLU()
  (4): Linear(in_features=10, out_features=10, bias=True)
  (5): ReLU()
  (6): Linear(in_features=10, out_features=2, bias=True)
  (7): Sigmoid()
)


In [106]:
state_dict.keys()

odict_keys(['0.weight', '0.bias', '2.weight', '2.bias', '4.weight', '4.bias', '6.weight', '6.bias'])

In [107]:
state_dict['6.weight']


 0.0000  0.3366  0.3943  0.3393  0.0000  0.0000  0.3239  0.4266  0.6474  0.3411
 0.0000  0.3407  0.3990  0.3436  0.0000  0.0000  0.3279  0.4333  0.6103  0.3456
[torch.FloatTensor of size 2x10]

In [108]:
state_dict['4.weight']


 0.1800  0.2272  0.0891 -0.2757  0.0657 -0.0995 -0.0042 -0.1193  0.2161  0.2002
-0.2189  0.2714  0.1317  0.5088  0.4943 -0.2321  0.6849  0.5739  0.0860  0.2694
 0.2560  0.3156 -0.1666  0.3172  0.7495 -0.1579  0.6386  0.4843  0.3335  0.2484
 0.0203  0.2195  0.0651  0.7414  0.6870  0.0437  0.3486  0.3772 -0.2326 -0.1278
-0.0231 -0.2470  0.3111  0.0711 -0.2917 -0.2876 -0.1493  0.0899  0.2617  0.0143
 0.1092 -0.0379  0.1214  0.2671  0.0012  0.2259 -0.1541 -0.2264  0.1892  0.0741
 0.2097 -0.0755 -0.1035  0.2670  0.2563  0.1982  0.6586  0.4396 -0.2626 -0.1390
 0.0946 -0.0087 -0.0491  0.4947  0.6931 -0.2657  0.3854  0.1874 -0.1858  0.2440
 0.1187  0.0030 -0.1533  0.2471  0.7344 -0.1222  0.7203  0.4387 -0.1798 -0.2634
 0.1381  0.2909 -0.2942  0.3358  0.4022  0.2976  0.4044  0.3428 -0.0374 -0.2219
[torch.FloatTensor of size 10x10]