In [10]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from datasets import load_dataset, load_from_disk
import evaluate
import numpy as np
import torch
import os

model_dir = "./saved_models/pii_eraser"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForTokenClassification.from_pretrained(model_dir)
model.eval()

dataset_name = "ai4privacy/pii-masking-400k"
dataset_path = "saved_datasets/ai4privacy_pii-masking-400k"

if os.path.exists(dataset_path):
    dataset = load_from_disk(dataset_path)
else:
    dataset = load_dataset(dataset_name)

unique_labels = set()
for row in dataset["train"]["mbert_token_classes"]:
    unique_labels.update(row)
label_list = sorted(unique_labels)
label_to_id = {label: i for i, label in enumerate(label_list)}
id_to_label = {i: label for i, label in enumerate(label_list)}

val_samples = dataset["validation"].select(range(100))

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["mbert_tokens"],
        is_split_into_words=True,
        truncation=True,
        return_tensors=None
    )
    labels = []
    for i, label in enumerate(examples["mbert_token_classes"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        label_ids = []
        prev_word_id = None
        for word_id in word_ids:
            if word_id is None:
                label_ids.append(-100)
            elif word_id != prev_word_id:
                label_ids.append(label_to_id[label[word_id]])
            else:
                label_ids.append(label_to_id[label[word_id]])
            prev_word_id = word_id
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

tokenized = val_samples.map(tokenize_and_align_labels, batched=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Map: 100%|██████████| 100/100 [00:00<00:00, 1353.48 examples/s]


BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12

In [11]:
all_preds = []
all_labels = []

for example in tokenized:
    input_ids = torch.tensor([example["input_ids"]]).to(device)
    attention_mask = torch.tensor([example["attention_mask"]]).to(device)
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits
    preds = torch.argmax(logits, dim=-1).squeeze().cpu().numpy()
    labels = np.array(example["labels"])

    word_preds = [id_to_label[p] for (p, l) in zip(preds, labels) if l != -100]
    word_labels = [id_to_label[l] for (p, l) in zip(preds, labels) if l != -100]

    all_preds.append(word_preds)
    all_labels.append(word_labels)

metric = evaluate.load("seqeval")
results = metric.compute(predictions=all_preds, references=all_labels)
print("Evaluation Metrics:")
for k, v in results.items():
    print(f"{k}: {v}")


Evaluation Metrics:
ACCOUNTNUM: {'precision': np.float64(0.6153846153846154), 'recall': np.float64(0.6666666666666666), 'f1': np.float64(0.64), 'number': np.int64(12)}
BUILDINGNUM: {'precision': np.float64(0.927536231884058), 'recall': np.float64(0.9411764705882353), 'f1': np.float64(0.9343065693430658), 'number': np.int64(68)}
CITY: {'precision': np.float64(0.8947368421052632), 'recall': np.float64(0.918918918918919), 'f1': np.float64(0.9066666666666667), 'number': np.int64(37)}
CREDITCARDNUMBER: {'precision': np.float64(0.76), 'recall': np.float64(1.0), 'f1': np.float64(0.8636363636363636), 'number': np.int64(19)}
DATEOFBIRTH: {'precision': np.float64(1.0), 'recall': np.float64(0.972972972972973), 'f1': np.float64(0.9863013698630138), 'number': np.int64(37)}
DRIVERLICENSENUM: {'precision': np.float64(1.0), 'recall': np.float64(0.6666666666666666), 'f1': np.float64(0.8), 'number': np.int64(9)}
EMAIL: {'precision': np.float64(1.0), 'recall': np.float64(0.9818181818181818), 'f1': np.flo

In [20]:
def predict_pii(sentence: str):
    encoding = tokenizer(sentence, return_tensors="pt", truncation=True, is_split_into_words=False)
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1).squeeze().cpu().numpy()

    tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze().cpu().numpy())
    labels = [id_to_label[pred] for pred in predictions]

    print("TOKEN".ljust(20), "PREDICTED LABEL")
    print("-" * 40)
    for token, label in zip(tokens, labels):
        print(f"{token.ljust(20)} {label}")

predict_pii("My SSN is 123-456-7890")
predict_pii("My name is Daniel Bhatti")
predict_pii("Da Gong is my boss - he's really great!")
predict_pii("My SSN is 1234567890")
predict_pii("My name is John Jacob Jingleheimersmith")
predict_pii("Please send me Alice's loan with id 832640173.  I need to call her at 947-2792-1849 tomorrow morning.")


TOKEN                PREDICTED LABEL
----------------------------------------
[CLS]                O
My                   O
SS                   O
##N                  O
is                   O
123                  B-SOCIALNUM
-                    I-SOCIALNUM
45                   I-SOCIALNUM
##6                  I-SOCIALNUM
-                    I-SOCIALNUM
78                   I-SOCIALNUM
##90                 I-SOCIALNUM
[SEP]                O
TOKEN                PREDICTED LABEL
----------------------------------------
[CLS]                O
My                   O
name                 O
is                   O
Daniel               B-GIVENNAME
B                    B-SURNAME
##hat                B-SURNAME
##ti                 B-SURNAME
[SEP]                O
TOKEN                PREDICTED LABEL
----------------------------------------
[CLS]                O
Da                   B-GIVENNAME
Gong                 I-SURNAME
is                   O
my                   O
boss                 O
