In [None]:
!pip install pytorch-adapt

### Create some fake data and models

Model names:
- G: feature generator
- C: classifier
- D: discriminator (for adversarial methods)

Data names:
- src_imgs/target_imgs: source or target data. The ```_imgs``` suffix is misleading, as the data doesn't have to be 2d, so this will probably be changed in a future version of the library.
- src_labels: class labels for the source data.
- src_domain/target_domain: integers representing the source and target domain. The convention is 0 for source, and 1 for target.
- src_sample_idx/target_sample_idx: each sample's index in the dataset


In [None]:
from pprint import pprint

import torch

from pytorch_adapt.containers import Models, Optimizers
from pytorch_adapt.hooks import validate_hook

feature_size = 100
G = torch.nn.Linear(1000, feature_size)
C = torch.nn.Linear(feature_size, 10)
D = torch.nn.Sequential(torch.nn.Linear(feature_size, 1), torch.nn.Flatten(start_dim=0))

models = Models({"G": G, "C": C, "D": D})
optimizers = Optimizers((torch.optim.Adam, {"lr": 0.00001}))
optimizers.create_with(models)
opts = list(optimizers.values())


dataset_size = 10000
# one batch worth of "data"
data = {
    "src_imgs": torch.randn(32, 1000),
    "target_imgs": torch.randn(32, 1000),
    "src_labels": torch.randint(0, 10, size=(32,)),
    "src_domain": torch.zeros(32),
    "target_domain": torch.ones(32),
    "src_sample_idx": torch.randint(0, dataset_size, size=(32,)),
    "target_sample_idx": torch.randint(0, dataset_size, size=(32,)),
}

### Register PyTorch forward hooks for demonstration

This will keep track of how many times each model is used.

In [None]:
def forward_count(self, *_):
    self.count += 1


G.register_forward_hook(forward_count)
C.register_forward_hook(forward_count)
D.register_forward_hook(forward_count)

### Helper function for this demo

In [None]:
def print_info(model_counts, losses, outputs, G, C, D=None):
    def get_shape(v):
        if isinstance(v, torch.Tensor):
            return v.shape
        elif isinstance(v, list):
            return [z.shape for z in v]

    print(f"Expected model counts = {dict(model_counts)}")
    true_str = f"True model counts = G: {G.count}, C: {C.count}"
    if D:
        true_str += f", D: {D.count}"
    print(true_str)
    pprint(losses)
    pprint({k: get_shape(v) for k, v in outputs.items()})

### Source Classifier

This hook applies a cross entropy loss on the source data, so it requires source logits to be computed. 

Therefore, each model (G and C) will be used once:
```src_features_logits = C(G(src_imgs))```.

We can use ```validate_hook``` to verify that the hook will work with the given data. This function also returns the expected number of times each model will be used.

In [None]:
from pytorch_adapt.hooks import ClassifierHook

# Reset counts
G.count, C.count = 0, 0
hook = ClassifierHook(opts)
model_counts = validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **data})
print_info(model_counts, losses, outputs, G, C)

### Source Classifier + BSP + BNM

Now we'll use the same ```ClassifierHook``` but add some hooks that are useful for domain adaptation.

The ```BSPHook``` requires source and target features: 

- ```src_features = G(src_imgs)```

- ```target_features = G(target_imgs)```

The ```BNMHook``` requires target logits: ```target_features_logits = C(target_features)```

The source logits still need to be computed for the source classification loss. So in total, each model will be used twice.

To use these hooks, we pass them as a list into the ```post``` argument. This means that the losses will be computed in the following order: classification, BSP, BNM. The ```ClassifierHook``` takes in optimizers as its first argument, so after the loss is computed, it also computes gradients and updates model weights.

The BSP loss tends to be very large, so we add a ```MeanWeighter```. This multiplies each loss by a scalar (1 by default), and then returns the mean of the scaled losses. In this case, we change the weight for ```bsp_loss``` to ```1e-5```.

The hook outputs two dictionaries:

- losses: a two-level dictionary where the outer level is associated with a particular optimization step (relevant for GAN architectures), and the inner level contains the loss components.
- outputs: all the data that was generated by models.

In [None]:
from pytorch_adapt.hooks import BNMHook, BSPHook
from pytorch_adapt.weighters import MeanWeighter

# Reset counts
G.count, C.count = 0, 0
weighter = MeanWeighter(weights={"bsp_loss": 1e-5})
hook = ClassifierHook(opts, post=[BSPHook(), BNMHook()], weighter=weighter)
model_counts = validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **data})
print_info(model_counts, losses, outputs, G, C)

### DANN

