In [193]:
import torch
import random

In [162]:
# create a simple pytorch model
class SimpleModel(torch.nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.ff1 = torch.nn.Linear(input_size, input_size)
        self.act1 = torch.nn.ReLU()
        self.ff2 = torch.nn.Linear(input_size, input_size)
        self.act2 = torch.nn.ReLU()
        self.ff3 = torch.nn.Linear(input_size, 1)

    def forward(self, x):
        a = self.act1(self.ff1(x))
        b = self.act2(self.ff2(x))
        x = self.ff3((a+b)/2)
        return x


class Interventionable():
    def __init__(self, model):
        self.activation = {}
        self.model = model

        self.names_to_layers = dict(self.model.named_children())

    def _get_activation(self, name):
        def hook(model, input, output):
            self.activation[name] = output
        return hook

    def _set_activation(self, name):
        def hook(model, input, output):
            output[:] = self.activation[name]
        return hook

    def forward(self, source, base, layer_name):
        # NOTE: other ways that do not require constantly adding / removing hooks should exist
        assert source.shape == base.shape
        assert layer_name in self.names_to_layers

        # set hook to get activation
        get_handler = self.names_to_layers[layer_name].register_forward_hook(self._get_activation(layer_name))

        # get output on source examples (and also capture the activations)
        source_logits = self.model(source)

        # remove the handler (don't store activations on base) 
        get_handler.remove()

        # get base logits
        base_logits = self.model(base)
        
        # set hook to do the intervention
        set_handler = self.names_to_layers[layer_name].register_forward_hook(self._set_activation(layer_name))

        # get counterfactual output on base examples
        counterfactual_logits = self.model(base)

        # remove the handler
        set_handler.remove()

        return source_logits, base_logits, counterfactual_logits


In [163]:
# create two inputs
torch.manual_seed(42)
x_base = torch.rand((1,10))
y_base = torch.ones((1,))
x_source = torch.rand((1,10))
y_source = torch.zeros((1,))

In [180]:
model = Interventionable(SimpleModel(10))

model.forward(x_source, x_base, 'ff3')

(tensor([[0.3744]], grad_fn=<AddmmBackward0>),
 tensor([[0.3403]], grad_fn=<AddmmBackward0>),
 tensor([[0.3744]], grad_fn=<CopySlices>))

In [191]:
# the user can easily define his own coordinate system
layer_coordinates = dict(enumerate(list(dict(model.model.named_children()).keys())))
print(layer_coordinates)

model.forward(x_source, x_base, layer_coordinates[0])

{0: 'ff1', 1: 'act1', 2: 'ff2', 3: 'act2', 4: 'ff3'}


(tensor([[0.3744]], grad_fn=<AddmmBackward0>),
 tensor([[0.3403]], grad_fn=<AddmmBackward0>),
 tensor([[0.4131]], grad_fn=<AddmmBackward0>))

In [198]:
# we can specify alignments over these coordinate systems
model1 = Interventionable(SimpleModel(10))
model2 = Interventionable(SimpleModel(10))

coordinates1 = dict(enumerate(list(dict(model1.model.named_children()).keys())))
coordinates2 = dict(enumerate(list(dict(model2.model.named_children()).keys())))

# alignments could also be specified over layer names ofcourse
alignment = {
    0:2,
    4:4
}
# this alignment is equivalent to:
# model1    model2
# ff1   <-> ff2
# ff3   <-> ff3

# sample a random alignment
k = list(alignment.keys())[random.randint(0,len(alignment)-1)]
v = alignment[k]

source_logits1, base_logits1, counterfactual_logits1 = model1.forward(x_source, x_base, coordinates1[k])
source_logits2, base_logits2, counterfactual_logits2 = model2.forward(x_source, x_base, coordinates2[v])