In [1]:
from transformers import DistilBertTokenizerFast, AutoModel, DistilBertModel
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F
import pandas as pd
from tqdm import tqdm
import copy

# code is based on repo: https://github.com/xszheng2020/memorization
# data is from https://github.com/xszheng2020/memorization/tree/master/sst/data
train = pd.read_csv("train.csv")
val = pd.read_csv("dev.csv")
test = pd.read_csv("test.csv")
num_classes = train["label"].nunique()

tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")


class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        row = self.data.iloc[index]
        label = row[0]
        sentence = row[1]
        return label, sentence

    def __len__(self):
        return len(self.data)


class CustomDatasetWithMask(Dataset):
    def __init__(self, data, mask=None):
        self.data = data
        self.mask = mask

    def __getitem__(self, index):
        row = self.data.iloc[index]
        label = row[0]
        sentence = row[1]
        sample_index = row[2]
        mask = self.mask[sample_index]
        return label, sentence, mask

    def __len__(self):
        return len(self.data)

### Finetune model

In [2]:
class CustomModel(nn.Module):
    def __init__(
        self,
    ):
        super(CustomModel, self).__init__()
        self.bert = DistilBertModel.from_pretrained(
            "distilbert-base-uncased",
            output_hidden_states=False,
            output_attentions=False,
        )
        self.classifier = nn.Linear(768, num_classes)

    def forward(self, inputs, is_ids, attention_mask):
        if is_ids:
            last_hidden = self.bert(input_ids=inputs, attention_mask=attention_mask)[0]
        else:
            last_hidden = self.bert(
                inputs_embeds=inputs, attention_mask=attention_mask
            )[0]
        ####
        cls_embedding = last_hidden[:, 0, :]  # (bs, dim) pooled_output = cls_embedding
        ####
        logits = self.classifier(cls_embedding)  # (bs, num_labels)
        ####
        return cls_embedding, logits


model = CustomModel()
model.cuda()


def collate_fn(batch):
    label = [b[0] for b in batch]
    text_a = [b[1] for b in batch]
    my_dict = tokenizer(
        text_a,
        None,
        add_special_tokens=True,
        padding=True,
        truncation=True,
        max_length=512,
        return_attention_mask=True,
        return_tensors="pt",
    )
    label = torch.tensor(label)
    return label, my_dict["input_ids"], my_dict["attention_mask"]


BATCH_SIZE = 100
EPOCHS = 150
early_stop_steps = 20
from_last_step = 0
# Prepare dataloaders
train_dataset = CustomDataset(data=train)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=0,
    shuffle=False,
    collate_fn=collate_fn,
)

val_dataset = CustomDataset(data=val)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    num_workers=0,
    shuffle=False,
    collate_fn=collate_fn,
)


model = CustomModel()
model.cuda()
optimizer_grouped_parameters = [
    {"params": [p for n, p in model.named_parameters() if (p.requires_grad == True)]},
]
optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=0.000001)


best_val_acc = 0
best_state_dict = None

# train model
for i in range(EPOCHS):
    model.train()
    train_correct = 0
    total_train_loss = 0
    total_val_loss = 0
    val_correct = 0

    for batch in train_dataloader:
        ####
        model.zero_grad()
        optimizer.zero_grad()
        ####
        label = batch[0].cuda()
        input_ids = batch[1].cuda()
        attention_mask = batch[2].cuda()

        _, outputs = model(input_ids, True, attention_mask=attention_mask)

        loss = F.cross_entropy(outputs, label)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        train_correct += (torch.argmax(outputs, axis=-1) == label).sum().item()

    model.eval()
    with torch.no_grad():
        for batch in val_dataloader:
            label = batch[0].cuda()
            input_ids = batch[1].cuda()
            attention_mask = batch[2].cuda()
            _, outputs = model(input_ids, True, attention_mask=attention_mask)
            total_val_loss += loss.item()
            val_correct += (torch.argmax(outputs, axis=-1) == label).sum().item()

    val_acc = val_correct / len(val)
    # save best checkpoint
    if val_acc >= best_val_acc:
        best_val_acc = val_acc
        best_state_dict = copy.deepcopy(model.state_dict())
        from_last_step = 0
    # early stop
    if from_last_step >= early_stop_steps:
        break
    from_last_step += 1
    if i % 10 == 0:
        print(
            i,
            "train_acc:",
            round(train_correct / len(train), 6),
            "train_loss:",
            round(total_train_loss / len(train), 6),
            "val_loss:",
            round(total_val_loss / len(val), 6),
            "val_acc:",
            round(val_correct / len(val), 6),
        )
