In [65]:
import torch
import torch.nn as nn

In [66]:
class RandomLearnableModule(nn.Module):
    def __init__(self, learnable_params, fixed_params):
        super().__init__()
        if learnable_params is not None:
            self.learnable_params = nn.ParameterDict(learnable_params)
        else:
            self.learnable_params = learnable_params
        self.fixed_params = fixed_params
                
    def forward(self, x):
        return torch.matmul(self.learnable_params.a, x)

In [82]:
class MetaClass(nn.Module):
    def __init__(self):
        super().__init__()
        self.module_dict = nn.ModuleDict()
        self.module_dict['1'] = RandomLearnableModule({'a':1.0*torch.arange(3)}, {'b': 0.0})
        self.module_dict['2'] = RandomLearnableModule({'a':1.0*torch.arange(3)}, {'b': 0.0})
        
        self.learnable_parameters = nn.ParameterDict({'v1': torch.tensor([7.0])})
        
    def forward(self, x):
        x1 = self.module_dict['1'](x)
        x2 = self.module_dict['2'](x)
        return x1 + x2

In [83]:
class MetaMetaClass(nn.Module):
    def __init__(self):
        super().__init__()
        self.module_dict = nn.ModuleDict()
        self.module_dict['I'] = MetaClass()
        self.module_dict['II'] = MetaClass()
    def forward(self, x):
        x1 = self.module_dict['I'](x)
        x2 = self.module_dict['II'](x)
        return x1 + x2

In [84]:
meta = MetaMetaClass()

In [85]:
list(meta.parameters())

[Parameter containing:
 tensor([0., 1., 2.], requires_grad=True),
 Parameter containing:
 tensor([0., 1., 2.], requires_grad=True),
 Parameter containing:
 tensor([7.], requires_grad=True),
 Parameter containing:
 tensor([0., 1., 2.], requires_grad=True),
 Parameter containing:
 tensor([0., 1., 2.], requires_grad=True),
 Parameter containing:
 tensor([7.], requires_grad=True)]

In [40]:
for name, param in meta.named_parameters():
    if param.requires_grad:
        print(name, param.data)

module_dict.I.module_dict.1.learnable_params.a tensor([0., 1., 2.])
module_dict.I.module_dict.2.learnable_params.a tensor([0., 1., 2.])
module_dict.II.module_dict.1.learnable_params.a tensor([0., 1., 2.])
module_dict.II.module_dict.2.learnable_params.a tensor([0., 1., 2.])


In [53]:
x = 1.0*(torch.tensor([[0, 1, 2], [0, 1, 2]]))

In [55]:
z = meta(x.T)

In [56]:
z.shape

torch.Size([2])

In [59]:
z

tensor([20., 20.], grad_fn=<AddBackward0>)

In [57]:
y = 30*torch.ones_like(z)

In [58]:
optim = torch.optim.Adam(meta.parameters())

In [60]:
optim.zero_grad()

In [61]:
loss = torch.nn.MSELoss()(z, y)

In [62]:
for name, param in meta.named_parameters():
    if param.requires_grad:
        print(name, param.grad)

module_dict.I.module_dict.1.learnable_params.a None
module_dict.I.module_dict.2.learnable_params.a None
module_dict.II.module_dict.1.learnable_params.a None
module_dict.II.module_dict.2.learnable_params.a None


In [63]:
loss.backward()

In [64]:
for name, param in meta.named_parameters():
    if param.requires_grad:
        print(name, param.grad)

module_dict.I.module_dict.1.learnable_params.a tensor([  0., -20., -40.])
module_dict.I.module_dict.2.learnable_params.a tensor([  0., -20., -40.])
module_dict.II.module_dict.1.learnable_params.a tensor([  0., -20., -40.])
module_dict.II.module_dict.2.learnable_params.a tensor([  0., -20., -40.])
