In [1]:
# !pip install layer-to-layer-pytorch

In [2]:
from typing import Optional

import torch
from torch import nn, optim

import numpy as np

from tqdm.auto import tqdm, trange
from tqdm.contrib import tenumerate, tzip

from layer_to_layer_pytorch import Layer2Layer

In [3]:

class M(nn.Module):
    def __init__(self, depth: int, dim: int, hidden_dim: Optional[int] = None):
        super().__init__()
        hidden_dim = hidden_dim or dim
        self.layers = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(dim, hidden_dim),
                    nn.BatchNorm1d(hidden_dim),
                    nn.LeakyReLU(),
                )
            ]
            + [
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.BatchNorm1d(hidden_dim),
                    nn.LeakyReLU(),
                )
                for i in range(depth)
            ]
            + [nn.Linear(hidden_dim, dim), nn.Sigmoid()]
        )
        
    def __len__(self) -> int:
        return len(self.layers)

    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        x = batch
        for l in self.layers:
            x = l(x)

        return x


In [4]:
model = M(depth=5, dim=40, hidden_dim=100)

l2l_model = Layer2Layer(model, layers_attr="layers", microbatch_size=100, verbose=False)

In [5]:
x = torch.rand(1_000, 40)
y = torch.rand(1_000, 40)

In [6]:
for name, param in l2l_model.model.named_parameters():
    print(f"[{name}]: {param.requires_grad}")

[layers.0.0.weight]: True
[layers.0.0.bias]: True
[layers.0.1.weight]: True
[layers.0.1.bias]: True
[layers.1.0.weight]: True
[layers.1.0.bias]: True
[layers.1.1.weight]: True
[layers.1.1.bias]: True
[layers.2.0.weight]: True
[layers.2.0.bias]: True
[layers.2.1.weight]: True
[layers.2.1.bias]: True
[layers.3.0.weight]: True
[layers.3.0.bias]: True
[layers.3.1.weight]: True
[layers.3.1.bias]: True
[layers.4.0.weight]: True
[layers.4.0.bias]: True
[layers.4.1.weight]: True
[layers.4.1.bias]: True
[layers.5.0.weight]: True
[layers.5.0.bias]: True
[layers.5.1.weight]: True
[layers.5.1.bias]: True
[layers.6.weight]: True
[layers.6.bias]: True


In [8]:
losses = []
criterion = nn.MSELoss()
optimizer = optim.AdamW(l2l_model.main_params)

for i in trange(2000):
    l2l_model.zero_grad()
    _ = l2l_model.forward(x)
    loss_value = l2l_model.compute_loss(y, criterion)
    
    if i % 50 == 0:
        tqdm.write(f"[{i}] loss = {loss_value}")
    losses.append(loss_value)

    l2l_model.backward()
    optimizer.step()
    l2l_model.update_main_model_params() # Sync params with CPU
    

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))

[0] loss = 0.09251299779862165
[50] loss = 0.04679687600582838
[100] loss = 0.024244420928880572
[150] loss = 0.012858417234383523
[200] loss = 0.007469074393156916
[250] loss = 0.004840212088311091
[300] loss = 0.0034365870524197817
[350] loss = 0.0025724523584358394
[400] loss = 0.0020228293433319777
[450] loss = 0.001649867874220945
[500] loss = 0.001370955680613406
[550] loss = 0.001156660437118262
[600] loss = 0.0010158663790207356
[650] loss = 0.0008867938231560402
[700] loss = 0.0008164640748873353
[750] loss = 0.0007635831498191692
[800] loss = 0.0006904192341607995
[850] loss = 0.0006020103282935452
[900] loss = 0.000585917350690579
[950] loss = 0.0005711435842385981
[1000] loss = 0.0005517072750080843
[1050] loss = 0.00042251174090779386
[1100] loss = 0.00043898035073652864
[1150] loss = 0.000410193380957935
[1200] loss = 0.00040243057810585015
[1250] loss = 0.00041606493323342875
[1300] loss = 0.0003635622069850797
[1350] loss = 0.0003410119952604873
[1400] loss = 0.00034775

In [10]:
import plotly.graph_objects as go

fig = go.Figure(data=go.Scatter(x=np.arange(5000), y=losses))
fig.show()