In [None]:
!pip install pytorch-adapt

### Helper function for demo

In [None]:
from pytorch_adapt.utils.common_functions import get_lr


def print_optimizers_slim(optimizers):
    for k, v in optimizers.items():
        print(
            f"{k}: {v.__class__.__name__} with lr={get_lr(v)} weight_decay={v.param_groups[0]['weight_decay']}"
        )
    print("")

### Containers Initialization

In [None]:
import torch

from pytorch_adapt.containers import LRSchedulers, Models, Optimizers

G = torch.nn.Linear(1000, 100)
C = torch.nn.Linear(100, 10)
D = torch.nn.Linear(100, 1)

models = Models({"G": G, "C": C, "D": D})
optimizers = Optimizers((torch.optim.Adam, {"lr": 0.456, "weight_decay": 0.123}))
schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {"gamma": 0.99}))

### Create with

In [None]:
optimizers.create_with(models)
schedulers.create_with(optimizers)

print(models)
print_optimizers_slim(optimizers)
print(schedulers)

### Merge

In [None]:
more_models = Models({"X": torch.nn.Linear(20, 1)})
models.merge(more_models)

optimizers = Optimizers((torch.optim.Adam, {"lr": 0.456}))
special_opt = Optimizers(
    (torch.optim.SGD, {"lr": 1, "weight_decay": 1e-5}), keys=["G", "X"]
)
optimizers.merge(special_opt)
optimizers.create_with(models)

print(models)
print_optimizers_slim(optimizers)

### Delete keys

In [None]:
from pytorch_adapt.containers import DeleteKey

opt1 = Optimizers((torch.optim.SGD, {"lr": 0.01, "momentum": 0.9}))
opt2 = Optimizers((DeleteKey, {}), keys=["G", "D"])
opt1.merge(opt2)
opt1.create_with(models)
print_optimizers_slim(opt1)