In [None]:
!pip install pytorch-adapt[lightning,ignite]

### Load a toy dataset

In [None]:
import torch
from tqdm import tqdm

from pytorch_adapt.datasets import get_mnist_mnistm

# mnist is the source domain
# mnistm is the target domain
datasets = get_mnist_mnistm(["mnist"], ["mnistm"], ".", download=True)
dataloader = torch.utils.data.DataLoader(
    datasets["train"], batch_size=32, num_workers=2
)

### Load toy models

In [None]:
from pytorch_adapt.models import Discriminator, mnistC, mnistG

device = torch.device("cuda")


def get_models():
    G = mnistG(pretrained=True).to(device)
    C = mnistC(pretrained=True).to(device)
    D = Discriminator(in_size=1200, h=256).to(device)
    return {"G": G, "C": C, "D": D}


def get_optimizers(models):
    G_opt = torch.optim.Adam(models["G"].parameters(), lr=0.0001)
    C_opt = torch.optim.Adam(models["C"].parameters(), lr=0.0001)
    D_opt = torch.optim.Adam(models["D"].parameters(), lr=0.0001)
    return [G_opt, C_opt, D_opt]

### Use in vanilla PyTorch

In [None]:
from pytorch_adapt.hooks import DANNHook
from pytorch_adapt.utils.common_functions import batch_to_device

models = get_models()
optimizers = get_optimizers(models)

# Assuming that models, optimizers, and dataloader are already created.
hook = DANNHook(optimizers)
for data in tqdm(dataloader):
    data = batch_to_device(data, device)
    # Optimization is done inside the hook.
    # The returned loss is for logging.
    loss, _ = hook({}, {**models, **data})

### Build complex algorithms

In [None]:
from pytorch_adapt.hooks import MCCHook, VATHook

models = get_models()
optimizers = get_optimizers(models)

# G and C are the Generator and Classifier models
G, C = models["G"], models["C"]
misc = {"combined_model": torch.nn.Sequential(G, C)}
hook = DANNHook(optimizers, post_g=[MCCHook(), VATHook()])
for data in tqdm(dataloader):
    data = batch_to_device(data, device)
    loss, _ = hook({}, {**models, **data, **misc})

### Wrap with your favorite PyTorch framework

In [None]:
from pytorch_adapt.adapters import DANN
from pytorch_adapt.containers import Models
from pytorch_adapt.datasets import DataloaderCreator

models = get_models()
models_cont = Models(models)
adapter = DANN(models=models_cont)
dc = DataloaderCreator(num_workers=2)
dataloaders = dc(**datasets)

#### Lightning

In [None]:
import pytorch_lightning as pl

from pytorch_adapt.frameworks.lightning import Lightning

L_adapter = Lightning(adapter)
trainer = pl.Trainer(gpus=1, max_epochs=1)
trainer.fit(L_adapter, dataloaders["train"])

#### Ignite

In [None]:
from pytorch_adapt.frameworks.ignite import Ignite

models = get_models()
models_cont = Models(models)
adapter = DANN(models=models_cont)

trainer = Ignite(adapter)
trainer.run(datasets, dataloader_creator=dc)

### Check your model's performance

In [None]:
from pytorch_adapt.validators import SNDValidator

# Random predictions as placeholder
preds = torch.randn(1000, 100)

# Assuming predictions have been collected
target_train = {"preds": preds}
validator = SNDValidator()
score = validator.score(target_train=target_train)

#### Lightning

In [None]:
from pytorch_adapt.frameworks.utils import filter_datasets

models = get_models()
models_cont = Models(models)
adapter = DANN(models=models_cont)
validator = SNDValidator()
dataloaders = dc(**filter_datasets(datasets, validator))
train_loader = dataloaders.pop("train")

L_adapter = Lightning(adapter, validator=validator)
trainer = pl.Trainer(gpus=1, max_epochs=1)
trainer.fit(L_adapter, train_loader, *dataloaders.values())

#### Ignite

In [None]:
from pytorch_adapt.validators import ScoreHistory

models = get_models()
models_cont = Models(models)
adapter = DANN(models=models_cont)

validator = ScoreHistory(SNDValidator())
trainer = Ignite(adapter, validator=validator)
trainer.run(datasets, dataloader_creator=dc)