In [1]:
!pip install pytorch-adapt

Collecting pytorch-adapt
  Downloading pytorch_adapt-0.0.61-py3-none-any.whl (137 kB)
[?25l[K     |██▍                             | 10 kB 29.3 MB/s eta 0:00:01[K     |████▊                           | 20 kB 26.0 MB/s eta 0:00:01[K     |███████▏                        | 30 kB 19.2 MB/s eta 0:00:01[K     |█████████▌                      | 40 kB 15.7 MB/s eta 0:00:01[K     |████████████                    | 51 kB 10.9 MB/s eta 0:00:01[K     |██████████████▎                 | 61 kB 12.7 MB/s eta 0:00:01[K     |████████████████▊               | 71 kB 11.9 MB/s eta 0:00:01[K     |███████████████████             | 81 kB 11.4 MB/s eta 0:00:01[K     |█████████████████████▍          | 92 kB 12.5 MB/s eta 0:00:01[K     |███████████████████████▉        | 102 kB 13.0 MB/s eta 0:00:01[K     |██████████████████████████▏     | 112 kB 13.0 MB/s eta 0:00:01[K     |████████████████████████████▋   | 122 kB 13.0 MB/s eta 0:00:01[K     |███████████████████████████████ | 133 kB 13

### Helper function for demo

In [2]:
from pytorch_adapt.utils.common_functions import get_lr


def print_optimizers_slim(optimizers):
    for k, v in optimizers.items():
        print(
            f"{k}: {v.__class__.__name__} with lr={get_lr(v)} weight_decay={v.param_groups[0]['weight_decay']}"
        )
    print("")

### Containers Initialization

In [3]:
import torch

from pytorch_adapt.containers import LRSchedulers, Models, Optimizers

device = torch.device("cuda")

G = torch.nn.Linear(1000, 100)
C = torch.nn.Linear(100, 10)
D = torch.nn.Linear(100, 1)

models = Models({"G": G, "C": C, "D": D})
optimizers = Optimizers((torch.optim.Adam, {"lr": 0.456, "weight_decay": 0.123}))
schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {"gamma": 0.99}))

### Create with

In [4]:
optimizers.create_with(models)
schedulers.create_with(optimizers)

print(models)
print_optimizers_slim(optimizers)
print(schedulers)

G: Linear(in_features=1000, out_features=100, bias=True)
C: Linear(in_features=100, out_features=10, bias=True)
D: Linear(in_features=100, out_features=1, bias=True)

G: Adam with lr=0.456 weight_decay=0.123
C: Adam with lr=0.456 weight_decay=0.123
D: Adam with lr=0.456 weight_decay=0.123

G: <torch.optim.lr_scheduler.ExponentialLR object at 0x7fecdace11d0>
C: <torch.optim.lr_scheduler.ExponentialLR object at 0x7febc76b3bd0>
D: <torch.optim.lr_scheduler.ExponentialLR object at 0x7febc76b3cd0>



### Merge

In [5]:
more_models = Models({"X": torch.nn.Linear(20, 1)})
models.merge(more_models)

optimizers = Optimizers((torch.optim.Adam, {"lr": 0.456}))
special_opt = Optimizers(
    (torch.optim.SGD, {"lr": 1, "weight_decay": 1e-5}), keys=["G", "X"]
)
optimizers.merge(special_opt)
optimizers.create_with(models)

print(models)
print_optimizers_slim(optimizers)

G: Linear(in_features=1000, out_features=100, bias=True)
C: Linear(in_features=100, out_features=10, bias=True)
D: Linear(in_features=100, out_features=1, bias=True)
X: Linear(in_features=20, out_features=1, bias=True)

G: SGD with lr=1 weight_decay=1e-05
C: Adam with lr=0.456 weight_decay=0
D: Adam with lr=0.456 weight_decay=0
X: SGD with lr=1 weight_decay=1e-05



### Delete keys

In [6]:
from pytorch_adapt.containers import DeleteKey

opt1 = Optimizers((torch.optim.SGD, {"lr": 0.01, "momentum": 0.9}))
opt2 = Optimizers((DeleteKey, {}), keys=["G", "D"])
opt1.merge(opt2)
opt1.create_with(models)
print_optimizers_slim(opt1)

C: SGD with lr=0.01 weight_decay=0
X: SGD with lr=0.01 weight_decay=0



### Model Container Functions

In [7]:
models.train()
for k, v in models.items():
    print(k, "training", v.training)

models.eval()
for k, v in models.items():
    print(k, "training", v.training)

models.zero_grad()
models.to(device)
for k, v in models.items():
    print(k, "device", v.weight.device)

G training True
C training True
D training True
X training True
G training False
C training False
D training False
X training False
G device cuda:0
C device cuda:0
D device cuda:0
X device cuda:0


### Optimizer Container Functions

In [8]:
data = torch.randn(32, 1000).to(device)
models.to(device)

for keys in [None, ["C"]]:
    logits = C(G(data))
    loss = torch.sum(logits)

    # zero gradients, compute gradients, update weights
    if keys is None:
        optimizers.zero_back_step(loss)
    # only apply zero_back_step to specific optimizers
    else:
        optimizers.zero_back_step(loss, keys=keys)

### Optimizer LR Multiplier

In [9]:
optimizers = Optimizers(
    (torch.optim.Adam, {"lr": 0.1}), multipliers={"G": 50, "C": 0.5}
)
optimizers.create_with(models)
print_optimizers_slim(optimizers)

G: Adam with lr=5.0 weight_decay=0
C: Adam with lr=0.05 weight_decay=0
D: Adam with lr=0.1 weight_decay=0
X: Adam with lr=0.1 weight_decay=0



### LR Scheduler Functions

In [10]:
schedulers = LRSchedulers(
    (torch.optim.lr_scheduler.ExponentialLR, {"gamma": 0.99}),
    scheduler_types={"per_step": ["G", "C"], "per_epoch": ["D", "X"]},
)
schedulers.create_with(optimizers)

# step lr schedulers by type
schedulers.step("per_step")
schedulers.step("per_epoch")

# get lr schedulers by type
per_step = schedulers.filter_by_scheduler_type("per_step")
per_epoch = schedulers.filter_by_scheduler_type("per_epoch")

