In [1]:
import torch
import torch.nn as nn

from concept_erasure import LeaceFitter
from functools import partial
from transformers import AutoModelForSequenceClassification, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
bert = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased", num_labels=3
)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
for k in dict(bert.named_modules()).keys():
    print(k)


bert
bert.embeddings
bert.embeddings.word_embeddings
bert.embeddings.position_embeddings
bert.embeddings.token_type_embeddings
bert.embeddings.LayerNorm
bert.embeddings.dropout
bert.encoder
bert.encoder.layer
bert.encoder.layer.0
bert.encoder.layer.0.attention
bert.encoder.layer.0.attention.self
bert.encoder.layer.0.attention.self.query
bert.encoder.layer.0.attention.self.key
bert.encoder.layer.0.attention.self.value
bert.encoder.layer.0.attention.self.dropout
bert.encoder.layer.0.attention.output
bert.encoder.layer.0.attention.output.dense
bert.encoder.layer.0.attention.output.LayerNorm
bert.encoder.layer.0.attention.output.dropout
bert.encoder.layer.0.intermediate
bert.encoder.layer.0.intermediate.dense
bert.encoder.layer.0.intermediate.intermediate_act_fn
bert.encoder.layer.0.output
bert.encoder.layer.0.output.dense
bert.encoder.layer.0.output.LayerNorm
bert.encoder.layer.0.output.dropout
bert.encoder.layer.1
bert.encoder.layer.1.attention
bert.encoder.layer.1.attention.self
bert.e

In [4]:
bert_layers = list(bert.bert.encoder.layer.children())

In [5]:
class HookManager:
    def __init__(self, module_list):
        self.module_list = module_list
        self.handles = []

    def register_hooks(self, hook_fn):
        for module in self.module_list:
            handle = module.register_forward_hook(hook_fn)
            self.handles.append(handle)

    def remove_hooks(self):
        for handle in self.handles:
            handle.remove()
        self.handles = []


class LeaceCLS(HookManager):
    def __init__(self, module_list, num_concepts: int = 2):
        super().__init__(module_list)
        self.num_concepts = num_concepts
        self.leace_erasers = {module: None for module in module_list}

    def fit(self):
        self.register_hooks(self.leace_fit_hook)

    def erase(self):
        self.register_hooks(self.leace_erase_hook)

    def leace_fit_hook(self, module, input, output):
        cls_rep = output[0][:, 0, :]
        n_per_concept = cls_rep.shape[0] // self.num_concepts
        labels = torch.tensor(
            [i // n_per_concept for i in range(cls_rep.shape[0])],
            dtype=torch.long,
            device=cls_rep.device,
        )

        if self.leace_erasers[module] is None:
            self.leace_erasers[module] = LeaceFitter(
                cls_rep.shape[-1], self.num_concepts - 1
            )

        self.leace_erasers[module].update(cls_rep, labels)

    def leace_erase_hook(self, module, input, output):
        cls_rep = output[0][:, 0, :]
        new_cls_rep = self.leace_erasers[module].eraser(cls_rep)
        output[0][:, 0, :] = new_cls_rep
        return output

In [6]:
leace_cls = LeaceCLS(bert_layers, num_concepts=2)

In [7]:
leace_cls.fit()
bert(
    **tokenizer(
        ["coucou"] * 50,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=12,
    )
)
leace_cls.remove_hooks()
leace_cls.erase()
bert(
    **tokenizer(
        ["coucou"] * 50,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=12,
    )
)
leace_cls.remove_hooks()