In [5]:
def extract_entities(text):
    pattern = re.compile(r'<(.*?)>(.*?)</\1>')
    spans = []
    clean_text = ""
    last_idx = 0
    for match in pattern.finditer(text):
        start, end = match.span()
        tag, entity = match.groups()
        clean_text += text[last_idx:start]
        spans.append({
            "start": len(clean_text),
            "end": len(clean_text) + len(entity),
            "label": tag.lower()
        })
        clean_text += entity
        last_idx = end
    clean_text += text[last_idx:]
    return clean_text, spans

def align_labels(text, spans):
    labels = ["O"] * len(text)
    for span in spans:
        for i in range(span["start"], span["end"]):
            if i == span["start"]:
                labels[i] = f"B-{span['label']}"
            else:
                labels[i] = f"I-{span['label']}"
    return labels


In [6]:
def tokenize_and_align_labels(text, labels):
    encoding = tokenizer(text, truncation=True, padding="max_length", max_length=128, return_offsets_mapping=True)
    encoded_labels = []
    offset_mapping = encoding.pop("offset_mapping")
    for offsets in offset_mapping:
        if offsets[0] == offsets[1]:
            encoded_labels.append("O")
        else:
            start, end = offsets
            encoded_labels.append(labels[start] if start < len(labels) else "O")
    return encoding, encoded_labels


In [20]:
def load_and_process_dataset():
    data = []
    label_set = set()
    truncated = 0

    for file in sorted(os.listdir(REDACTED_DIR)):
        if not file.endswith("_pii_masked.txt"):
            continue

        index = file.replace("_pii_masked.txt", "")
        ocr_path = os.path.join(OCR_DIR, f"{index}_ocr_text.txt")
        redacted_path = os.path.join(REDACTED_DIR, file)

        if not os.path.exists(ocr_path):
            print(f"Skipping {index}, OCR file missing")
            continue

        try:
            with open(redacted_path, "r", encoding="utf-8") as f:
                redacted_text = f.read()

            clean_text, spans = extract_entities(redacted_text)  
            char_labels = align_labels(clean_text, spans)        

            encoding, token_labels = tokenize_and_align_labels(clean_text, char_labels)

            tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"])
            label_set.update(token_labels)

            # Track truncated samples
            if sum(encoding["attention_mask"]) == tokenizer.model_max_length:
                truncated += 1

            data.append({
                "input_ids": encoding["input_ids"],
                "attention_mask": encoding["attention_mask"],
                "labels": token_labels,
                "tokens": tokens
            })
        except Exception as e:
            print(f"Error processing {index}: {e}")

    print(f"Total truncated examples (>= {tokenizer.model_max_length} tokens): {truncated}")

    all_labels = sorted(list(label_set))
    label2id = {label: i for i, label in enumerate(all_labels)}
    id2label = {i: label for label, i in label2id.items()}

    for item in data:
        item["labels"] = [label2id.get(lbl, 0) for lbl in item["labels"]]

    train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)

    dataset = DatasetDict({
        "train": Dataset.from_list(train_data),
        "validation": Dataset.from_list(val_data)
    })

    return dataset, label2id, id2label


In [None]:
import os
import re
import json
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification
import torch

BASE_DIR = "../Text_PII/Dataset"
OCR_DIR = os.path.join(BASE_DIR, "OCR_Text")
REDACTED_DIR = os.path.join(BASE_DIR, "PII_mapped_Text_elements")
JSON_DIR = os.path.join(BASE_DIR, "PII_output_gemini")

MODEL_NAME = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

dataset, label2id, id2label = load_and_process_dataset()

model = AutoModelForTokenClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(label2id),
    id2label=id2label,
    label2id=label2id
)

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_total_limit=2,
    report_to="none",
    disable_tqdm=False,  # Enable tqdm bars
    logging_strategy="steps",  # Ensure logging happens during training
)

data_collator = DataCollatorForTokenClassification(tokenizer)

def compute_metrics(p):
    predictions, labels = p
    preds = predictions.argmax(-1)
    correct = (preds == labels) & (labels != -100)
    total = (labels != -100)
    accuracy = correct.sum() / total.sum()
    return {"accuracy": accuracy.item()}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()

trainer.save_model("./pii-redaction-model")


Total truncated examples (>= 512 tokens): 0


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy
1,0.5279,0.51876,0.826414
2,0.394,0.424496,0.851758


In [4]:
import os
print(os.getcwd())
!pip install 'accelerate>=0.26.0'
!pip install torch datasets transformers
os.environ["TOKENIZERS_PARALLELISM"] = "false"

/home/ec2-user/SageMaker/PII_Redaction/Text_PII


In [None]:
sample = dataset["validation"][0]
with torch.no_grad():
    input_ids = torch.tensor([sample["input_ids"]])
    attention_mask = torch.tensor([sample["attention_mask"]])
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)[0].tolist()

print("\nSample Predictions:")
for token_id, pred_id in zip(sample["input_ids"], predictions):
    token = tokenizer.decode([token_id])
    label = id2label[pred_id]
    print(f"{token}\t->\t{label}")

In [None]:
import os
BASE_DIR = "/home/ec2-user/SageMaker/PII_Redaction/Text_PII/Dataset/OCR_Text"
folder_path = "../Text_PII/Dataset/PII_output_gemini" 
num_files = len([f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))])
print(f"Number of files: {num_files}")

In [None]:
!pip uninstall -y pandas
!pip install --upgrade --force-reinstall --no-cache-dir pandas==2.2.2
    
dataset, label2id, id2label = load_and_process_dataset()
original_texts = [" ".join(example["tokens"]) for example in dataset["train"]]

token_lengths = [len(tokenizer(text, padding=False, truncation=False)["input_ids"]) for text in original_texts]

average_original_length = sum(token_lengths) / len(token_lengths)
max_original_length = max(token_lengths)

print(f"Average token length before truncation: {average_original_length:.2f}")
print(f"Max token length before truncation: {max_original_length}")

In [None]:
%reset -f

import torch
torch.cuda.empty_cache()

import gc
gc.collect()