In [None]:
!pip install pytorch-adapt

### Computing Features

Every hook takes in 2 arguments that represent the current context:

1. A dictionary of previously computed losses.
2. A dictionary of everything else that has been previously computed or passed in.

In [None]:
import torch

from pytorch_adapt.hooks import FeaturesHook


def forward_count(self, *_):
    self.count += 1


def print_keys_and_count(inputs, outputs, G):
    print("Inputs", list(inputs.keys()))
    print("Outputs", list(outputs.keys()))
    print(f"G.count = {G.count}\n")


G = torch.nn.Linear(1000, 100)
G.register_forward_hook(forward_count)
G.count = 0

models = {"G": G}
data = {
    "src_imgs": torch.randn(32, 1000),
    "target_imgs": torch.randn(32, 1000),
}

hook = FeaturesHook()

inputs = data
losses, outputs = hook({}, {**models, **inputs})
# Outputs contains src_imgs_features and target_imgs_features.
print_keys_and_count(inputs, outputs, G)

inputs = {**data, **outputs}
losses, outputs = hook({}, {**models, **inputs})
# Outputs is empty because the required outputs are already in the inputs.
# G.count remains the same because G wasn't used for anything.
print_keys_and_count(inputs, outputs, G)

hook = FeaturesHook(detach=True)
losses, outputs = hook({}, {**models, **inputs})
# Detached data is kept separate.
# G.count remains the same because the existing tensors
# were simply detached, and this requires no computation.
print_keys_and_count(inputs, outputs, G)

inputs = data
hook = FeaturesHook(detach=True)
losses, outputs = hook({}, {**models, **inputs})
# G.count increases because the undetached data wasn't passed in
# so it has to be computed
print_keys_and_count(inputs, outputs, G)

inputs = {**data, **outputs}
hook = FeaturesHook()
losses, outputs = hook({}, {**models, **inputs})
# Even though detached data is passed in,
# G.count increases because you can't get undetached data from detached data
print_keys_and_count(inputs, outputs, G)