In [9]:
from typing import Optional

import torch
from torch import nn, optim

import numpy as np

from tqdm.auto import tqdm, trange
from tqdm.contrib import tenumerate, tzip

from layer_to_layer_pytorch import Layer2Layer

In [10]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(5, 5),
            nn.Linear(5, 5),
        ])
        
    def forward(self, x):
        for l in self.layers:
            x = l(x)
        return x

In [11]:
net = Model()
nnet = Model()
for s, t in zip(net.parameters(), nnet.parameters()):
    t.data.copy_(s.data)
l2l_model = Layer2Layer(net, microbatch_size=20, mixed_precision=True)

In [12]:
x = torch.rand((80, 5))
y = torch.rand((80, 5))

In [13]:
criterion = nn.MSELoss(reduction="mean")
optimizer = optim.Adam(nnet.parameters())
r_loss = []
for i in trange(600):
    optimizer.zero_grad()
    out = nnet(x)
    loss = criterion(out, y)
    loss.backward()
    optimizer.step()
    r_loss.append(loss.item())

HBox(children=(FloatProgress(value=0.0, max=600.0), HTML(value='')))




In [14]:
optimizer = optim.Adam(l2l_model.main_params)
w_loss = []
for i in trange(600):
    l2l_model.zero_grad()
    _ = l2l_model.forward(x)
    loss_value = l2l_model.compute_loss(y, criterion)
    
    if i % 50 == 0:
        tqdm.write(f"[{i}] loss = {loss_value}")
    w_loss.append(loss_value)

    l2l_model.backward()
    optimizer.step()
    l2l_model.update_main_model_params()

HBox(children=(FloatProgress(value=0.0, max=600.0), HTML(value='')))

[0] loss = 0.40304672718048096
[50] loss = 0.29087238758802414
[100] loss = 0.21799978241324425
[150] loss = 0.15478910133242607
[200] loss = 0.1116107702255249
[250] loss = 0.09347709640860558
[300] loss = 0.08926261961460114
[350] loss = 0.08778456225991249
[400] loss = 0.08662739209830761
[450] loss = 0.08561820536851883
[500] loss = 0.08475672826170921
[550] loss = 0.08399175852537155



In [15]:
import plotly.graph_objects as go

In [16]:
fig = go.Figure(
    data=[
        go.Scatter(x=np.arange(600), y=w_loss, name="L2L wrapper"),
        go.Scatter(x=np.arange(600), y=r_loss, name="Usual model"),
    ]
)
fig.show()
