In [1]:
from datasets import load_dataset
import torch

imdb = load_dataset("imdb")

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset imdb (/Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
100%|██████████| 3/3 [00:00<00:00, 365.94it/s]


In [2]:
small_train_dataset = (
    imdb["train"].shuffle(seed=42).select([i for i in list(range(30))])
)
small_test_dataset = imdb["test"].shuffle(seed=4).select([i for i in list(range(3))])

Loading cached shuffled indices for dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-9c48ce5d173413c7.arrow
Loading cached shuffled indices for dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-d00218895ddb9236.arrow


In [3]:
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("assemblyai/distilbert-base-uncased-sst2")
model = AutoModelForSequenceClassification.from_pretrained(
    "assemblyai/distilbert-base-uncased-sst2"
)

tokenized_segments = tokenizer(
    [
        "AssemblyAI is the best speech-to-text API for modern developers with performance being second to none!"
    ],
    return_tensors="pt",
    padding=True,
    truncation=True,
)
tokenized_segments_input_ids, tokenized_segments_attention_mask = (
    tokenized_segments.input_ids,
    tokenized_segments.attention_mask,
)
model_predictions = F.softmax(
    model(
        input_ids=tokenized_segments_input_ids,
        attention_mask=tokenized_segments_attention_mask,
    )["logits"],
    dim=1,
)

print("Positive probability: " + str(model_predictions[0][1].item() * 100) + "%")
print("Negative probability: " + str(model_predictions[0][0].item() * 100) + "%")

Positive probability: 96.0169792175293%
Negative probability: 3.9830222725868225%


In [4]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding=True)


tokenized_train = small_train_dataset.map(preprocess_function, batched=True)
tokenized_test = small_test_dataset.map(preprocess_function, batched=True)

Loading cached processed dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-2171cbc40247bcf4.arrow
Loading cached processed dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-66cb6f78496beb63.arrow


In [5]:
class ImdbDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, attn_mask, labels):
        self.encodings = encodings
        self.attn_mask = attn_mask
        self.labels = labels

    def __getitem__(self, idx):
        x = torch.tensor([self.encodings[idx], self.attn_mask[idx]])
        y = torch.tensor(self.labels[idx])
        return x, y

    def __len__(self):
        return len(self.labels)

In [6]:
train_dataset = ImdbDataset(
    tokenized_train["input_ids"],
    tokenized_train["attention_mask"],
    tokenized_train["label"],
)
test_dataset = ImdbDataset(
    tokenized_test["input_ids"],
    tokenized_test["attention_mask"],
    tokenized_test["label"],
)

In [7]:
class PatchedModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x[:, 0], x[:, 1])["logits"]

In [8]:
for param in model.named_parameters():
    param[1].requires_grad = False

for m_name, module in model.named_modules():
    if len(list(module.children())) == 0 and len(list(module.parameters())) > 0:
        if isinstance(module, torch.nn.Linear):
            for p_name, param in module.named_parameters():
                if (
                    "ffn" in m_name
                    or "pre_classifier" in m_name
                    or "classifier" in m_name
                ):
                    param.requires_grad = True

In [9]:
for param in model.named_parameters():
    if param[1].requires_grad:
        print(param[0], param[1].shape)

distilbert.transformer.layer.0.ffn.lin1.weight torch.Size([3072, 768])
distilbert.transformer.layer.0.ffn.lin1.bias torch.Size([3072])
distilbert.transformer.layer.0.ffn.lin2.weight torch.Size([768, 3072])
distilbert.transformer.layer.0.ffn.lin2.bias torch.Size([768])
distilbert.transformer.layer.1.ffn.lin1.weight torch.Size([3072, 768])
distilbert.transformer.layer.1.ffn.lin1.bias torch.Size([3072])
distilbert.transformer.layer.1.ffn.lin2.weight torch.Size([768, 3072])
distilbert.transformer.layer.1.ffn.lin2.bias torch.Size([768])
distilbert.transformer.layer.2.ffn.lin1.weight torch.Size([3072, 768])
distilbert.transformer.layer.2.ffn.lin1.bias torch.Size([3072])
distilbert.transformer.layer.2.ffn.lin2.weight torch.Size([768, 3072])
distilbert.transformer.layer.2.ffn.lin2.bias torch.Size([768])
distilbert.transformer.layer.3.ffn.lin1.weight torch.Size([3072, 768])
distilbert.transformer.layer.3.ffn.lin1.bias torch.Size([3072])
distilbert.transformer.layer.3.ffn.lin2.weight torch.Size(

In [10]:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=3, shuffle=True
)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=3, shuffle=True)

In [11]:
first_batch = next(iter(test_dataloader))

In [12]:
first_batch[0].shape

torch.Size([3, 2, 364])

In [13]:
from pydvl.influence import compute_influences
from pydvl.influence.torch import TorchTwiceDifferentiable

patched_model = PatchedModel(model)
patched_model.eval()

ekfac_train_influences = compute_influences(
    TorchTwiceDifferentiable(patched_model, F.cross_entropy),
    training_data=train_dataloader,
    test_data=test_dataloader,
    influence_type="up",
    inversion_method="ekfac",
    hessian_regularization=0.1,
    progress=True,
)

Batch Test Gradients: 100%|██████████| 1/1 [00:01<00:00,  1.36s/it]
Batch Split Input Gradients: 100%|██████████| 10/10 [00:24<00:00,  2.50s/it]


In [15]:
ekfac_train_influences.shape

torch.Size([3, 30])