In [1]:
import torch
from torch import nn
from torch import optim

In [2]:
class MyLinear(nn.Module):
    
    def __init__(self,inp,outp):
        super(MyLinear,self).__init__()
        
        # requires_grad=True
        self.w=nn.Parameter(torch.randn(outp,inp))
        self.b=nn.Parameter(torch.randn(outp))
        
    def forward(self,x):
        x=x @ self.w.t() + self.b
        return x

In [3]:
class Flatten(nn.Module):
    
    def __init__(self):
        super(Flatten,self).__init__()
        
    def forward(self,input):
        return input.view(input.size(0),-1)

In [4]:
class TestNet(nn.Module):
    
    def __init__(self):
        super(TestNet,self).__init__()
        
        self.net=nn.Sequential(
            nn.Conv2d(1,16,stride=1,padding=1),
            nn.MaxPool2d(2,2),
            Flatten(),
            nn.Linear(1*14*14,10)
        )
        
    def forward(self,x):
        return self.net(x)

In [5]:
class BasicNet(nn.Module):
    
    def __init__(self):
        super(BasicNet,self).__init__()
        
        self.net=nn.Linear(4,3)
        
    def forward(self,x):
        return self.net(x)

In [6]:
class Net(nn.Module):
    
    def __init__(self):
        super(Net,self).__init__()
        
        self.net=nn.Sequential(
            BasicNet(),
            nn.ReLU(),
            nn.Linear(3,2)
        )
        
    def forward(self,x):
        return self.net(x)

In [7]:
def main():
    device=torch.device('cuda')
    
    net=Net()
    net.to(device)
    
    net.train()
    
    net.eval()
    
#     net.load_state_dict(torch.load('ckpt.mdl'))
#     torch.save(net.state_dict(),'ckpt.mdl')
    
    for name,t in net.named_parameters():
        print('parameters:',name,t.shape)
        
    for name,m in net.named_children():
        print('\nchildren:',name,m)
        
    for name,m in net.named_modules():
        print('\nmodules:',name,m)

In [8]:
if __name__=='__main__':
    main()

parameters: net.0.net.weight torch.Size([3, 4])
parameters: net.0.net.bias torch.Size([3])
parameters: net.2.weight torch.Size([2, 3])
parameters: net.2.bias torch.Size([2])

children: net Sequential(
  (0): BasicNet(
    (net): Linear(in_features=4, out_features=3, bias=True)
  )
  (1): ReLU()
  (2): Linear(in_features=3, out_features=2, bias=True)
)

modules:  Net(
  (net): Sequential(
    (0): BasicNet(
      (net): Linear(in_features=4, out_features=3, bias=True)
    )
    (1): ReLU()
    (2): Linear(in_features=3, out_features=2, bias=True)
  )
)

modules: net Sequential(
  (0): BasicNet(
    (net): Linear(in_features=4, out_features=3, bias=True)
  )
  (1): ReLU()
  (2): Linear(in_features=3, out_features=2, bias=True)
)

modules: net.0 BasicNet(
  (net): Linear(in_features=4, out_features=3, bias=True)
)

modules: net.0.net Linear(in_features=4, out_features=3, bias=True)

modules: net.1 ReLU()

modules: net.2 Linear(in_features=3, out_features=2, bias=True)
