In [2]:
import torch
import random

In [4]:
class InterventionableTransformer():
    def __init__(self, model):
        self.activation = {}
        self.model = model

    # these functions are model dependent
    # they specify how the coordinate system works
    def _coordinate_to_getter(self, coord):
        layer, index = coord
        def hook(model, input, output):
            self.activation[f'{layer}-{index}'] = output[:,index]
        handler = self.model.bert.encoder.layer[layer].output.register_forward_hook(hook)
        return handler

    def _coordinate_to_setter(self, coord):
        layer, index = coord
        def hook(model, input, output):
            # NOTE: This might lead to errors about inplace manipulations during the backprop.
            output[:,index] = self.activation[f'{layer}-{index}']
        handler = self.model.bert.encoder.layer[layer].output.register_forward_hook(hook)
        return handler

    def _slice_to_getter(self, bl_coord, tr_coord):
        # NOTE: Alternative implementation would call _coordinate_to_getter for every coord in the slice
        # In this alternative implementation, user needs to only implement the _coordinate_ functions
        b_layer, l_index = bl_coord
        t_layer, r_index = tr_coord

        handlers = []
        for layer in range(b_layer, t_layer+1):
            def hook(model, input, output):
                self.activation[f'{layer}-{l_index}:{r_index}'] = output[:,l_index:r_index+1]
            handlers.append(self.model.bert.encoder.layer[layer].output.register_forward_hook(hook))
        
        class HandlerList():
            def __init__(self, handlers):
                self.handlers = handlers

            def remove(self):
                for handler in self.handlers:
                    handler.remove()

        return HandlerList(handlers)

    def _slice_to_setter(self, bl_coord, tr_coord):
        b_layer, l_index = bl_coord
        t_layer, r_index = tr_coord

        handlers = []
        for layer in range(b_layer, t_layer+1):
            def hook(model, input, output):
                output[:,l_index:r_index+1] = self.activation[f'{layer}-{l_index}:{r_index}']
            handlers.append(self.model.bert.encoder.layer[layer].output.register_forward_hook(hook))
        
        class HandlerList():
            def __init__(self, handlers):
                self.handlers = handlers

            def remove(self):
                for handler in self.handlers:
                    handler.remove()

        return HandlerList(handlers)


    def forward(self, source, base, coord, source_labels=None, base_labels=None, counterfactual_labels=None):
        # NOTE: other ways that do not require constantly adding / removing hooks should exist

        # set hook to get activation
        # get_handler = self.names_to_layers[layer_name].register_forward_hook(self._get_activation(layer_name))
        get_handler = self._coordinate_to_getter(coord)

        # get output on source examples (and also capture the activations)
        source_logits = self.model(**source, labels=source_labels)

        # remove the handler (don't store activations on base) 
        get_handler.remove()

        # get base logits
        base_logits = self.model(**base, labels=base_labels)
        
        # set hook to do the intervention
        set_handler = self._coordinate_to_setter(coord)

        # get counterfactual output on base examples
        counterfactual_logits = self.model(**base, labels=counterfactual_labels)

        # remove the handler
        set_handler.remove()

        return source_logits, base_logits, counterfactual_logits


    def forward_slice(self, source, base, bl_coord, tr_coord, source_labels=None, base_labels=None, counterfactual_labels=None):
        # NOTE: other ways that do not require constantly adding / removing hooks should exist

        # set hook to get activation
        # get_handler = self.names_to_layers[layer_name].register_forward_hook(self._get_activation(layer_name))
        get_handler = self._slice_to_getter(bl_coord, tr_coord)

        # get output on source examples (and also capture the activations)
        source_logits = self.model(**source, labels=source_labels)

        # remove the handler (don't store activations on base) 
        get_handler.remove()

        # get base logits
        base_logits = self.model(**base, labels=base_labels)
        
        # set hook to do the intervention
        set_handler = self._slice_to_setter(bl_coord, tr_coord)

        # get counterfactual output on base examples
        counterfactual_logits = self.model(**base, labels=counterfactual_labels)

        # remove the handler
        set_handler.remove()

        return source_logits, base_logits, counterfactual_logits


