In [23]:
import torch
import random

In [24]:
class InterventionableTransformer():
    def __init__(self, model):
        self.activation = {}
        self.model = model

    # these functions are model dependent
    # they specify how the coordinate system works
    def _coordinate_to_getter(self, coord):
        layer, index = coord
        def hook(model, input, output):
            self.activation[f'{layer}-{index}'] = output[:,index]
        handler = self.model.bert.encoder.layer[layer].output.register_forward_hook(hook)
        return handler

    def _coordinate_to_setter(self, coord):
        layer, index = coord
        def hook(model, input, output):
            # NOTE: This might lead to errors about inplace manipulations during the backprop.
            output[:,index] = self.activation[f'{layer}-{index}']
        handler = self.model.bert.encoder.layer[layer].output.register_forward_hook(hook)
        return handler

    def _slice_to_getter(self, bl_coord, tr_coord):
        # NOTE: Alternative implementation would call _coordinate_to_getter for every coord in the slice
        # In this alternative implementation, user needs to only implement the _coordinate_ functions
        b_layer, l_index = bl_coord
        t_layer, r_index = tr_coord

        handlers = []
        for layer in range(b_layer, t_layer+1):
            def hook(model, input, output):
                self.activation[f'{layer}-{l_index}:{r_index}'] = output[:,l_index:r_index+1]
            handlers.append(self.model.bert.encoder.layer[layer].output.register_forward_hook(hook))
        
        class HandlerList():
            def __init__(self, handlers):
                self.handlers = handlers

            def remove(self):
                for handler in self.handlers:
                    handler.remove()

        return HandlerList(handlers)

    def _slice_to_setter(self, bl_coord, tr_coord):
        b_layer, l_index = bl_coord
        t_layer, r_index = tr_coord

        handlers = []
        for layer in range(b_layer, t_layer+1):
            def hook(model, input, output):
                output[:,l_index:r_index+1] = self.activation[f'{layer}-{l_index}:{r_index}']
            handlers.append(self.model.bert.encoder.layer[layer].output.register_forward_hook(hook))
        
        class HandlerList():
            def __init__(self, handlers):
                self.handlers = handlers

            def remove(self):
                for handler in self.handlers:
                    handler.remove()

        return HandlerList(handlers)


    def forward(self, source, base, coord):
        # NOTE: other ways that do not require constantly adding / removing hooks should exist

        # set hook to get activation
        # get_handler = self.names_to_layers[layer_name].register_forward_hook(self._get_activation(layer_name))
        get_handler = self._coordinate_to_getter(coord)

        # get output on source examples (and also capture the activations)
        source_logits = self.model(**source)

        # remove the handler (don't store activations on base) 
        get_handler.remove()

        # get base logits
        base_logits = self.model(**base)
        
        # set hook to do the intervention
        set_handler = self._coordinate_to_setter(coord)

        # get counterfactual output on base examples
        counterfactual_logits = self.model(**base)

        # remove the handler
        set_handler.remove()

        return source_logits, base_logits, counterfactual_logits


    def forward_slice(self, source, base, bl_coord, tr_coord):
        # NOTE: other ways that do not require constantly adding / removing hooks should exist

        # set hook to get activation
        # get_handler = self.names_to_layers[layer_name].register_forward_hook(self._get_activation(layer_name))
        get_handler = self._slice_to_getter(bl_coord, tr_coord)

        # get output on source examples (and also capture the activations)
        source_logits = self.model(**source)

        # remove the handler (don't store activations on base) 
        get_handler.remove()

        # get base logits
        base_logits = self.model(**base)
        
        # set hook to do the intervention
        set_handler = self._slice_to_setter(bl_coord, tr_coord)

        # get counterfactual output on base examples
        counterfactual_logits = self.model(**base)

        # remove the handler
        set_handler.remove()

        return source_logits, base_logits, counterfactual_logits


In [25]:
from transformers import BertForSequenceClassification, BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = InterventionableTransformer(BertForSequenceClassification.from_pretrained('bert-base-uncased'));

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [26]:
base = tokenizer("tokenize this sentence which is now longer!", return_tensors='pt')
source = tokenizer("short sentence!", return_tensors='pt')

In [27]:
# intervene on a single coord
coord = (10,0)

model.forward(source, base, coord)

(SequenceClassifierOutput(loss=None, logits=tensor([[0.0317, 0.1740]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None),
 SequenceClassifierOutput(loss=None, logits=tensor([[-0.3456,  0.1440]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None),
 SequenceClassifierOutput(loss=None, logits=tensor([[-0.1326,  0.1118]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None))

In [28]:
# intervene on a slice
bl_coord = (9,1)
tr_coord = (10,3)

model.forward_slice(source, base, bl_coord, tr_coord)

(SequenceClassifierOutput(loss=None, logits=tensor([[0.0317, 0.1740]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None),
 SequenceClassifierOutput(loss=None, logits=tensor([[-0.3456,  0.1440]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None),
 SequenceClassifierOutput(loss=None, logits=tensor([[-0.2397,  0.1655]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None))