In [1]:
!pip install pytorch-adapt

Collecting pytorch-adapt
  Downloading pytorch_adapt-0.0.61-py3-none-any.whl (137 kB)
[K     |████████████████████████████████| 137 kB 12.4 MB/s 
[?25hCollecting torchmetrics
  Downloading torchmetrics-0.7.2-py3-none-any.whl (397 kB)
[K     |████████████████████████████████| 397 kB 29.1 MB/s 
Collecting pytorch-metric-learning>=1.1.0
  Downloading pytorch_metric_learning-1.2.0-py3-none-any.whl (107 kB)
[K     |████████████████████████████████| 107 kB 47.3 MB/s 
Collecting pyDeprecate==0.3.*
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Installing collected packages: pyDeprecate, torchmetrics, pytorch-metric-learning, pytorch-adapt
Successfully installed pyDeprecate-0.3.2 pytorch-adapt-0.0.61 pytorch-metric-learning-1.2.0 torchmetrics-0.7.2


### Inputs to hooks
Every hook takes in 2 arguments that represent the current context:

- A dictionary of models and tensors.
- An optional dictionary of losses.

### FeaturesHook

In [2]:
import torch

from pytorch_adapt.hooks import FeaturesHook


def forward_count(self, *_):
    self.count += 1


def print_keys_and_count(inputs, outputs, models):
    print("Inputs", list(inputs.keys()))
    print("Outputs", list(outputs.keys()))
    for k, v in models.items():
        print(f"{k}.count = {v.count}")
    print("")


G = torch.nn.Linear(1000, 100)
G.register_forward_hook(forward_count)
G.count = 0

models = {"G": G}
data = {
    "src_imgs": torch.randn(32, 1000),
    "target_imgs": torch.randn(32, 1000),
}

hook = FeaturesHook()

inputs = data
outputs, losses = hook({**models, **inputs})
# Outputs contains src_imgs_features and target_imgs_features.
print_keys_and_count(inputs, outputs, models)

inputs = {**data, **outputs}
outputs, losses = hook({**models, **inputs})
# Outputs is empty because the required outputs are already in the inputs.
# G.count remains the same because G wasn't used for anything.
print_keys_and_count(inputs, outputs, models)

hook = FeaturesHook(detach=True)
outputs, losses = hook({**models, **inputs})
# Detached data is kept separate.
# G.count remains the same because the existing tensors
# were simply detached, and this requires no computation.
print_keys_and_count(inputs, outputs, models)

inputs = data
hook = FeaturesHook(detach=True)
outputs, losses = hook({**models, **inputs})
# G.count increases because the undetached data wasn't passed in
# so it has to be computed
print_keys_and_count(inputs, outputs, models)

inputs = {**data, **outputs}
hook = FeaturesHook()
outputs, losses = hook({**models, **inputs})
# Even though detached data is passed in,
# G.count increases because you can't get undetached data from detached data
print_keys_and_count(inputs, outputs, models)

Inputs ['src_imgs', 'target_imgs']
Outputs ['src_imgs_features', 'target_imgs_features']
G.count = 2

Inputs ['src_imgs', 'target_imgs', 'src_imgs_features', 'target_imgs_features']
Outputs []
G.count = 2

Inputs ['src_imgs', 'target_imgs', 'src_imgs_features', 'target_imgs_features']
Outputs ['src_imgs_features_detached', 'target_imgs_features_detached']
G.count = 2

Inputs ['src_imgs', 'target_imgs']
Outputs ['src_imgs_features_detached', 'target_imgs_features_detached']
G.count = 4

Inputs ['src_imgs', 'target_imgs', 'src_imgs_features_detached', 'target_imgs_features_detached']
Outputs ['src_imgs_features', 'target_imgs_features']
G.count = 6



### LogitsHook

```LogitsHook``` works the same as ```FeaturesHook```, but expects features as input.

In [3]:
from pytorch_adapt.hooks import LogitsHook

C = torch.nn.Linear(100, 10)
C.register_forward_hook(forward_count)
C.count = 0

models = {"C": C}
data = {
    "src_imgs_features": torch.randn(32, 100),
    "target_imgs_features": torch.randn(32, 100),
}
hook = LogitsHook()

inputs = data
outputs, losses = hook({**models, **inputs})
print_keys_and_count(inputs, outputs, models)

Inputs ['src_imgs_features', 'target_imgs_features']
Outputs ['src_imgs_features_logits', 'target_imgs_features_logits']
C.count = 2



### FeaturesAndLogitsHook

```FeaturesAndLogitsHook``` combines ```FeaturesHook``` and ```LogitsHook```.

In [4]:
from pytorch_adapt.hooks import FeaturesAndLogitsHook

G.count, C.count = 0, 0
models = {"G": G, "C": C}
data = {
    "src_imgs": torch.randn(32, 1000),
    "target_imgs": torch.randn(32, 1000),
}
hook = FeaturesAndLogitsHook()

inputs = data
outputs, losses = hook({**models, **inputs})
print_keys_and_count(inputs, outputs, models)

Inputs ['src_imgs', 'target_imgs']
Outputs ['src_imgs_features', 'target_imgs_features', 'src_imgs_features_logits', 'target_imgs_features_logits']
G.count = 2
C.count = 2



### ChainHook

```ChainHook``` allows you to chain together an arbitrary number of hooks. The hooks are run sequentially, with the outputs of hook ```n``` being added to the context so that they become part of the inputs to hook ```n+1```.

In [5]:
from pytorch_adapt.hooks import ChainHook

G.count, C.count = 0, 0
hook = ChainHook(FeaturesHook(), LogitsHook())

inputs = data
outputs, losses = hook({**models, **inputs})
print_keys_and_count(inputs, outputs, models)

Inputs ['src_imgs', 'target_imgs']
Outputs ['src_imgs_features', 'target_imgs_features', 'src_imgs_features_logits', 'target_imgs_features_logits']
G.count = 2
C.count = 2

