In [4]:
import torch

In [78]:
class CausalArithmetic(torch.nn.Module):
    def __init__(self, input_size = 1):
        super().__init__()
        self.x = torch.nn.Identity()
        self.y = torch.nn.Identity()
        self.z = torch.nn.Identity()

        self.S = torch.nn.Identity()
        self.O = torch.nn.Identity()

    def forward(self,input):
        # We multiply each intermediate value with a trivial layer
        # This allows us to attach hooks to the inttermediate values

        # NOTE: would like to abstract this
        # NOTE: overhead due to pytorch automatically tracking the backward graph,
        # while this is not needed because the model won't be trained?

        # NOTE: without copying, the intervention also changes the input tensor if we intervene on x,y, or z.
        x = torch.clone(input[:,0])
        y = torch.clone(input[:,1])
        z = torch.clone(input[:,2])

        x = self.x(x)
        y = self.y(y)
        z = self.z(z)

        S = self.S(x + y)
        O = self.O(S + z)
        return O


In [79]:
# TODO: Refactor in package
class Interventionable():
    def __init__(self, model):
        self.activation = {}
        self.model = model

        self.names_to_layers = dict(self.model.named_children())

    def _get_activation(self, name):
        def hook(model, input, output):
            self.activation[name] = output
        return hook

    def _set_activation(self, name):
        def hook(model, input, output):
            return self.activation[name]
        return hook

    def forward(self, source, base, layer_name):
        # NOTE: other ways that do not require constantly adding / removing hooks should exist
        assert source.shape == base.shape
        assert layer_name in self.names_to_layers

        # set hook to get activation
        get_handler = self.names_to_layers[layer_name].register_forward_hook(self._get_activation(layer_name))

        # get output on source examples (and also capture the activations)
        source_logits = self.model(source)

        # remove the handler (don't store activations on base) 
        get_handler.remove()

        # get base logits
        base_logits = self.model(base)
        
        # set hook to do the intervention
        set_handler = self.names_to_layers[layer_name].register_forward_hook(self._set_activation(layer_name))

        # get counterfactual output on base examples
        counterfactual_logits = self.model(base)

        # remove the handler
        set_handler.remove()

        return source_logits, base_logits, counterfactual_logits


In [80]:
model = CausalArithmetic()

base = torch.tensor([[3,6,9]])
print(model(base))
source = torch.tensor([[4,4,4]])
print(model(source))

tensor([18])
tensor([12])


In [83]:
model = Interventionable(CausalArithmetic())
print(model.forward(source, base, "x"))

print(base)
print(source)


(tensor([12]), tensor([18]), tensor([19]))
tensor([[3, 6, 9]])
tensor([[4, 4, 4]])


In [10]:
# chalenge is that hooks attach to layers
# to implement a simple causal model we don't necessarily need layers
# but we would still like to define nodes of the computational graph

# --> check how the PyTorch computational graph works
# We can probably hack our way around this problem by defining identity-operation layers

# Is this the cleanest way to extend interventions to causal models?

# Alternative would be to use a different CompGraph package and define our own set of hooks (meh)