Let's try DANN next. DANN uses a discriminator that tries to distinguish between source and target features. The required data for computing the adversarial loss is:

- ```src_features = G(src_imgs)```
- ```target_features = G(target_imgs)```
- ```src_features_dlogits = D(src_features)```
- ```target_features_dlogits = D(target_features)```

The ```_dlogits``` suffix represents the output of the discriminator model. In addition to these outputs, DANN uses a classification loss on source data:

- ```src_features_logits = C(src_features)```

Based on these requirements, the model counts should be G:2, D:2, C:1

In [None]:
from pytorch_adapt.hooks import DANNHook

G.count, C.count, D.count = 0, 0, 0
hook = DANNHook(opts)
model_counts = validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **data})
print_info(model_counts, losses, outputs, G, C, D)

### DANN + MCC + ATDOC

Now we'll add two hooks to DANN:

- ```MCCHook``` requires target logits. This isn't normally required by DANN, so the count for C should increase by 1.
- ```ATDOCHook``` requires source features and logits. These are already required by DANN, so the count for G and C should remain the same.

We pass these hooks into the ```post_g``` argument, because we want them to use raw source and target features. (If you passed them in as ```post_d``` then they would use the output of the gradient reversal layer, which we don't want in this case.)

In [None]:
from pytorch_adapt.hooks import ATDOCHook, MCCHook

G.count, C.count, D.count = 0, 0, 0
mcc = MCCHook()
atdoc = ATDOCHook(dataset_size=dataset_size, feature_dim=100, num_classes=10)

hook = DANNHook(opts, post_g=[mcc, atdoc])
model_counts = validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **data})
print_info(model_counts, losses, outputs, G, C, D)

### CDAN

The ```CDANHook``` is adversarial like ```DANNHook```, but it doesn't use a gradient reversal layer. Thus, optimization occurs in two steps: one for updating the generator, and one for updating the discriminator. In each step, the discriminator has to recompute its logits, so it will be used 4 times instead of 2.

```CDANHook``` also requires a separate ```feature_combiner``` model that we pass in along with all the other models and data.

You'll notice the outputs have different names from DANN's outputs:

- All of the ```feature_combiner``` outputs contain the ```_combined``` suffix, as well as the names of the tensors that were combined. 
- Tensors with the ```_detached``` suffix are detached from the autograd graph. This is done during the discriminator update, to avoid computing gradients for the generator.

In [None]:
from pytorch_adapt.hooks import CDANHook
from pytorch_adapt.layers import RandomizedDotProduct
from pytorch_adapt.utils import common_functions as c_f

G.count, C.count, D.count = 0, 0, 0
d_opts = opts[2:]
g_opts = opts[:2]
misc = {"feature_combiner": RandomizedDotProduct([feature_size, 10], feature_size)}

hook = CDANHook(d_opts=d_opts, g_opts=g_opts)
model_counts = validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **misc, **data})
print_info(model_counts, losses, outputs, G, C, D)

### CDAN + VAT

Here we present a current failure case of ```validate_hook```. The ```VATHook``` uses ```VATLoss```, and inside of ```VATLoss```, the ```combined_model``` is used twice. ```VATHook``` uses ```VATLoss``` twice, so the ```combined_model``` is used a total of 4 times. However, there is no way for ```validate_hook``` to know this, so its estimates for G and C are off by 4.

In [None]:
from pytorch_adapt.hooks import VATHook

G.count, C.count, D.count = 0, 0, 0
misc["combined_model"] = torch.nn.Sequential(G, C)
hook = CDANHook(d_opts=d_opts, g_opts=g_opts, post_g=[VATHook()])
model_counts = validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **misc, **data})
print_info(model_counts, losses, outputs, G, C, D)

### MCD

In [None]:
from pytorch_adapt.hooks import MCDHook
from pytorch_adapt.layers import MultipleModels

C2 = c_f.reinit(C)
C_multiple = MultipleModels(C, C2)
models["C"] = C_multiple

C_multiple.register_forward_hook(forward_count)
G.count, C_multiple.count = 0, 0
g_opts = opts[0:1]
c_opts = opts[1:2]

hook = MCDHook(g_opts=g_opts, c_opts=c_opts)
model_counts = validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **data})
print_info(model_counts, losses, outputs, G, C_multiple)

### MCD + AFN + MMD

In [None]:
from pytorch_adapt.hooks import AFNHook, AlignerHook

hook = MCDHook(g_opts=g_opts, c_opts=c_opts, post_x=[AFNHook()], post_z=[AlignerHook()])
validate_hook(hook, list(data.keys()))
losses, outputs = hook({}, {**models, **data})