In [None]:
pip install pytorch-adapt

In [None]:
import sys

sys.path.insert(0, "../../src")

### adapters/index.md initialization

In [None]:
import torch

from pytorch_adapt.adapters import DANN
from pytorch_adapt.containers import Models

G = torch.nn.Linear(1000, 100)
C = torch.nn.Linear(100, 10)
D = torch.nn.Sequential(torch.nn.Linear(100, 1), torch.nn.Flatten(start_dim=0))
models = Models({"G": G, "C": C, "D": D})

adapter = DANN(models=models)

### adapters/index.md training step

In [None]:
device = torch.device("cuda")
adapter.models.to(device)

data = {
    "src_imgs": torch.randn(32, 1000),
    "target_imgs": torch.randn(32, 1000),
    "src_labels": torch.randint(0, 10, size=(32,)),
    "src_domain": torch.zeros(32),
    "target_domain": torch.zeros(32),
}

loss = adapter.training_step(data, device)

### adapters/index.md inference

In [None]:
data = torch.randn(32, 1000).to(device)
features, logits = adapter.inference(data)

### containers/index.md create with

In [None]:
import torch

from pytorch_adapt.containers import LRSchedulers, Models, Optimizers

G = torch.nn.Linear(1000, 100)
C = torch.nn.Linear(100, 10)
D = torch.nn.Linear(100, 1)

models = Models({"G": G, "C": C, "D": D})
optimizers = Optimizers((torch.optim.Adam, {"lr": 0.456}))
schedulers = LRSchedulers((torch.optim.lr_scheduler.ExponentialLR, {"gamma": 0.99}))

optimizers.create_with(models)
schedulers.create_with(optimizers)

# optimizers contains an optimizer for G, C, and D
# schedulers contains an LR scheduler for each optimizer

print(models)
print(optimizers)
print(schedulers)

### containers/index.md merge

In [None]:
more_models = Models({"X": torch.nn.Linear(20, 1)})
models.merge(more_models)

optimizers = Optimizers((torch.optim.Adam, {"lr": 0.456}))
special_opt = Optimizers((torch.optim.SGD, {"lr": 1}), keys=["G", "X"])
optimizers.merge(special_opt)
optimizers.create_with(models)

# models contains G, C, D, and X
# optimizers:
# - the Adam optimizer with lr 0.456 for models C and D
# - the SGD optimizer with lr 1 for models G and X

print(models)
print(optimizers)