In [None]:
pip install -q pytorch-adapt

In [None]:
import sys

sys.path.insert(0, "../../src")
import torch
import tqdm

from pytorch_adapt.datasets.getters import get_mnist_mnistm
from pytorch_adapt.hooks import DANNHook
from pytorch_adapt.models import Classifier, Discriminator, MNISTFeatures
from pytorch_adapt.utils.common_functions import batch_to_device

In [None]:
datasets = get_mnist_mnistm(["mnist"], ["mnistm"], folder=".", download=True)
dataloader = torch.utils.data.DataLoader(
    datasets["train"], batch_size=32, num_workers=2
)

In [None]:
device = torch.device("cuda")

G = MNISTFeatures().to(device)
C = Classifier(num_classes=10, in_size=1200, h=256).to(device)
D = Discriminator(in_size=1200, h=256).to(device)
models = {"G": G, "C": C, "D": D}

G_opt = torch.optim.Adam(G.parameters(), lr=0.0001)
C_opt = torch.optim.Adam(C.parameters(), lr=0.0001)
D_opt = torch.optim.Adam(D.parameters(), lr=0.0001)
opts = [G_opt, C_opt, D_opt]

hook = DANNHook(opts)

In [None]:
for i, data in enumerate(tqdm.tqdm(dataloader)):
    data = batch_to_device(data, device)
    loss, _ = hook({}, {**models, **data})