# load best model
model.load_state_dict(best_state_dict)
best_val_acc

In [3]:
import random

L2_LAMBDA = 5e-3
# Set random seed
def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seeds(42)


def compute_s(model, v, train_data_loader, damp, scale, num_samples):

    last_estimate = list(v).copy()
    for i, batch in enumerate(train_data_loader):
        ####
        label = batch[0].cuda()
        input_ids = batch[1].cuda()
        ####
        attention_mask = batch[2].cuda()
        ####
        this_estimate = compute_hessian_vector_products(
            model=model,
            vectors=last_estimate,
            label=label,
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        # Recursively caclulate h_estimate
        with torch.no_grad():
            new_estimate = [
                a + (1 - damp) * b - c / scale
                for a, b, c in zip(v, last_estimate, this_estimate)
            ]
        ####

        new_estimate_norm = new_estimate[0].norm().item()
        last_estimate_norm = last_estimate[0].norm().item()
        estimate_norm_diff = new_estimate_norm - last_estimate_norm
        ####
        last_estimate = new_estimate

        if i > num_samples:  # should be i>=(num_samples-1) but does not matters
            break

    # References:
    # https://github.com/kohpangwei/influence-release/blob/master/influence/genericNeuralNet.py#L475
    # Do this for each iteration of estimation
    # Since we use one estimation, we put this at the end
    inverse_hvp = [X / scale for X in last_estimate]

    return inverse_hvp


def compute_hessian_vector_products(model, vectors, label, input_ids, attention_mask):
    ####
    _, outputs = model(
        inputs=input_ids,
        is_ids=True,
        attention_mask=attention_mask,
    )
    ce_loss = F.cross_entropy(outputs, label)
    ####
    hack_loss = torch.cat(
        [
            (p**2).view(-1)
            for n, p in model.named_parameters()
            if ((not any(nd in n for nd in no_decay)) and (p.requires_grad == True))
        ]
    ).sum() * (L2_LAMBDA)
    ####
    loss = ce_loss + hack_loss
    ####
    model.zero_grad()
    grad_tuple = torch.autograd.grad(
        outputs=loss,
        inputs=[
            param for name, param in model.named_parameters() if param.requires_grad
        ],
        create_graph=True,
    )
    ####
    # model.zero_grad()
    grad_grad_tuple = torch.autograd.grad(
        outputs=grad_tuple,
        inputs=[
            param for name, param in model.named_parameters() if param.requires_grad
        ],
        grad_outputs=vectors,
        only_inputs=True,
    )

    return grad_grad_tuple

### Calculate memorization scores for finetuned model

In [4]:
from contexttimer import Timer

train_dataloader = DataLoader(
    train_dataset, batch_size=1, num_workers=0, shuffle=False, collate_fn=collate_fn
)

my_hook_1 = {}
my_hook_2 = {}

emb = None


def hook_func_1(module, input_, output):
    my_hook_1["out"] = output


def hook_func_2(module, input_, output):
    output.data = emb.data
    my_hook_2["out"] = output
    # my_hook_2['out'].requires_grad = True # Important # not needed for model without freeze

####
no_decay = ["bias", "output_layer_norm.weight"]
output_collections = []
####
for idx, batch in enumerate(tqdm(train_dataloader)):
    ####
    z_label = batch[0].cuda()
    z_input_ids = batch[1].cuda()
    z_attention_mask = batch[2].cuda()
    ####
    row = train.iloc[idx]
    tokens = tokenizer.convert_ids_to_tokens(z_input_ids[0].cpu().numpy())
    ####
    baseline = z_input_ids.clone()
    baseline[baseline != 0] = 103

    hook_1 = model.bert.transformer.layer[0].output_layer_norm.register_forward_hook(
        hook_func_1
    )
    _, _ = model(inputs=baseline, is_ids=True, attention_mask=z_attention_mask)
    bemb = my_hook_1["out"].clone()
    hook_1.remove()
    ####
    hook_1 = model.bert.transformer.layer[0].output_layer_norm.register_forward_hook(
        hook_func_1
    )
    _, outputs = model(inputs=z_input_ids, is_ids=True, attention_mask=z_attention_mask)
    wemb = my_hook_1["out"].clone()
    hook_1.remove()
    ####
    prob = F.softmax(outputs, dim=-1)
    prediction = torch.argmax(prob, dim=1)

    prob_gt = torch.gather(prob, 1, z_label.unsqueeze(1))
    ####
    model.zero_grad()

    v = torch.autograd.grad(
        outputs=prob_gt,
        inputs=[
            param for name, param in model.named_parameters() if param.requires_grad
        ],
        create_graph=False,
    )
    ####
    for repetition in range(4):
        with Timer() as timer:
            ####
            train_dataloader = DataLoader(
                train_dataset,
                batch_size=1,
                num_workers=0,
                shuffle=True,
                # pin_memory=True,
                collate_fn=collate_fn,
            )
            ####
            s = compute_s(
                model=model,
                v=v,
                train_data_loader=train_dataloader,
                damp=5e-3,
                scale=1e4,
                num_samples=100,
            )
            ####
            time_elapsed = timer.elapsed
        # print(f"{time_elapsed:.2f} seconds")
        ####
        hessian = None
        steps = 50
        ####
        for alpha in np.linspace(
            0, 1.0, num=steps + 1, endpoint=True
        ):  # right Riemann sum
            emb = bemb.clone() + alpha * (wemb.clone() - bemb.clone())
            # emb.requires_grad = True
            ####
            hook_2 = model.bert.transformer.layer[
                0
            ].output_layer_norm.register_forward_hook(hook_func_2)
            _, outputs = model(
                inputs=z_input_ids, is_ids=True, attention_mask=z_attention_mask
            )
            hook_2.remove()
            ####
            ce_loss_gt = F.cross_entropy(outputs, z_label)
            z_hack_loss = torch.cat(
                [
                    (p**2).view(-1)
                    for n, p in model.named_parameters()
                    if (
                        (not any(nd in n for nd in no_decay))
                        and (p.requires_grad == True)
                    )
                ]
            ).sum() * (L2_LAMBDA)
            ####
            model.zero_grad()

            grad_tuple_ = torch.autograd.grad(
                outputs=ce_loss_gt + z_hack_loss,
                inputs=[
                    param
                    for name, param in model.named_parameters()
                    if param.requires_grad
                ],
                create_graph=True,
            )
            # model.zero_grad()
            dot = nn.utils.parameters_to_vector(
                s
            ).detach() @ nn.utils.parameters_to_vector(
                grad_tuple_
            )  # scalar
            grad_grad_tuple_ = torch.autograd.grad(
                outputs=dot, inputs=[my_hook_2["out"]], only_inputs=True
            )
            ####
            if alpha == 0:
                hessian = grad_grad_tuple_[0]
                influence_prime = [-torch.sum(x * y) for x, y in zip(s, grad_tuple_)]
                influence_prime = sum(influence_prime).item()
            elif alpha == 1.0:
                break
            else:
                hessian += grad_grad_tuple_[0]
        ####
        hessian = hessian / steps
        ####
        influence = [-torch.sum(x * y) for x, y in zip(s, grad_tuple_)]
        influence = sum(influence).item()
        ####
        result = hessian * (wemb - bemb)
        theta = torch.sum(result).detach().cpu().numpy()
        attributions = torch.sum(result, dim=-1)[0].detach().cpu().numpy()

        outputs = {
            "index": idx,
            "sentence": row["sentence"],
            "label": row["label"],
            "prob": prob.detach().cpu().numpy()[0],
            "prediction": prediction.detach().cpu().numpy()[0],
            "influence_prime": influence_prime,
            "influence": influence,
            "diff": influence_prime - influence,
            "theta": theta,
            "tokens": tokens,
            "attributions": attributions,
            "repetition": repetition,
            "time_elapsed": time_elapsed,
        }

        output_collections.append(outputs)
        ####
        break

In [5]:
df = pd.DataFrame(output_collections)
df = df.rename({'index':'sample_index'}, axis=1)
df = df.sort_values('influence', ascending=False).reset_index(drop=True)
df['percentage'] = ((df.index//(df.shape[0]/10)).astype(int)+1)*10
# save data without top m% memorized examples to data/{m}.csv
for i in range(0, 100, 10):
    df.loc[df['percentage']>i, ['label','sentence','sample_index']].to_csv(f'data/{i}.csv', index=False)