In [1]:

import torch.nn as nn
import torch.nn.functional as F

# NUM_NEU = 2
NUM_NEU = 50

class SharedModel(nn.Module):

    # parameter-related operation is defined in init as nn
    def __init__(self, nlist):
        super(SharedModel, self).__init__()
        # input of network is a 2-dimensional feature(latitude, longitude)
        self.hidden = nn.ModuleList()
        self.hidden_size = len(nlist) - 2
        for i in range(self.hidden_size):
            self.hidden.append(nn.Linear(nlist[i], nlist[i+1]))
        self.ol = nn.Linear(nlist[-2],nlist[-1])   # outputlayer
    
    # parameter-irrelative operation is recommended as function
    def forward(self, x): # input x is the 2-dimensional spatial coordinates
        for i in range(self.hidden_size):
            x = F.relu(self.hidden[i](x))
        x = self.ol(x)
        return x

In [2]:
a = SharedModel([2,5,5,1])


In [3]:
a.state_dict()

OrderedDict([('hidden.0.weight',
              tensor([[ 0.6956, -0.0825],
                      [-0.6684,  0.3141],
                      [ 0.2535, -0.5232],
                      [-0.4845, -0.6444],
                      [ 0.0763, -0.6695]])),
             ('hidden.0.bias',
              tensor([ 0.5952, -0.2872,  0.2569,  0.1323, -0.0401])),
             ('hidden.1.weight',
              tensor([[-0.2626,  0.0006,  0.4181,  0.3660,  0.3866],
                      [-0.0441,  0.0888,  0.1337,  0.3323,  0.3980],
                      [ 0.3814,  0.4134, -0.2573, -0.3117,  0.1556],
                      [ 0.1248, -0.2768,  0.0807,  0.3789,  0.4125],
                      [ 0.0110, -0.2668,  0.3133,  0.3926, -0.1802]])),
             ('hidden.1.bias',
              tensor([-0.2202, -0.1003,  0.0391, -0.2574,  0.0120])),
             ('ol.weight',
              tensor([[ 0.0080, -0.3974, -0.3341,  0.4001, -0.2792]])),
             ('ol.bias', tensor([0.2097]))])

In [4]:
b = SharedModel([2,5,5,5,1])

In [5]:
b.state_dict()

OrderedDict([('hidden.0.weight',
              tensor([[ 0.2308, -0.1549],
                      [-0.1178,  0.0461],
                      [ 0.3983, -0.6221],
                      [ 0.6566,  0.7038],
                      [ 0.6153,  0.5666]])),
             ('hidden.0.bias',
              tensor([-0.2137,  0.5834, -0.0927,  0.4404, -0.4194])),
             ('hidden.1.weight',
              tensor([[ 0.1900, -0.3686, -0.4102,  0.1153, -0.0254],
                      [ 0.1498, -0.0498,  0.0422, -0.1464, -0.1651],
                      [ 0.2948,  0.3458, -0.0465, -0.0508,  0.1232],
                      [-0.1153, -0.1550,  0.3032, -0.2307, -0.4321],
                      [-0.1507,  0.1387, -0.4251,  0.1420, -0.0923]])),
             ('hidden.1.bias',
              tensor([-0.0880,  0.1184, -0.3258,  0.2974,  0.0580])),
             ('hidden.2.weight',
              tensor([[ 0.0849,  0.0292,  0.2311,  0.3045,  0.0752],
                      [-0.3016,  0.1813, -0.0735, -0.1794, -0.4165],
