In [1]:
import torch
from torch import nn
from torch.nn import functional as F


In [6]:
net = nn.Sequential(nn.Flatten(), nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))

In [7]:
X = torch.rand(2, 20)
net(X).shape

torch.Size([2, 10])

In [11]:
class NLP(nn.Module):
    def __init__(self, num_hidden=256, num_output=10):
        super().__init__()
        self.fl = nn.Flatten()
        self.ln1 = nn.LazyLinear(num_hidden)
        self.ln2 = nn.LazyLinear(num_output)
    def forward(self, X):
        H1 = F.relu(self.ln1(self.fl(X)))
        return self.ln2(H1)

In [12]:
net1 = NLP()
net1(X).shape

torch.Size([2, 10])

In [14]:
class MySequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        for idx, modul in enumerate(args):
            self.add_module(str(idx), modul)
    def forward(self, X):
        for modul in self.children():
            X = modul(X)
        return X

In [16]:
net = MySequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))
net(X).shape

torch.Size([2, 10])

In [17]:
class FixedHiddenMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_w = torch.rand((20, 20))
        self.linear = nn.LazyLinear(20)
    
    def forward(self, X):
        X = self.linear(X)
        X = F.relu(X @ self.rand_w + 1)
        X = self.linear(X)
        while X.abs().sum() > 1:
            X /= 2
        return X.sum()

In [19]:
net = FixedHiddenMLP()
net(X)

tensor(0.2494, grad_fn=<SumBackward0>)

In [22]:
class NestMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.LazyLinear(64),
            nn.ReLU(),
            nn.LazyLinear(32),
            nn.ReLU())
        self.linear = nn.LazyLinear(16)
        
    def forward(self, X):
        return self.linear(X)

In [23]:
chimera = nn.Sequential(NestMLP(), nn.LazyLinear(20), FixedHiddenMLP())
chimera(X)

tensor(-0.2656, grad_fn=<SumBackward0>)