In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [3]:
net = nn.Sequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))
X = torch.rand(2, 20)
net(X).shape

torch.Size([2, 10])

In [6]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.LazyLinear(128)
        self.out = nn.LazyLinear(16)

    def forward(self, X):
        return self.out(F.relu(self.hidden(X)))

In [7]:
net = MLP()
net(X).shape

torch.Size([2, 16])

In [8]:
class MySequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        for index, module in enumerate(args):
            self.add_module(str(index), module)

    def forward(self,X):
        for module in self.children():
            X = module(X)
        return X


In [10]:
net = MySequential(nn.LazyLinear(128), nn.ReLU(), nn.LazyLinear(16))
net(X).shape

torch.Size([2, 16])

In [16]:
class FixedHiddenMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_weight = torch.rand((20, 20))
        self.linear = nn.LazyLinear(20)

    def forward(self, X):
        X = self.linear(X)
        X = F.relu( X @ self.rand_weight + 1)
        X = self.linear(X)

        while X.abs().sum() > 1:
            X /= 2
        return X.sum()

In [17]:
net = FixedHiddenMLP()
net(X)



tensor(0.0665, grad_fn=<SumBackward0>)

In [18]:
class NestMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.LazyLinear(128), nn.ReLU(),
                                 nn.LazyLinear(32), nn.ReLU())
        self.linear = nn.LazyLinear(8)

    def forward(self, X):
        return self.linear(self.net(X))

cinma = nn.Sequential(NestMLP(), nn.LazyLinear(20), FixedHiddenMLP())
cinma(X)

tensor(0.1188, grad_fn=<SumBackward0>)