# How does LAP scale with the number of sources?

In [1]:
import torch
import torch.nn as nn

from loss_adapted_plasticity import SourceLossWeighting

In [2]:
device = "cpu"

In [3]:
class MLP(nn.Module):
    def __init__(self, in_shape, out_shape):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(in_shape, 100*in_shape),
            nn.ReLU(),
            nn.Linear(100*in_shape, 100*in_shape),
            nn.ReLU(),
            nn.Linear(100*in_shape, out_shape),
        )
        self.criterion = nn.CrossEntropyLoss(reduction="none")

    def forward(self, x, y=None, return_loss=False):
        out = self.net(x)
        if return_loss:
            loss = self.criterion(out, y)
            return loss, out
        return out

In [4]:
n_features = 20
batch_size = 2048
n_classes = 5
n_sources = 10

In [5]:
history_length = 25
depression_strength = 1.0
leniency = 1.0

mlp = MLP(n_features, n_classes)
mlp.to(device)

optimiser = torch.optim.Adam(
    params=mlp.parameters(),
    lr=0.01,
    weight_decay=0.0001
)

source_loss_weighting = SourceLossWeighting(
    history_length=history_length,
    warmup_iters=100,
    depression_strength=depression_strength,
    discrete_amount=0.005,
    leniency=leniency,
    device=device
)

In [6]:
x, y, sources = (
    torch.randn(batch_size, n_features),
    torch.randint(0, n_classes, (batch_size,)),
    torch.randint(0, n_sources, (batch_size,))
)

x, y, sources = x.to(device), y.to(device), sources.to(device)

In [7]:
label_loss, outputs = mlp(x, y, return_loss=True)

In [8]:
# creating full loss history for fair comparison

for i in range(1000):
    label_loss = source_loss_weighting(
        losses=label_loss, sources=sources,
    )

In [9]:
%timeit source_loss_weighting(losses=label_loss, sources=sources)

449 μs ± 5.81 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [10]:
%timeit mlp(x, y, return_loss=True)

13.3 ms ± 133 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
%%timeit 
optimiser.zero_grad()
label_loss, _ = mlp(x, y, return_loss=True)
label_loss.mean().backward()
optimiser.step()

42 ms ± 1.45 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [12]:
%%timeit 
optimiser.zero_grad()
label_loss, _ = mlp(x, y, return_loss=True)
label_loss = source_loss_weighting(losses=label_loss, sources=sources)
label_loss.mean().backward()
optimiser.step()

42.1 ms ± 191 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
