In [1]:
import torch
import numpy as np
import toydiff as tdf
import matplotlib.pyplot as plt
from toydiff.nn.blocks import Module, Linear

In [2]:
model_simple = Linear(1, 1, bias=True)

In [3]:
for p in model_simple.parameters():
    print(p)

Tensor([[-1.1524795]], dtype=float32, track_gradient=True)
Tensor([0.16869831], dtype=float32, track_gradient=True)


In [4]:
for p in model_simple.named_parameters():
    print(p)

('weight', Tensor([[-1.1524795]], dtype=float32, track_gradient=True))
('bias', Tensor([0.16869831], dtype=float32, track_gradient=True))


In [5]:
class Model(Module):
    def __init__(self, in_f, out_f):
        super().__init__()
        self.lin1 = Linear(in_f, 5)
        self.lin2 = Linear(5, 5, bias=True)
        self.lin3 = Linear(5, out_f)

    def forward(self, X):
        return self.lin3(self.lin2(self.lin1(X)))

    
class SuperModel(Module):
    def __init__(self):
        super().__init__()
        self.model1 = Model(1,1)
        self.model2 = Model(1,1)

    def forward(self, X):
        return self.model2(self.model1(X))

In [6]:
model_complex = Model(1,1)

In [7]:
for key, val in model_complex.named_parameters():
    print(key, val)

weight Tensor([[ 0.16809037],
       [-0.44398898],
       [ 0.07091864],
       [ 0.32421875],
       [ 0.6750233 ]], dtype=float32, track_gradient=True)
weight Tensor([[ 0.38460317,  0.57683796,  0.6252425 ,  0.692226  ,  0.31461474],
       [ 0.66168797,  1.8071277 , -0.20569378,  0.6223298 , -0.48108304],
       [-1.5718416 ,  0.04897113,  0.45917118, -2.0962489 ,  0.8343824 ],
       [-0.42814   ,  1.562432  , -1.161408  ,  0.94462436,  0.87796783],
       [ 1.8187758 ,  1.0470984 ,  1.7618078 , -0.35621834,  1.6303357 ]],
      dtype=float32, track_gradient=True)
bias Tensor([-0.62646484, -1.5172968 , -0.11051818,  0.60567015, -0.69061065],
      dtype=float32, track_gradient=True)
weight Tensor([[-0.49345124,  0.4414633 , -0.07866151,  0.26608378,  0.43842435]],
      dtype=float32, track_gradient=True)


In [8]:
sp_model = SuperModel()

In [9]:
sp_model.state_dict()

OrderedDict([('weight_0',
              Tensor([[ 1.5009221 ],
                     [ 1.8802804 ],
                     [-1.1215589 ],
                     [-0.07336602],
                     [-0.98108894]], dtype=float32, track_gradient=True)),
             ('weight_1',
              Tensor([[ 0.6491288 ,  0.16675629, -0.4888138 ,  0.19954038, -0.76729393],
                     [ 1.3105116 , -1.5407058 ,  0.13516453,  0.54066455,  0.04474811],
                     [ 1.8605776 , -0.24941117, -0.4930001 , -1.300886  ,  1.0397424 ],
                     [ 0.8769514 , -1.0941402 ,  0.218654  ,  1.2038121 , -0.03444668],
                     [ 1.4870265 , -2.0687523 , -0.6395416 ,  0.29797655, -0.712167  ]],
                    dtype=float32, track_gradient=True)),
             ('bias_2',
              Tensor([ 1.4455769, -1.9997764, -0.5442659, -0.5875887,  1.4555124],
                    dtype=float32, track_gradient=True)),
             ('weight_3',
              Tensor([[-0.28100106, -