In [1]:
import torch
import torch.nn as nn

In [2]:
class MLP(nn.Module):
    def __init__(self, num_inputs, num_hiddens, num_outputs):
        super().__init__()
        self.lin1 = nn.Linear(num_inputs, num_hiddens)
        self.lin2 = nn.Linear(num_hiddens, num_outputs)

        self.relu = nn.ReLU()

    def forward(self, X):
        X_h = self.relu(self.lin1(X))
        X_h = self.lin2(X_h)
        return X_h

In [4]:
X = torch.rand(2, 20)

mlp = MLP(20, 20, 2)
mlp(X)

tensor([[0.3141, 0.4658],
        [0.2512, 0.4324]], grad_fn=<AddmmBackward0>)

In [13]:
class MySequential(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__()
        for idx, module in enumerate(args):
            self._modules[str(idx)] = module
    
    def forward(self, X):
        for module in self._modules.values():
            X = module(X)
        return X

In [14]:
class MySequential2(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__()
        self.modules = []
        for idx, module in enumerate(args):
            self.modules.append(module)
    
    def forward(self, X):
        for module in self.modules:
            X = module(X)
        return X

In [16]:
lin1 = nn.Linear(20, 20)
lin2 = nn.Linear(20, 2)

net1 = MySequential(lin1, nn.ReLU(), lin2)
net2 = nn.Sequential(lin1, nn.ReLU(), lin2)
net3 = MySequential2(lin1, nn.ReLU(), lin2)
X1 = net1(X)
X2 = net2(X)
X3 = net3(X)

print("-------X1-----------")
print(X1)
print("-------X2-----------")
print(X2)
print("-------X3-----------")
print(X3)


-------X1-----------
tensor([[-0.3467, -0.0927],
        [-0.3443, -0.2894]], grad_fn=<AddmmBackward0>)
-------X2-----------
tensor([[-0.3467, -0.0927],
        [-0.3443, -0.2894]], grad_fn=<AddmmBackward0>)
-------X3-----------
tensor([[-0.3467, -0.0927],
        [-0.3443, -0.2894]], grad_fn=<AddmmBackward0>)


In [20]:
class FixedHiddenMLP(nn.Module):
    def __init__(self, num_examples, num_inputs, num_outputs):
        super().__init__()
        self.lin1 = nn.Linear(num_inputs, num_outputs)
        self.C = torch.randn((num_outputs, num_outputs), requires_grad=False)

        self.relu = nn.ReLU()
    
    def forward(self, X):
        X = self.lin1(X)
        X = torch.mm(X, self.C)
        X = self.lin1(X)
        return X