In [None]:
!pip install pytorch-adapt

### Create some fake data and models

In [None]:
from pprint import pprint

import torch

from pytorch_adapt.containers import Models, Optimizers
from pytorch_adapt.hooks import validate_hook

feature_size = 100
G = torch.nn.Linear(1000, feature_size)
C = torch.nn.Linear(feature_size, 10)
D = torch.nn.Sequential(torch.nn.Linear(feature_size, 1), torch.nn.Flatten(start_dim=0))

models = Models({"G": G, "C": C, "D": D})
optimizers = Optimizers((torch.optim.Adam, {"lr": 0.456}))
optimizers.create_with(models)
opts = list(optimizers.values())

dataset_size = 10000
# one batch worth of "data"
data = {
    "src_imgs": torch.randn(32, 1000),
    "target_imgs": torch.randn(32, 1000),
    "src_labels": torch.randint(0, 10, size=(32,)),
    "src_domain": torch.zeros(32),
    "target_domain": torch.zeros(32),
    "src_sample_idx": torch.randint(0, dataset_size, size=(32,)),
    "target_sample_idx": torch.randint(0, dataset_size, size=(32,)),
}

### Register PyTorch forward hooks for demonstration

This will keep track of how many times each model is used.

In [None]:
def forward_count(self, *_):
    self.count += 1


C.register_forward_hook(forward_count)
G.register_forward_hook(forward_count)

### Source Classifier

This hook applies a cross entropy loss on the source data, so it requires source logits to be computed. 

Therefore, each model (G and C) will be used once:
```src_logits = C(G(src_imgs))```.

We can use ```validate_hook``` to verify that the hook will work with the given data. This function also returns the expected number of times each model will be used.

In [None]:
from pytorch_adapt.hooks import ClassifierHook

# Reset counts
G.count, C.count = 0, 0
hook = ClassifierHook(opts)
model_counts = validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **data})
print(f"Expected model counts = {dict(model_counts)}")
print(f"True model counts = G: {G.count}, C: {C.count}")
pprint(losses)

### Source Classifier + BSP + BNM

Now we'll use the same ```ClassifierHook``` but add some hooks that are useful for domain adaptation.

The ```BSPHook``` requires source and target features: 

- ```src_features = G(src_imgs)```

- ```target_features = G(target_imgs)```

The ```BNMHook``` requires target logits: ```target_logits = C(target_features)```

The source logits still need to be computed for the source classification loss. So in total, each model will be used twice.

To use these hooks, we pass them as a list into the ```post``` argument. This means that the losses will be computed in the following order: classification, BSP, BNM. The ```ClassifierHook``` takes in optimizers as its first argument, so after the loss is computed, it also computes gradients and updates model weights.

The BSP loss tends to be very large, so we add a ```MeanWeighter```. This multiplies each loss by a scalar (1 by default), and then returns the mean of the scaled losses. In this case, we change the weight for ```bsp_loss``` to ```1e-5```.

The hook outputs two dictionaries:

- losses: a two-level dictionary where the outer level is associated with a particular optimization step (relevant for GAN architectures), and the inner level contains the loss components.
- outputs: all the data that was generated by models.

In [None]:
from pytorch_adapt.hooks import BNMHook, BSPHook
from pytorch_adapt.weighters import MeanWeighter

# Reset counts
G.count, C.count = 0, 0
weighter = MeanWeighter(weights={"bsp_loss": 1e-5})
hook = ClassifierHook(opts, post=[BSPHook(), BNMHook()], weighter=weighter)
model_counts = validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **data})
print(f"Expected model counts = {dict(model_counts)}")
print(f"True model counts = G: {G.count}, C: {C.count}")
pprint(losses)
pprint({k: v.shape for k, v in outputs.items()})

### DANN

In [None]:
from pytorch_adapt.hooks import DANNHook

hook = DANNHook(opts)
validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **data})

### DANN + MCC + ATDOC

In [None]:
from pytorch_adapt.hooks import ATDOCHook, MCCHook

mcc = MCCHook()
atdoc = ATDOCHook(dataset_size=dataset_size, feature_dim=100, num_classes=10)

hook = DANNHook(opts, post_g=[mcc, atdoc])
validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **data})

### CDAN

In [None]:
from pytorch_adapt.hooks import CDANHook
from pytorch_adapt.layers import RandomizedDotProduct
from pytorch_adapt.utils import common_functions as c_f

d_opts = opts[2:]
g_opts = opts[:2]
misc = {"feature_combiner": RandomizedDotProduct([feature_size, 10], feature_size)}

hook = CDANHook(d_opts=d_opts, g_opts=g_opts)
validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **misc, **data})

### CDAN + VAT

In [None]:
from pytorch_adapt.hooks import VATHook

misc["combined_model"] = torch.nn.Sequential(G, C)
hook = CDANHook(d_opts=d_opts, g_opts=g_opts, post_g=[VATHook()])
validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **misc, **data})

### MCD

In [None]:
from pytorch_adapt.hooks import MCDHook
from pytorch_adapt.layers import MultipleModels

C2 = c_f.reinit(C)
C = MultipleModels(C, C2)
models["C"] = C

g_opts = opts[0:1]
c_opts = opts[1:2]

hook = MCDHook(g_opts=g_opts, c_opts=c_opts)
validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **data})

### MCD + AFN + MMD

In [None]:
from pytorch_adapt.hooks import AFNHook, AlignerHook

hook = MCDHook(g_opts=g_opts, c_opts=c_opts, post_x=[AFNHook()], post_z=[AlignerHook()])
validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **data})