In [11]:
import re
import torch
from transformers import DistilBertTokenizerFast, DistilBertForTokenClassification, TokenClassificationPipeline

POLICIES = {
    "EMAIL":      lambda x: re.sub(r"(.{2}).*@(.+)", r"\1***@\2", x),
    "PHONE":      lambda x: "***-***-" + re.sub(r"\D", "", x)[-4:],
    "CREDITCARD": lambda x: "**** **** **** " + re.sub(r"\D", "", x)[-4:],
    "SSN":        lambda x: "***-**-" + x[-4:],
    "APIKEY":     lambda x: "[REDACTED-API]",
    "PERSON":     lambda x: "[PERSON]",
    "ADDRESS":    lambda x: "[ADDRESS]"
}

def apply_policy(entity_group, token):
    base = entity_group.replace("B-","").replace("I-","")
    func = POLICIES.get(base, lambda x: x)
    return func(token)

# ---------- Rule-based pre-pass (fallbacks NER might miss) ----------
_RULES = [
    # Emails (tolerate spaces around @ and .)
    (re.compile(r'([A-Za-z0-9._%+-]{2})[A-Za-z0-9._%+-]*\s*@\s*([A-Za-z0-9.-]+(?:\s*\.\s*[A-Za-z]{2,})+)'),
     lambda m: f"{m.group(1)}***@{re.sub(r'\\s*\\.\\s*', '.', m.group(2).replace(' ', ''))}"),
    # SSN
    (re.compile(r'\b\d{3}-\d{2}-\d{4}\b'), lambda m: "***-**-" + m.group(0)[-4:]),
    # Credit cards (13â€“16 digits with spaces/dashes)
    (re.compile(r'\b(?:\d[ -]*?){13,16}\b'), lambda m: "**** **** **** " + re.sub(r'\D','', m.group(0))[-4:]),
    # US phones
    (re.compile(r'\b(?:\+?\d{1,3}[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?){2}\d{4}\b'),
     lambda m: "***-***-" + re.sub(r'\D','', m.group(0))[-4:]),
    # Stripe-like API keys
    (re.compile(r'\bsk_(?:live|test)_[A-Za-z0-9]{16,}\b'), lambda m: "[REDACTED-API]"),
]

def rule_based_redact(text: str) -> str:
    out = text
    for pat, repl in _RULES:
        out = pat.sub(repl, out)
    return out

# ---------- Span-safe NER redaction ----------
def redact_output(text, ner_preds):
    out = rule_based_redact(text)

    # Collect spans (start, end, label), drop obvious junk
    spans = []
    for p in ner_preds:
        s, e = int(p["start"]), int(p["end"])
        if 0 <= s < e <= len(text):
            spans.append((s, e, p["entity_group"]))

    # Apply right-to-left to avoid index shift; also skip if policy returns identical string
    for s, e, label in sorted(spans, key=lambda t: t[0], reverse=True):
        original = out[s:e]
        red = apply_policy(label, original)
        if red != original:
            out = out[:s] + red + out[e:]
    return out

# ---------- Load model locally ----------
from google.colab import drive
drive.mount('/content/drive')
import os, pprint
PROJECT_DIR = "/content/drive/MyDrive/DataShield_AI"
MODEL_DIR   = f"{PROJECT_DIR}/models/ner-distilbert"
print("Exists?", os.path.isdir(MODEL_DIR))
pprint.pp(os.listdir(MODEL_DIR))

tok = DistilBertTokenizerFast.from_pretrained(MODEL_DIR, local_files_only=True)
mdl = DistilBertForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)

device = 0 if torch.cuda.is_available() else -1
ner = TokenClassificationPipeline(model=mdl, tokenizer=tok,
                                  aggregation_strategy="simple", device=device)

samples = [
    "Email me at alice.lee @ example . com or call 415-555-1234. My SSN is 123-45-6789.",
    "Temp API key: sk_live_1234567890abcdef",
    "Use 4242-4242-4242-4242 for tests; SSN 123-45-6789 should be blocked."
]

for text in samples:
    preds = ner(text)
    print(redact_output(text, preds))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Exists? True
['checkpoint-1126',
 'checkpoint-1689',
 'config.json',
 'model.safetensors',
 'special_tokens_map.json',
 'vocab.txt',
 'labels.json',
 'training_args.bin',
 'tokenizer.json',
 'tokenizer_config.json']


Device set to use cpu


Email me at al***@example.com or call ***-***-***-***-1234 i[PERSON]**-6***-***-789
[PERSON]mp [PERSON] key: [REDACTED-API]EDACT[REDACTED-API]
Use ***-***-***-***-4242 for tests; [PERSON]***-***- ***-***-6789 should be blocked.