In [5]:
from transformers import BertForSequenceClassification, BertTokenizer, BertConfig

config = BertConfig.from_pretrained('bert-base-uncased')
config.num_labels = 1
model = BertForSequenceClassification(config)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = InterventionableTransformer(model)

In [6]:
base = tokenizer("tokenize this sentence which is now longer!", return_tensors='pt')
source = tokenizer("short sentence!", return_tensors='pt')

In [7]:
# intervene on a single coord
coord = (10,0)

model.forward(source, base, coord)

(SequenceClassifierOutput(loss=None, logits=tensor([[-0.2717]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None),
 SequenceClassifierOutput(loss=None, logits=tensor([[-0.4269]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None),
 SequenceClassifierOutput(loss=None, logits=tensor([[-0.3380]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None))

In [8]:
# intervene on a slice
bl_coord = (9,1)
tr_coord = (10,3)

print(model.forward_slice(source, base, bl_coord, tr_coord))
print(model.forward_slice(source, base, bl_coord, tr_coord))

(SequenceClassifierOutput(loss=None, logits=tensor([[-0.4341]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None), SequenceClassifierOutput(loss=None, logits=tensor([[-0.5688]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None), SequenceClassifierOutput(loss=None, logits=tensor([[-0.4428]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None))
(SequenceClassifierOutput(loss=None, logits=tensor([[-0.3204]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None), SequenceClassifierOutput(loss=None, logits=tensor([[-0.2707]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None), SequenceClassifierOutput(loss=None, logits=tensor([[-0.2740]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None))


### checking backpropagation using the interventionable

In [9]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# creating some fake labels
source_labels=torch.randint(high=1,size=(1,1),device=device).float()
base_labels=torch.randint(high=1,size=(1,1),device=device).float()
counterfactual_labels=torch.randint(high=1,size=(1,1),device=device).float()

source = source.to(device)
base = base.to(device)
model.model.to(device)

coord = (10,0)

model.model.train()
source_logits, base_logits, counterfactual_logits = model.forward(source, base, coord, source_labels=source_labels, base_labels=base_labels, counterfactual_labels=counterfactual_labels)

# config
config = model.model.config
# extra params
lr = 5e-5
num_epochs = 10
num_steps_per_epoch = 10 # NOTE: normally get this from the dataloader
num_training_steps = num_epochs * num_steps_per_epoch
num_warmup_steps = 0

# optimizer
from transformers import AdamW
optimizer = AdamW(model.model.parameters(), lr=lr)

# scheduler
from transformers import get_scheduler
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)

# update
source_loss = source_logits.loss
base_loss = base_logits.loss
counterfactual_loss = counterfactual_logits.loss


loss = source_loss #+ base_loss + counterfactual_loss
print(loss)
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()

print("source_loss:",source_loss)
print("base_loss:",base_loss)
print("counterfactual_loss:",counterfactual_loss)



tensor(0.1240, grad_fn=<MseLossBackward0>)
source_loss: tensor(0.1240, grad_fn=<MseLossBackward0>)
base_loss: tensor(0.0793, grad_fn=<MseLossBackward0>)
counterfactual_loss: tensor(0.1069, grad_fn=<MseLossBackward0>)


### checking how arithmetic sequences are tokenized

In [10]:
input_ids = tokenizer("5+4-(3+2)", return_tensors='pt')['input_ids'].squeeze()

In [11]:
print(tokenizer.decode(input_ids))

tokens = []
for id in input_ids:
    tokens.append(tokenizer.decode([id]))
print(tokens)

[CLS] 5 + 4 - ( 3 + 2 ) [SEP]
['[CLS]', '5', '+', '4', '-', '(', '3', '+', '2', ')', '[SEP]']
