In [1]:
import os
import numpy as np
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm import tqdm
import evaluate
import wandb
import datasets




In [2]:
dataset = load_dataset("ai4privacy/pii-masking-200k")

# Filter to keep only English samples
print("Filtering dataset for English samples only...")
dataset_en = dataset.filter(lambda example: example["language"] == "en")

# Split the train dataset into train and validation
train_test_split = dataset_en["train"].train_test_split(test_size=0.2, seed=42)

# Create a new DatasetDict with the splits
dataset = datasets.DatasetDict({
    "train": train_test_split["train"],
    "validation": train_test_split["test"]  # The "test" split from train_test_split becomes our validation
})

print(f"Train size: {len(dataset['train'])}")
print(f"Validation size: {len(dataset['validation'])}")

Filtering dataset for English samples only...
Train size: 34800
Validation size: 8701


In [3]:
sample = dataset["train"][0]
for key, value in sample.items():
    print(f"{key}: {value}")
print(f"Number of samples: {len(dataset['train'])}")

source_text: This is to notify you about the changes in your tax plan. The state of Basel-Landschaft and county of Lake County have revised laws affecting businesses. Will connect with Legacy Branding Associate soon. Check in with https://basic-reparation.com/ for details.
target_text: This is to notify you about the changes in your tax plan. The state of [STATE] and county of [COUNTY] have revised laws affecting businesses. Will connect with [JOBTITLE] soon. Check in with [URL] for details.
privacy_mask: [{'value': 'Basel-Landschaft', 'start': 71, 'end': 87, 'label': 'STATE'}, {'value': 'Lake County', 'start': 102, 'end': 113, 'label': 'COUNTY'}, {'value': 'Legacy Branding Associate', 'start': 172, 'end': 197, 'label': 'JOBTITLE'}, {'value': 'https://basic-reparation.com/', 'start': 218, 'end': 247, 'label': 'URL'}]
span_labels: [[0, 71, "O"], [71, 87, "STATE"], [87, 102, "O"], [102, 113, "COUNTY"], [113, 172, "O"], [172, 197, "JOBTITLE"], [197, 218, "O"], [218, 247, "URL"], [247, 260

In [4]:
dataset["train"].features

{'source_text': Value(dtype='string', id=None),
 'target_text': Value(dtype='string', id=None),
 'privacy_mask': [{'value': Value(dtype='string', id=None),
   'start': Value(dtype='int64', id=None),
   'end': Value(dtype='int64', id=None),
   'label': Value(dtype='string', id=None)}],
 'span_labels': Value(dtype='string', id=None),
 'mbert_text_tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
 'mbert_bio_labels': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
 'id': Value(dtype='int64', id=None),
 'language': Value(dtype='string', id=None),
 'set': Value(dtype='string', id=None)}

In [8]:
# Create label mappings from the dataset
# First, extract all unique labels from the privacy_mask field
print("Extracting unique labels...")
entity_labels = set()
for example in tqdm(dataset["train"]):
    for mask in example["privacy_mask"]:
        entity_labels.add(mask["label"])

# Create BIO tags for each entity type
label_list = ["O"]  # Outside tag
for entity in sorted(entity_labels):
    label_list.append(f"B-{entity}")  # Beginning tag
    label_list.append(f"I-{entity}")  # Inside tag

label2id = {label: i for i, label in enumerate(label_list)}
id2label = {i: label for i, label in enumerate(label_list)}

print(f"Number of labels: {len(label_list)}")
print("Labels:", label_list[:10], "..." if len(label_list) > 10 else "")

Extracting unique labels...


100%|██████████| 34800/34800 [00:07<00:00, 4933.55it/s]

Number of labels: 113
Labels: ['O', 'B-ACCOUNTNAME', 'I-ACCOUNTNAME', 'B-ACCOUNTNUMBER', 'I-ACCOUNTNUMBER', 'B-AGE', 'I-AGE', 'B-AMOUNT', 'I-AMOUNT', 'B-BIC'] ...





In [6]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load MobileBERT tokenizer and model
model_name = "google/mobilebert-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForTokenClassification.from_pretrained(
#     model_name,
#     num_labels=len(label_list),
#     id2label=id2label,
#     label2id=label2id
# )
# model.to(device);

Using device: cuda


In [9]:
# Function to process dataset with MobileBERT tokenizer
def prepare_with_mobilebert_tokenizer(examples, tokenizer=tokenizer, max_length=256, label2id=label2id):
    """
    Tokenize text with MobileBERT's tokenizer and align labels based on character offsets.
    
    This function:
    1. Takes the raw text from source_text
    2. Uses the privacy_mask information to build a character-level entity map
    3. Tokenizes with MobileBERT's tokenizer and gets character offsets
    4. Aligns entity labels with the new tokens
    """
    # Get the original text
    texts = examples["source_text"]
    
    # Extract character-level entity information from privacy_mask
    entity_maps = []
    for example_masks in examples["privacy_mask"]:
        entity_map = {}
        for mask in example_masks:
            for pos in range(mask["start"], mask["end"]):
                entity_map[pos] = mask["label"]
        entity_maps.append(entity_map)
    
    # Tokenize with MobileBERT's tokenizer and get character offsets
    tokenized = tokenizer(
        texts, 
        truncation=True,
        max_length=max_length,
        padding=False,
        return_offsets_mapping=True
    )
    
    # Align labels with new tokens
    labels = []
    for i, offset_mapping in enumerate(tokenized.pop("offset_mapping")):
        label_ids = []
        entity_map = entity_maps[i]
        
        previous_entity = None
        for j, (start, end) in enumerate(offset_mapping):
            # Skip special tokens which have empty offsets (0,0)
            if start == end == 0:
                label_ids.append(-100)
                continue
                
            # Find if this token overlaps with any entity
            current_entity = None
            for pos in range(start, end):
                if pos in entity_map:
                    current_entity = entity_map[pos]
                    break
            
            # Determine if this is a beginning or inside token
            if current_entity is None:
                # Not an entity
                label_ids.append(label2id["O"])
                previous_entity = None
            elif previous_entity != current_entity:
                # Beginning of entity or new entity
                label_ids.append(label2id[f"B-{current_entity}"])
                previous_entity = current_entity
            else:
                # Continuation of the entity
                label_ids.append(label2id[f"I-{current_entity}"])
        
        labels.append(label_ids)
    
    tokenized["labels"] = labels
    return tokenized

tokenized_datasets = dataset.map(
    prepare_with_mobilebert_tokenizer,
    batched=True,
    batch_size=32,
    remove_columns=dataset["train"].column_names,
    num_proc=4,
)

In [28]:
sample = tokenized_datasets["train"][0]
for key, value in sample.items():
    print(f"{key}: {value}")
print(f"Number of samples: {len(dataset['train'])}")

input_ids: [101, 2023, 2003, 2000, 2025, 8757, 2017, 2055, 1996, 3431, 1999, 2115, 4171, 2933, 1012, 1996, 2110, 1997, 14040, 1011, 4915, 29043, 1998, 2221, 1997, 2697, 2221, 2031, 8001, 4277, 12473, 5661, 1012, 2097, 7532, 2007, 8027, 16140, 5482, 2574, 1012, 4638, 1999, 2007, 16770, 1024, 1013, 1013, 3937, 1011, 16360, 25879, 3258, 1012, 4012, 1013, 2005, 4751, 1012, 102]
token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
attention_mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
labels: [-100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 95, 96, 96, 96, 0, 0, 0, 19, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 61, 62, 62, 0, 0, 0, 0, 0, 101, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 0, 0, 0, -100]
N

In [29]:
# Sanity check for tokenized dataset
def sanity_check_tokenized_dataset(tokenized_dataset, tokenizer, id2label, num_samples=5):
    """
    Perform sanity checks on the tokenized dataset to ensure proper alignment 
    between tokens and labels.
    
    Args:
        tokenized_dataset: The dataset after tokenization and processing
        tokenizer: The tokenizer used to process the dataset
        id2label: Mapping from label IDs to label names
        num_samples: Number of samples to check
    """
    print(f"Performing sanity check on {num_samples} samples...")
    
    # Get a subset of samples
    samples = tokenized_dataset["train"].select(range(min(num_samples, len(tokenized_dataset["train"]))))
    
    for i, sample in enumerate(samples):
        print(f"\n--- Sample {i+1} ---")
        
        # Get input IDs and labels
        input_ids = sample["input_ids"]
        labels = sample["labels"]
        
        # Decode tokens
        tokens = tokenizer.convert_ids_to_tokens(input_ids)
        
        # Print alignment of tokens and labels
        print("Token\tLabel")
        print("-" * 30)
        
        for token, label_id in zip(tokens, labels):
            # Skip tokens with label -100 (special tokens)
            if label_id == -100:
                label_text = "SPECIAL"
            else:
                label_text = id2label[label_id]
                
            print(f"{token}\t{label_text}")
        
        # Check label distribution
        label_counts = {}
        for label_id in labels:
            if label_id != -100:
                label_name = id2label[label_id]
                label_counts[label_name] = label_counts.get(label_name, 0) + 1
        
        print("\nLabel distribution:")
        for label, count in label_counts.items():
            print(f"{label}: {count}")
        
        # Check if special tokens have -100 labels
        special_token_ids = [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]
        for idx, input_id in enumerate(input_ids):
            if input_id in special_token_ids and labels[idx] != -100:
                print(f"WARNING: Special token {tokenizer.convert_ids_to_tokens([input_id])[0]} has label {id2label[labels[idx]]} instead of -100")
    
    # Check overall dataset statistics
    print("\n--- Overall Dataset Statistics ---")
    print(f"Number of training examples: {len(tokenized_dataset['train'])}")
    print(f"Number of validation examples: {len(tokenized_dataset['validation'])}")
    
    # Check token length distribution
    train_lengths = [len(x["input_ids"]) for x in tokenized_dataset["train"]]
    print(f"Average token length (train): {sum(train_lengths) / len(train_lengths):.2f}")
    print(f"Max token length (train): {max(train_lengths)}")
    print(f"Min token length (train): {min(train_lengths)}")
    
    # Check label distribution in training set
    all_labels = []
    for example in tokenized_dataset["train"]:
        for label in example["labels"]:
            if label != -100:
                all_labels.append(label)
    
    label_count = {}
    for label in all_labels:
        label_name = id2label[label]
        label_count[label_name] = label_count.get(label_name, 0) + 1
    
    print("\nTop 10 most common labels in training set:")
    sorted_labels = sorted(label_count.items(), key=lambda x: x[1], reverse=True)
    for label, count in sorted_labels[:10]:
        print(f"{label}: {count} ({count / len(all_labels) * 100:.2f}%)")
    
    print("\nLeast common labels in training set:")
    for label, count in sorted_labels[-10:]:
        print(f"{label}: {count} ({count / len(all_labels) * 100:.2f}%)")


# Perform sanity check on the tokenized dataset
sanity_check_tokenized_dataset(tokenized_datasets, tokenizer, id2label, num_samples=5)

Performing sanity check on 5 samples...

--- Sample 1 ---
Token	Label
------------------------------
[CLS]	SPECIAL
this	O
is	O
to	O
not	O
##ify	O
you	O
about	O
the	O
changes	O
in	O
your	O
tax	O
plan	O
.	O
the	O
state	O
of	O
basel	B-STATE
-	I-STATE
lands	I-STATE
##chaft	I-STATE
and	O
county	O
of	O
lake	B-COUNTY
county	I-COUNTY
have	O
revised	O
laws	O
affecting	O
businesses	O
.	O
will	O
connect	O
with	O
legacy	B-JOBTITLE
branding	I-JOBTITLE
associate	I-JOBTITLE
soon	O
.	O
check	O
in	O
with	O
https	B-URL
:	I-URL
/	I-URL
/	I-URL
basic	I-URL
-	I-URL
rep	I-URL
##arat	I-URL
##ion	I-URL
.	I-URL
com	I-URL
/	I-URL
for	O
details	O
.	O
[SEP]	SPECIAL

Label distribution:
O: 37
B-STATE: 1
I-STATE: 3
B-COUNTY: 1
I-COUNTY: 1
B-JOBTITLE: 1
I-JOBTITLE: 2
B-URL: 1
I-URL: 11

--- Sample 2 ---
Token	Label
------------------------------
[CLS]	SPECIAL
dear	O
drew	B-FIRSTNAME
,	O
your	O
mentor	O
##ship	O
program	O
application	O
is	O
tentatively	O
approved	O
!	O
please	O
confirm	O
if	O
all	O
details	O
includin

In [11]:
def convert_to_binary_classification(examples):
    labels = examples["labels"]
    binary_labels = []
    
    for example_labels in labels:
        example_binary_labels = []
        for label in example_labels:
            if label == -100:  # Keep special token labels intact
                example_binary_labels.append(-100)
            elif label == label2id["O"]:  # Keep "O" label intact
                example_binary_labels.append(0)  # 0 for non-PII
            else:  # All PII types become a single class
                example_binary_labels.append(1)  # 1 for PII
                
        binary_labels.append(example_binary_labels)
    
    examples["labels"] = binary_labels
    return examples

# Create a binary version of the dataset
binary_id2label = {0: "O", 1: "PII"}
binary_label2id = {"O": 0, "PII": 1}

# Apply the conversion
binary_tokenized_datasets = tokenized_datasets.map(
    convert_to_binary_classification,
    batched=True
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Update model configuration for binary classification
binary_model = AutoModelForTokenClassification.from_pretrained(
    model_name,
    num_labels=2,
    id2label=binary_id2label,
    label2id=binary_label2id
)
binary_model.to(device);

Using device: cuda


Some weights of MobileBertForTokenClassification were not initialized from the model checkpoint at google/mobilebert-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
sample = binary_tokenized_datasets["train"][0]
sample1 = dataset["train"][0]
for key, value in sample.items():
    print(f"{key}: {value}")
print(f'Source text: {sample1["source_text"]}')
untokenized = tokenizer.convert_ids_to_tokens(sample["input_ids"])
print(f"Untokenized: {untokenized}")
print(f"Number of samples: {len(dataset['train'])}")

input_ids: [101, 2023, 2003, 2000, 2025, 8757, 2017, 2055, 1996, 3431, 1999, 2115, 4171, 2933, 1012, 1996, 2110, 1997, 14040, 1011, 4915, 29043, 1998, 2221, 1997, 2697, 2221, 2031, 8001, 4277, 12473, 5661, 1012, 2097, 7532, 2007, 8027, 16140, 5482, 2574, 1012, 4638, 1999, 2007, 16770, 1024, 1013, 1013, 3937, 1011, 16360, 25879, 3258, 1012, 4012, 1013, 2005, 4751, 1012, 102]
token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
attention_mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
labels: [-100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, -100]
Source text: This is to notify you

In [13]:
seqeval = evaluate.load("seqeval")

# def compute_metrics(p):
#     predictions, labels = p
#     predictions = np.argmax(predictions, axis=2)

#     true_predictions = [
#         [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
#         for prediction, label in zip(predictions, labels)
#     ]
#     true_labels = [
#         [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
#         for prediction, label in zip(predictions, labels)
#     ]

#     results = seqeval.compute(predictions=true_predictions, references=true_labels)
#     return {
#         "precision": results["overall_precision"],
#         "recall": results["overall_recall"],
#         "f1": results["overall_f1"],
#         "accuracy": results["overall_accuracy"],
#     }

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    true_predictions = [
        [binary_id2label[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [binary_id2label[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [14]:
import wandb
import os
wandb.login()

os.environ["WANDB_PROJECT"] = "pii_NER"  # Name of your W&B project
os.environ["WANDB_LOG_MODEL"] = "checkpoint"       # Options: 'checkpoint', 'end', or 'false'
os.environ["WANDB_WATCH"] = "all"                  # Options: 'gradients', 'all', or 'false'

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mchainathanss[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [17]:
# Data collator
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding="longest")

run_name = "mobilebert-pii-binary3"

# Define training arguments
training_args = TrainingArguments(
    output_dir=f"./results/{run_name}",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    report_to=["wandb"],
    push_to_hub=False,
    run_name=run_name,
    # gradient_checkpointing=True,
    fp16=True,  # Enable mixed precision training if supported
)

# Initialize Trainer
trainer = Trainer(
    model=binary_model,
    args=training_args,
    train_dataset=binary_tokenized_datasets["train"],
    eval_dataset=binary_tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

# Train the model
print("Training model...")
trainer.train()

# Evaluate the model
print("Evaluating model...")
evaluation_results = trainer.evaluate()
print(f"Evaluation results: {evaluation_results}")

# Save model
model_save_path = f"./results/{run_name}/final_model"
if not os.path.exists(model_save_path):
    os.makedirs(model_save_path)
trainer.save_model(model_save_path)
tokenizer.save_pretrained(model_save_path)
print(f"Model saved to {model_save_path}")

  trainer = Trainer(


Training model...


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,1.2581,1.248093,0.01227,0.029015,0.017247,0.420666
2,1.2537,1.248093,0.01227,0.029015,0.017247,0.420666
3,1.2645,1.248093,0.01227,0.029015,0.017247,0.420666
4,1.25,1.248093,0.01227,0.029015,0.017247,0.420666


[34m[1mwandb[0m: Adding directory to artifact (.\results\mobilebert-pii-binary3\checkpoint-2175)... Done. 0.2s
[34m[1mwandb[0m: Adding directory to artifact (.\results\mobilebert-pii-binary3\checkpoint-4350)... Done. 0.2s
[34m[1mwandb[0m: Adding directory to artifact (.\results\mobilebert-pii-binary3\checkpoint-6525)... Done. 0.2s


KeyboardInterrupt: 

In [6]:
# Function to make predictions with the model for new text
def mask_pii_text(text, model, tokenizer, device, id2label):
    """
    Identify and mask PII entities in new text
    """
    # Tokenize input with offset mapping to track character positions
    inputs = tokenizer(text, return_tensors="pt", return_offsets_mapping=True)
    offset_mapping = inputs.pop("offset_mapping")[0]
    
    # Move to device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Get predictions
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.argmax(outputs.logits, dim=2)[0]
    
    # Convert token predictions to entities
    entities = []
    current_entity = None
    start_idx = None
    
    for i, (pred, (token_start, token_end)) in enumerate(zip(predictions, offset_mapping)):
        # Skip special tokens
        if token_start == token_end == 0:
            continue
            
        pred_label = id2label[pred.item()]
        
        # Check if it's beginning of entity
        if pred_label.startswith("B-"):
            # If we were tracking a previous entity, add it to result
            if current_entity:
                entities.append((current_entity, start_idx, prev_end))
            
            # Start new entity
            current_entity = pred_label[2:]  # Remove B- prefix
            start_idx = token_start.item()
            prev_end = token_end.item()
        
        # Check if it's inside an entity
        elif pred_label.startswith("I-") and current_entity == pred_label[2:]:
            # Continue current entity
            prev_end = token_end.item()
        
        # Not an entity or different entity
        elif current_entity:
            # Add previous entity to result
            entities.append((current_entity, start_idx, prev_end))
            current_entity = None
    
    # Add final entity if there is one
    if current_entity:
        entities.append((current_entity, start_idx, prev_end))
    
    # Create masked text
    masked_text = list(text)
    for entity_type, start, end in entities:
        for i in range(start, end):
            masked_text[i] = '*'
        # Insert entity type at the beginning
        masked_text[start:start] = f"[{entity_type}]"
    
    return ''.join(masked_text)

# Example usage (uncomment to test)
# sample_text = "Hello, my name is John Doe and I live in New York. My phone number is 555-123-4567."
# model.to(device)  # Make sure model is on the correct device
# masked_sample = mask_pii_text(sample_text, model, tokenizer)
# print(f"Original: {sample_text}")
# print(f"Masked: {masked_sample}")

In [8]:
# Path to your saved model
model_path = "./results/mobilebert-pii-masking3/final_model"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForTokenClassification.from_pretrained(
    model_path,
    id2label=id2label,
    label2id=label2id
)
model.to(device);

In [10]:
sample_text = "Hello, my name is John Doe and I live in New York. My phone number is 555-123-4567."
masked_sample = mask_pii_text(sample_text, model, tokenizer, device, id2label)
print(f"Original: {sample_text}")
print(f"Masked: {masked_sample}")

Original: Hello, my name is John Doe and I live in New York. My phone number is 555-123-4567.
Masked: Hello[S[ZI[SSN][SSN]****E]*[CRE[LITECO[LI[STATE]*[CRED[STATE[ACCOUN[LI[AC[STA[ST[[[LASTNAME]*OB]*SERNAME]*TE]*E]*OUNTNAME]***OINADDRESS]**AME]******DNUMBER]**OINADDRESS]**ADDRESS]*ITCARDNUMBER]*** my name is John Doe and I live in New York. My phone number is 555-123-4567.


In [8]:
import os
import torch
import numpy as np
from transformers import AutoModelForTokenClassification, AutoTokenizer

def load_pii_model(model_path):
    """
    Load the trained PII detection model and tokenizer
    
    Args:
        model_path: Path to the saved model directory
        
    Returns:
        model: Loaded model
        tokenizer: Loaded tokenizer
        id2label: Dictionary mapping from id to label
    """
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForTokenClassification.from_pretrained(model_path)
    
    # Get id2label mapping from model config
    id2label = model.config.id2label
    
    # Move model to appropriate device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    return model, tokenizer, id2label

def detect_pii(text, model, tokenizer, id2label, max_length=256):
    """
    Detect PII in the input text using the trained model
    
    Args:
        text: Input text to analyze
        model: Trained PII detection model
        tokenizer: Tokenizer for the model
        id2label: Dictionary mapping from id to label
        max_length: Maximum sequence length for tokenization
        
    Returns:
        tokens: List of tokens from the input text
        predictions: List of predicted labels for each token
        pii_spans: List of detected PII spans (start, end, label)
    """
    # Tokenize the input text
    inputs = tokenizer(
        text, 
        return_tensors="pt", 
        truncation=True,
        max_length=max_length,
        return_offsets_mapping=True,
        padding="max_length"
    )
    
    # Get token offsets and move tensors to the same device as model
    offset_mapping = inputs.pop("offset_mapping").cpu().numpy()[0]
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Get model predictions
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.argmax(outputs.logits, dim=2)
    
    # Convert predictions to labels
    predicted_labels = [id2label[p.item()] for p in predictions[0]]
    
    # Get text tokens
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    
    # Identify PII spans (start, end, label)
    pii_spans = []
    current_span = None
    
    for i, (token, label, (start, end)) in enumerate(zip(tokens, predicted_labels, offset_mapping)):
        # Skip special tokens like [CLS], [SEP], [PAD]
        if token in tokenizer.special_tokens_map.values() or start == end == 0:
            continue
            
        # For binary classification, we only have "O" and "PII"
        if label == "PII":
            # Start a new span if none exists
            if current_span is None:
                current_span = {"start": start, "label": "PII"}
            # Otherwise, extend the current span
            current_span["end"] = end
        else:  # "O" label or other non-PII label
            # If we were tracking a PII span, finalize it
            if current_span is not None:
                pii_spans.append(current_span)
                current_span = None
    
    # Don't forget to add the last span if we had one
    if current_span is not None:
        pii_spans.append(current_span)
    
    return tokens, predicted_labels, pii_spans

def mask_pii_text(text, pii_spans, mask_char="*"):
    """
    Mask detected PII in the original text
    
    Args:
        text: Original text
        pii_spans: List of PII spans (start, end, label)
        mask_char: Character to use for masking
        
    Returns:
        masked_text: Text with PII masked
    """
    # Convert text to list for easier manipulation
    chars = list(text)
    
    # Apply masks
    for span in pii_spans:
        for i in range(span["start"], span["end"]):
            chars[i] = mask_char
    
    # Join characters back into a string
    masked_text = "".join(chars)
    
    return masked_text

def main():
    # Path to your saved model
    model_path = r"C:\Users\Sai\Documents\Neu\Masters_Project\PerceptionPrivacy\pii_token_classification\results\mobilebert-pii-binary3\checkpoint-8700"
    
    # Load model
    model, tokenizer, id2label = load_pii_model(model_path)
    
    # Example text for inference
    sample_texts = [
        "My name is John Smith and my email is johnsmith@example.com.",
        "Please contact me at 555-123-4567 or visit me at 123 Main Street, New York, NY 10001.",
        "My social security number is 123-45-6789 and my credit card is 4111-1111-1111-1111."
    ]
    
    # Process each text
    for idx, text in enumerate(sample_texts):
        print(f"\nSample {idx+1}: {text}")
        
        # Detect PII
        tokens, predicted_labels, pii_spans = detect_pii(text, model, tokenizer, id2label)
        
        # Print detected PII
        print("Detected PII spans:")
        for span in pii_spans:
            pii_text = text[span["start"]:span["end"]]
            print(f"  - {pii_text} ({span['label']}): positions {span['start']}-{span['end']}")
        
        # Mask the text
        masked_text = mask_pii_text(text, pii_spans)
        print(f"Masked text: {masked_text}")

main()


Sample 1: My name is John Smith and my email is johnsmith@example.com.
Detected PII spans:
  - My (PII): positions 0-2
  - is John Smith (PII): positions 8-21
  - my (PII): positions 26-28
  - johnsmith (PII): positions 38-47
Masked text: ** name ************* and ** email is *********@example.com.

Sample 2: Please contact me at 555-123-4567 or visit me at 123 Main Street, New York, NY 10001.
Detected PII spans:
  - contact me (PII): positions 7-17
  - 45 (PII): positions 29-31
  - visit me (PII): positions 37-45
  - Main Street, New York, (PII): positions 53-75
Masked text: Please ********** at 555-123-**67 or ******** at 123 ********************** NY 10001.

Sample 3: My social security number is 123-45-6789 and my credit card is 4111-1111-1111-1111.
Detected PII spans:
  - My social security (PII): positions 0-18
  - 45 (PII): positions 33-35
  - 678 (PII): positions 36-39
  - my (PII): positions 45-47
Masked text: ****************** number is 123-**-***9 and ** credit card is 411