In [None]:
!pip install pytorch-adapt

### Setup

In [None]:
import copy
from pprint import pprint

import torch

from pytorch_adapt.utils import common_functions as c_f

# Models
G = torch.nn.Linear(1000, 100)
C = torch.nn.Linear(100, 10)
D = torch.nn.Sequential(torch.nn.Linear(100, 1), torch.nn.Flatten(start_dim=0))
G_opt = torch.optim.Adam(G.parameters())
C_opt = torch.optim.Adam(C.parameters())
D_opt = torch.optim.Adam(D.parameters())

dataset_size = 10000
# 1 batch of data
example_data = {
    "src_imgs": torch.randn(32, 1000),
    "target_imgs": torch.randn(32, 1000),
    "src_labels": torch.randint(0, 10, size=(32,)),
    "src_domain": torch.zeros(32),
    "target_domain": torch.zeros(32),
    "src_sample_idx": torch.randint(0, dataset_size, size=(32,)),
    "target_sample_idx": torch.randint(0, dataset_size, size=(32,)),
}


def get_data(keys):
    return {k: example_data[k] for k in keys}

### [Adversarial Discriminative Domain Adaptation](https://arxiv.org/abs/1702.05464) (ADDA)

In [None]:
from pytorch_adapt.hooks import ADDAHook

# make Target model
T = copy.deepcopy(G)
T_opt = torch.optim.Adam(T.parameters())
hook = ADDAHook(g_opts=[T_opt], d_opts=[D_opt])

models = {"G": G, "C": C, "D": D, "T": T}
data = get_data(["src_imgs", "target_imgs", "src_domain", "target_domain"])
losses, _ = hook({}, {**models, **data})
pprint(losses)

### [Larger Norm More Transferable: An Adaptive Feature Norm Approach for Unsupervised Domain Adaptation](https://arxiv.org/abs/1811.07456) (AFN)

In [None]:
from pytorch_adapt.hooks import AFNHook, ClassifierHook

hook = ClassifierHook(opts=[G_opt, C_opt], post=[AFNHook()])

models = {"G": G, "C": C}
data = get_data(["src_imgs", "target_imgs", "src_labels"])
losses, _ = hook({}, {**models, **data})
pprint(losses)

### [Domain Adaptation with Auxiliary Target Domain-Oriented Classifier](https://arxiv.org/abs/2007.04171) (ATDOC)

In [None]:
from pytorch_adapt.hooks import ATDOCHook, ClassifierHook

atdoc = ATDOCHook(dataset_size=10000, feature_dim=100, num_classes=10)
hook = ClassifierHook(opts=[G_opt, C_opt], post=[atdoc])

models = {"G": G, "C": C}
data = get_data(["src_imgs", "target_imgs", "src_labels", "target_sample_idx"])
losses, _ = hook({}, {**models, **data})
pprint(losses)

### [Towards Discriminability and Diversity: Batch Nuclear-norm Maximization under Label Insufficient Situations](https://arxiv.org/abs/2003.12237) (BNM)

In [None]:
from pytorch_adapt.hooks import BNMHook, ClassifierHook

hook = ClassifierHook(opts=[G_opt, C_opt], post=[BNMHook()])

models = {"G": G, "C": C}
data = get_data(["src_imgs", "target_imgs", "src_labels"])
losses, _ = hook({}, {**models, **data})
pprint(losses)