In [None]:
!pip install pytorch-adapt

### Inputs to hooks
Every hook takes in 2 arguments that represent the current context:

- A dictionary of previously computed losses.
- A dictionary of everything else that has been previously computed or passed in.

### FeaturesHook

In [None]:
import torch

from pytorch_adapt.hooks import FeaturesHook


def forward_count(self, *_):
    self.count += 1


def print_keys_and_count(inputs, outputs, models):
    print("Inputs", list(inputs.keys()))
    print("Outputs", list(outputs.keys()))
    for k, v in models.items():
        print(f"{k}.count = {v.count}")
    print("")


G = torch.nn.Linear(1000, 100)
G.register_forward_hook(forward_count)
G.count = 0

models = {"G": G}
data = {
    "src_imgs": torch.randn(32, 1000),
    "target_imgs": torch.randn(32, 1000),
}

hook = FeaturesHook()

inputs = data
losses, outputs = hook({}, {**models, **inputs})
# Outputs contains src_imgs_features and target_imgs_features.
print_keys_and_count(inputs, outputs, models)

inputs = {**data, **outputs}
losses, outputs = hook({}, {**models, **inputs})
# Outputs is empty because the required outputs are already in the inputs.
# G.count remains the same because G wasn't used for anything.
print_keys_and_count(inputs, outputs, models)

hook = FeaturesHook(detach=True)
losses, outputs = hook({}, {**models, **inputs})
# Detached data is kept separate.
# G.count remains the same because the existing tensors
# were simply detached, and this requires no computation.
print_keys_and_count(inputs, outputs, models)

inputs = data
hook = FeaturesHook(detach=True)
losses, outputs = hook({}, {**models, **inputs})
# G.count increases because the undetached data wasn't passed in
# so it has to be computed
print_keys_and_count(inputs, outputs, models)

inputs = {**data, **outputs}
hook = FeaturesHook()
losses, outputs = hook({}, {**models, **inputs})
# Even though detached data is passed in,
# G.count increases because you can't get undetached data from detached data
print_keys_and_count(inputs, outputs, models)

### LogitsHook

```LogitsHook``` works the same as ```FeaturesHook```, but expects features as input.

In [None]:
from pytorch_adapt.hooks import LogitsHook

C = torch.nn.Linear(100, 10)
C.register_forward_hook(forward_count)
C.count = 0

models = {"C": C}
data = {
    "src_imgs_features": torch.randn(32, 100),
    "target_imgs_features": torch.randn(32, 100),
}
hook = LogitsHook()

inputs = data
losses, outputs = hook({}, {**models, **inputs})
print_keys_and_count(inputs, outputs, models)

### FeaturesAndLogitsHook

```FeaturesAndLogitsHook``` combines ```FeaturesHook``` and ```LogitsHook```.

In [None]:
from pytorch_adapt.hooks import FeaturesAndLogitsHook

G.count, C.count = 0, 0
models = {"G": G, "C": C}
data = {
    "src_imgs": torch.randn(32, 1000),
    "target_imgs": torch.randn(32, 1000),
}
hook = FeaturesAndLogitsHook()

inputs = data
losses, outputs = hook({}, {**models, **inputs})
print_keys_and_count(inputs, outputs, models)

### ChainHook

```ChainHook``` allows you to chain together an arbitrary number of hooks. The hooks are run sequentially, with the outputs of hook ```n``` being added to the context so that they become part of the inputs to hook ```n+1```.

In [None]:
from pytorch_adapt.hooks import ChainHook

G.count, C.count = 0, 0
hook = ChainHook(FeaturesHook(), LogitsHook())

inputs = data
losses, outputs = hook({}, {**models, **inputs})
print_keys_and_count(inputs, outputs, models)