In [25]:
from google.colab import drive
drive.mount('/content/drive')

import os, pprint
PROJECT_DIR = "/content/drive/MyDrive/DataShield_AI"
MODEL_DIR   = f"{PROJECT_DIR}/models/ner-distilbert"

print("Exists?", os.path.isdir(MODEL_DIR))
pprint.pp(os.listdir(MODEL_DIR))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Exists? True
['checkpoint-1126',
 'checkpoint-1689',
 'config.json',
 'model.safetensors',
 'special_tokens_map.json',
 'vocab.txt',
 'training_args.bin',
 'tokenizer_config.json',
 'tokenizer.json',
 'labels.json']


In [26]:
import os
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"

In [27]:
import torch
from transformers import DistilBertTokenizerFast, DistilBertForTokenClassification
from transformers import TokenClassificationPipeline

# Load tokenizer & model from local folder only
tok = DistilBertTokenizerFast.from_pretrained(MODEL_DIR, local_files_only=True)
mdl = DistilBertForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)

device = 0 if torch.cuda.is_available() else -1
ner = TokenClassificationPipeline(
    model=mdl, tokenizer=tok, aggregation_strategy="simple", device=device
)

text = "My SSN is 123-45-6789 and email is john@example.com"
ner(text)

Device set to use cpu


[{'entity_group': 'PERSON',
  'score': np.float32(0.7593688),
  'word': 'My',
  'start': 0,
  'end': 2},
 {'entity_group': 'PERSON',
  'score': np.float32(0.9343675),
  'word': 'SSN',
  'start': 3,
  'end': 6},
 {'entity_group': 'PHONE',
  'score': np.float32(0.9995421),
  'word': '123 - 45 - 6789',
  'start': 10,
  'end': 21},
 {'entity_group': 'EMAIL',
  'score': np.float32(0.99941206),
  'word': 'j',
  'start': 35,
  'end': 36},
 {'entity_group': 'EMAIL',
  'score': np.float32(0.9958656),
  'word': '##oh',
  'start': 36,
  'end': 38},
 {'entity_group': 'EMAIL',
  'score': np.float32(0.9260991),
  'word': '##n @ example. com',
  'start': 38,
  'end': 51}]

In [28]:
from google.colab import drive
drive.mount('/content/drive')

import os, pprint
PROJECT_DIR = "/content/drive/MyDrive/DataShield_AI"
MODEL_DIR   = f"{PROJECT_DIR}/models/ner-distilbert"

print("Exists?", os.path.isdir(MODEL_DIR))
pprint.pp(os.listdir(MODEL_DIR))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Exists? True
['checkpoint-1126',
 'checkpoint-1689',
 'config.json',
 'model.safetensors',
 'special_tokens_map.json',
 'vocab.txt',
 'training_args.bin',
 'tokenizer_config.json',
 'tokenizer.json',
 'labels.json']


In [29]:
import sys, subprocess, pkgutil
def pipi(cmd): subprocess.check_call([sys.executable, "-m", "pip"] + cmd.split())

# Show current versions
import importlib, pkg_resources
for p in ["transformers","huggingface_hub","tokenizers"]:
    try:
        m = importlib.import_module(p)
        print(p, "→", m.__version__)
    except:
        print(p, "not installed")

transformers → 4.57.1
huggingface_hub → 0.36.0
tokenizers → 0.22.1


In [30]:
RE_EMAIL = re.compile(r'\b[a-zA-Z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b')
RE_PHONE = re.compile(r'\b(?:\+?\d{1,3}[\s-]?)?(?:\(?\d{3}\)?[\s-]?)?\d{3}[\s-]?\d{4}\b')
RE_SSN   = re.compile(r'\b\d{3}-\d{2}-\d{4}\b')
RE_CC    = re.compile(r'\b(?:\d[ -]*?){13,19}\b')

SEVERITY = {"SSN":"HIGH","CREDITCARD":"HIGH","APIKEY":"HIGH",
            "EMAIL":"MEDIUM","PHONE":"MEDIUM","ADDRESS":"MEDIUM","PERSON":"LOW"}

def normalize_email(s):
    s = re.sub(r'\s*@\s*', '@', s)
    s = re.sub(r'\s*\.\s*', '.', s)
    return s

def span_overlap(a, b):
    return not (a[1] <= b[0] or b[1] <= a[0])

In [31]:
def ner_entities(text):
    out = []
    for r in ner(text):
        s,e = r["start"], r["end"]
        lab = r["entity_group"]
        chunk = text[s:e]
        if lab == "EMAIL": chunk = normalize_email(chunk)
        out.append({"start": s, "end": e, "label": lab, "score": float(r["score"]), "source": "NER"})
    return out

def rule_entities(text):
    spans = []
    for m in RE_SSN.finditer(text):
        spans.append({"start": m.start(), "end": m.end(), "label":"SSN", "score":1.0, "source":"RULE"})
    for m in RE_CC.finditer(text):
        spans.append({"start": m.start(), "end": m.end(), "label":"CREDITCARD", "score":1.0, "source":"RULE"})
    for m in RE_EMAIL.finditer(text):
        spans.append({"start": m.start(), "end": m.end(), "label":"EMAIL", "score":1.0, "source":"RULE"})
    for m in RE_PHONE.finditer(text):
        spans.append({"start": m.start(), "end": m.end(), "label":"PHONE", "score":1.0, "source":"RULE"})
    return spans

PRIORITY = {"RULE": 2, "NER": 1}

def resolve_overlaps(spans):
    def key(sp):
        sev = {"HIGH":3,"MEDIUM":2,"LOW":1}[SEVERITY.get(sp["label"], "LOW")]
        return (PRIORITY.get(sp["source"],0), sev, sp["score"], sp["end"]-sp["start"])
    spans = sorted(spans, key=key, reverse=True)
    kept = []
    for sp in spans:
        if all(not span_overlap((sp["start"], sp["end"]), (k["start"], k["end"])) for k in kept):
            kept.append(sp)
    return sorted(kept, key=lambda x: x["start"])

In [32]:
def redact_text(text, spans, style="stars"):
    chars = list(text)
    for sp in spans:
        s,e,lab = sp["start"], sp["end"], sp["label"]
        if style == "stars":
            chars[s:e] = "*" * (e - s)
        elif style == "label":
            chars[s:e] = f"[{lab}_REDACTED]"
        elif style == "partial_email" and lab == "EMAIL":
            chunk = normalize_email(text[s:e])
            try:
                user, domain = chunk.split("@",1)
                masked = f"{user[:1]}***@{domain}"
            except Exception:
                masked = "***"
            chars[s:e] = masked
        else:
            chars[s:e] = "***"
    return "".join(chars)

def coaching_message(findings):
    high = any(f["severity"]=="HIGH" for f in findings)
    kinds = sorted({f["label"] for f in findings})
    if high:
        return f"High-severity data detected ({', '.join(kinds)}). Keys/SSNs must be rotated or invalidated immediately. Do not paste secrets in chat/email."
    if kinds:
        return f"Detected {', '.join(kinds)}. Consider sharing via a secure vault and avoid plaintext."
    return "No sensitive data detected."

def detect_and_redact(text):
    spans = resolve_overlaps(ner_entities(text) + rule_entities(text))
    redacted = redact_text(text, spans, style="stars")
    findings = [{
        "label": sp["label"],
        "severity": SEVERITY.get(sp["label"], "LOW"),
        "span": [sp["start"], sp["end"]],
        "source": sp["source"]
    } for sp in spans]
    advice = coaching_message(findings)
    event = {
        "id": str(uuid.uuid4()),
        "original_text": text,
        "redacted_text": redacted,
        "findings": findings,
        "advice": advice,
        "ts": time.time()
    }
    return event

In [33]:
sample = "Email me at alice.lee @ example . com or call 415-555-1234. My SSN is 123-45-6789."
event = detect_and_redact(sample)
event["redacted_text"], event["findings"], event["advice"]

('Email ** at ************************* or call ************. ** *** is ***********.',
 [{'label': 'EMAIL', 'severity': 'MEDIUM', 'span': [6, 8], 'source': 'NER'},
  {'label': 'EMAIL', 'severity': 'MEDIUM', 'span': [12, 14], 'source': 'NER'},
  {'label': 'EMAIL', 'severity': 'MEDIUM', 'span': [14, 18], 'source': 'NER'},
  {'label': 'EMAIL', 'severity': 'MEDIUM', 'span': [18, 20], 'source': 'NER'},
  {'label': 'EMAIL', 'severity': 'MEDIUM', 'span': [20, 37], 'source': 'NER'},
  {'label': 'PHONE', 'severity': 'MEDIUM', 'span': [46, 58], 'source': 'RULE'},
  {'label': 'PERSON', 'severity': 'LOW', 'span': [60, 62], 'source': 'NER'},
  {'label': 'PERSON', 'severity': 'LOW', 'span': [63, 66], 'source': 'NER'},
  {'label': 'SSN', 'severity': 'HIGH', 'span': [70, 81], 'source': 'RULE'}],
 'High-severity data detected (EMAIL, PERSON, PHONE, SSN). Keys/SSNs must be rotated or invalidated immediately. Do not paste secrets in chat/email.')

In [34]:
LOG_PATH = f"{PROJECT_DIR}/logs/audit.jsonl"
import os, json
os.makedirs(f"{PROJECT_DIR}/logs", exist_ok=True)

def stream_demo(chunks):
    for ch in chunks:
        ev = detect_and_redact(ch)
        print("IN :", ch)
        print("OUT:", ev["redacted_text"])
        print("FND:", ev["findings"])
        print("ADVICE:", ev["advice"])
        with open(LOG_PATH, "a") as f:
            f.write(json.dumps(ev) + "\n")
        time.sleep(0.2)

chunks = [
  "Here is my email: john.doe @ gmail . com",
  "Temp API key: sk_live_1234567890abcdef",
  "Use 4242-4242-4242-4242 for tests; SSN 123-45-6789 should be blocked."
]
stream_demo(chunks)

print("Audit log →", LOG_PATH)

IN : Here is my email: john.doe @ gmail . com
OUT: Here is my email: **********************
FND: [{'label': 'EMAIL', 'severity': 'MEDIUM', 'span': [18, 19], 'source': 'NER'}, {'label': 'EMAIL', 'severity': 'MEDIUM', 'span': [19, 21], 'source': 'NER'}, {'label': 'EMAIL', 'severity': 'MEDIUM', 'span': [21, 23], 'source': 'NER'}, {'label': 'EMAIL', 'severity': 'MEDIUM', 'span': [23, 25], 'source': 'NER'}, {'label': 'EMAIL', 'severity': 'MEDIUM', 'span': [25, 40], 'source': 'NER'}]
ADVICE: Detected EMAIL. Consider sharing via a secure vault and avoid plaintext.
IN : Temp API key: sk_live_1234567890abcdef
OUT: Temp *** key: ***********************f
FND: [{'label': 'PERSON', 'severity': 'LOW', 'span': [5, 8], 'source': 'NER'}, {'label': 'APIKEY', 'severity': 'HIGH', 'span': [14, 15], 'source': 'NER'}, {'label': 'APIKEY', 'severity': 'HIGH', 'span': [15, 16], 'source': 'NER'}, {'label': 'EMAIL', 'severity': 'MEDIUM', 'span': [16, 17], 'source': 'NER'}, {'label': 'APIKEY', 'severity': 'HIGH', 

In [35]:
import re, time, uuid, json
from typing import List, Dict

# --- High-precision rules ---
RE_EMAIL = re.compile(r'\b[a-zA-Z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b')
RE_PHONE = re.compile(r'\b(?:\+?\d{1,3}[\s\-]?)?(?:\(?\d{3}\)?[\s\-]?)?\d{3}[\s\-]?\d{4}\b')
RE_SSN   = re.compile(r'\b\d{3}-\d{2}-\d{4}\b')
RE_CC    = re.compile(r'\b(?:\d[ -]?){13,19}\b')

# Common API key shapes (very conservative demo patterns)
RE_APIKEY = re.compile(
    r'\b(?:sk_(?:live|test)_[A-Za-z0-9]{16,}|'
    r'AIza[0-9A-Za-z\-_]{35}|'
    r'(?:ghp|gitpat)_[A-Za-z0-9]{30,}|'
    r'AKIA[0-9A-Z]{16})\b'
)

SEVERITY = {
    "SSN":"HIGH","CREDITCARD":"HIGH","APIKEY":"HIGH",
    "EMAIL":"MEDIUM","PHONE":"MEDIUM","ADDRESS":"MEDIUM","PERSON":"LOW"
}
PRIORITY = {"RULE": 2, "NER": 1}

STOPWORD_PERSON = {"my", "the", "a", "an", "this", "that", "here", "there", "it"}  # tiny noise filter

def normalize_email(s:str)->str:
    s = re.sub(r'\s*@\s*', '@', s)
    s = re.sub(r'\s*\.\s*', '.', s)
    return s

def span_overlap(a, b):
    return not (a[1] <= b[0] or b[1] <= a[0])

def merge_adjacent(text:str, ents:List[Dict], join_gap:int=1)->List[Dict]:
    """Merge adjacent same-type spans (fix WordPiece splits)."""
    if not ents: return []
    ents = sorted(ents, key=lambda x: x["start"])
    out = [ents[0].copy()]
    for e in ents[1:]:
        last = out[-1]
        same = e["label"] == last["label"]
        touching = e["start"] <= last["end"] + join_gap
        if same and touching:
            last["end"] = max(last["end"], e["end"])
            last["score"] = max(float(last.get("score",1.0)), float(e.get("score",1.0)))
            last["text"] = text[last["start"]:last["end"]]
        else:
            e2 = e.copy()
            e2["text"] = text[e2["start"]:e2["end"]]
            out.append(e2)
    # normalize/clean
    for sp in out:
        if sp["label"] == "EMAIL":
            sp["text"] = normalize_email(sp["text"])
    return out

def ner_entities(text:str)->List[Dict]:
    raw = ner(text)
    ents = []
    for r in raw:
        s, e = int(r["start"]), int(r["end"])
        lab  = r["entity_group"]
        tok  = text[s:e]
        # tiny PERSON de-noiser
        if lab == "PERSON" and tok.lower() in STOPWORD_PERSON:
            continue
        ents.append({"start": s, "end": e, "label": lab, "score": float(r["score"]), "source": "NER", "text": tok})
    return merge_adjacent(text, ents, join_gap=1)

def rule_entities(text:str)->List[Dict]:
    spans = []
    for m in RE_SSN.finditer(text):
        spans.append({"start": m.start(), "end": m.end(), "label":"SSN", "score":1.0, "source":"RULE", "text": m.group(0)})
    for m in RE_CC.finditer(text):
        spans.append({"start": m.start(), "end": m.end(), "label":"CREDITCARD", "score":1.0, "source":"RULE", "text": m.group(0)})
    for m in RE_EMAIL.finditer(text):
        spans.append({"start": m.start(), "end": m.end(), "label":"EMAIL", "score":1.0, "source":"RULE", "text": normalize_email(m.group(0))})
    for m in RE_PHONE.finditer(text):
        spans.append({"start": m.start(), "end": m.end(), "label":"PHONE", "score":1.0, "source":"RULE", "text": m.group(0)})
    for m in RE_APIKEY.finditer(text):
        spans.append({"start": m.start(), "end": m.end(), "label":"APIKEY", "score":1.0, "source":"RULE", "text": m.group(0)})
    return spans

def resolve_overlaps(spans:List[Dict])->List[Dict]:
    def key(sp):
        sev = {"HIGH":3,"MEDIUM":2,"LOW":1}[SEVERITY.get(sp["label"], "LOW")]
        return (PRIORITY.get(sp["source"],0), sev, sp.get("score",1.0), sp["end"]-sp["start"])
    spans = sorted(spans, key=key, reverse=True)
    kept = []
    for sp in spans:
        if all(not span_overlap((sp["start"], sp["end"]), (k["start"], k["end"])) for k in kept):
            kept.append(sp)
    return sorted(kept, key=lambda x: x["start"])

def redact_text(text:str, spans:List[Dict], style="stars")->str:
    """Apply redaction; ensures no trailing chars leak."""
    chars = list(text)
    for sp in spans:
        s, e, lab = sp["start"], sp["end"], sp["label"]
        red = "*" * (e - s) if style=="stars" else f"[{lab}_REDACTED]"
        # replace exactly s:e
        chars[s:e] = list(red)
    return "".join(chars)

def coaching_message(findings):
    high = any(SEVERITY.get(f["label"],"LOW")=="HIGH" for f in findings)
    kinds = sorted({f["label"] for f in findings})
    if high:
        return f"High-severity data detected ({', '.join(kinds)}). Rotate keys / invalidate SSNs immediately. Avoid sharing secrets in plaintext."
    if kinds:
        return f"Detected {', '.join(kinds)}. Use a secrets vault and minimize plaintext sharing."
    return "No sensitive data detected."

def detect_and_redact(text:str)->Dict:
    spans = resolve_overlaps(ner_entities(text) + rule_entities(text))
    redacted = redact_text(text, spans, style="stars")
    findings = [{
        "label": sp["label"],
        "severity": SEVERITY.get(sp["label"], "LOW"),
        "span": [sp["start"], sp["end"]],
        "source": sp["source"],
        "text": sp.get("text","")
    } for sp in spans]
    return {
        "id": str(uuid.uuid4()),
        "original_text": text,
        "redacted_text": redacted,
        "findings": findings,
        "advice": coaching_message(findings),
        "ts": time.time()
    }

In [36]:
import time

def benchmark_inference(sample, runs=50):
    start = time.time()
    for _ in range(runs):
        _ = ner(sample)
    end = time.time()
    t = (end - start) / runs
    print(f"Avg latency: {t*1000:.2f} ms per request | {1/t:.2f} req/sec")

In [37]:
benchmark_inference("User email is john.doe@example.com and card 4444 3333 2222 1111")

Avg latency: 98.36 ms per request | 10.17 req/sec


In [38]:
samples = [
  "Email me at alice.lee @ example . com or call 415-555-1234. My SSN is 123-45-6789.",
  "Temp API key: sk_live_1234567890abcdef",
  "Use 4242-4242-4242-4242 for tests; SSN 123-45-6789 should be blocked."
]
for s in samples:
    ev = detect_and_redact(s)
    print("IN :", s)
    print("OUT:", ev["redacted_text"])
    print("FND:", ev["findings"])
    print("ADVICE:", ev["advice"])
    print("---")

IN : Email me at alice.lee @ example . com or call 415-555-1234. My SSN is 123-45-6789.
OUT: Email ** at ************************* or call ************. My *** is ***********.
FND: [{'label': 'EMAIL', 'severity': 'MEDIUM', 'span': [6, 8], 'source': 'NER', 'text': 'me'}, {'label': 'EMAIL', 'severity': 'MEDIUM', 'span': [12, 37], 'source': 'NER', 'text': 'alice.lee@example.com'}, {'label': 'PHONE', 'severity': 'MEDIUM', 'span': [46, 58], 'source': 'RULE', 'text': '415-555-1234'}, {'label': 'PERSON', 'severity': 'LOW', 'span': [63, 66], 'source': 'NER', 'text': 'SSN'}, {'label': 'SSN', 'severity': 'HIGH', 'span': [70, 81], 'source': 'RULE', 'text': '123-45-6789'}]
ADVICE: High-severity data detected (EMAIL, PERSON, PHONE, SSN). Rotate keys / invalidate SSNs immediately. Avoid sharing secrets in plaintext.
---
IN : Temp API key: sk_live_1234567890abcdef
OUT: Temp *** key: ************************
FND: [{'label': 'PERSON', 'severity': 'LOW', 'span': [5, 8], 'source': 'NER', 'text': 'API'}, 

In [39]:
def coaching_agent(entities):
    tips = []
    for e in entities:
        if "CREDITCARD" in e["entity_group"]:
            tips.append("Never paste credit card numbers in chat tools.")
        if "SSN" in e["entity_group"]:
            tips.append("SSNs should be handled only in secure HR systems.")
        if "APIKEY" in e["entity_group"]:
            tips.append("API keys belong in secrets managers, not messages.")
    return "\n".join(tips)

In [41]:
pred = ner(sample)
# print(redact_output(sample, pred))
print(coaching_agent(pred))




In [42]:
SEVERITY = {
    "APIKEY": 10,
    "SSN": 9,
    "CREDITCARD": 8,
    "EMAIL": 5,
    "PHONE": 4,
    "ADDRESS": 3,
    "PERSON": 2
}

def severity_score(preds):
    return sum(SEVERITY.get(p["entity_group"].split("-")[-1],0) for p in preds)

In [43]:
score = severity_score(pred)
print("Risk score:", score)

Risk score: 37


In [44]:
def safe_infer(text, retries=3):
    for i in range(retries):
        try:
            return ner(text)
        except Exception as e:
            if i == retries-1:
                raise e
            time.sleep(0.1)

In [45]:
def batch_infer(texts):
    return ner(texts, batch_size=8, truncation=True)

In [46]:
import re
import torch
from transformers import DistilBertTokenizerFast, DistilBertForTokenClassification, TokenClassificationPipeline

POLICIES = {
    "EMAIL":      lambda x: re.sub(r"(.{2}).*@(.+)", r"\1***@\2", x),
    "PHONE":      lambda x: "***-***-" + re.sub(r"\D", "", x)[-4:],
    "CREDITCARD": lambda x: "**** **** **** " + re.sub(r"\D", "", x)[-4:],
    "SSN":        lambda x: "***-**-" + x[-4:],
    "APIKEY":     lambda x: "[REDACTED-API]",
    "PERSON":     lambda x: "[PERSON]",
    "ADDRESS":    lambda x: "[ADDRESS]"
}

def apply_policy(entity_group, token):
    base = entity_group.replace("B-","").replace("I-","")
    func = POLICIES.get(base, lambda x: x)
    return func(token)

# ---------- Rule-based pre-pass (fallbacks NER might miss) ----------
_RULES = [
    # Emails (tolerate spaces around @ and .)
    (re.compile(r'([A-Za-z0-9._%+-]{2})[A-Za-z0-9._%+-]*\s*@\s*([A-Za-z0-9.-]+(?:\s*\.\s*[A-Za-z]{2,})+)'),
     lambda m: f"{m.group(1)}***@{re.sub(r'\\s*\\.\\s*', '.', m.group(2).replace(' ', ''))}"),
    # SSN
    (re.compile(r'\b\d{3}-\d{2}-\d{4}\b'), lambda m: "***-**-" + m.group(0)[-4:]),
    # Credit cards (13–16 digits with spaces/dashes)
    (re.compile(r'\b(?:\d[ -]*?){13,16}\b'), lambda m: "**** **** **** " + re.sub(r'\D','', m.group(0))[-4:]),
    # US phones
    (re.compile(r'\b(?:\+?\d{1,3}[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?){2}\d{4}\b'),
     lambda m: "***-***-" + re.sub(r'\D','', m.group(0))[-4:]),
    # Stripe-like API keys
    (re.compile(r'\bsk_(?:live|test)_[A-Za-z0-9]{16,}\b'), lambda m: "[REDACTED-API]"),
]

def rule_based_redact(text: str) -> str:
    out = text
    for pat, repl in _RULES:
        out = pat.sub(repl, out)
    return out

# ---------- Span-safe NER redaction ----------
def redact_output(text, ner_preds):
    out = rule_based_redact(text)

    # Collect spans (start, end, label), drop obvious junk
    spans = []
    for p in ner_preds:
        s, e = int(p["start"]), int(p["end"])
        if 0 <= s < e <= len(text):
            spans.append((s, e, p["entity_group"]))

    # Apply right-to-left to avoid index shift; also skip if policy returns identical string
    for s, e, label in sorted(spans, key=lambda t: t[0], reverse=True):
        original = out[s:e]
        red = apply_policy(label, original)
        if red != original:
            out = out[:s] + red + out[e:]
    return out

# ---------- Load model locally ----------
from google.colab import drive
drive.mount('/content/drive')
import os, pprint
PROJECT_DIR = "/content/drive/MyDrive/DataShield_AI"
MODEL_DIR   = f"{PROJECT_DIR}/models/ner-distilbert"
print("Exists?", os.path.isdir(MODEL_DIR))
pprint.pp(os.listdir(MODEL_DIR))

tok = DistilBertTokenizerFast.from_pretrained(MODEL_DIR, local_files_only=True)
mdl = DistilBertForTokenClassification.from_pretrained(MODEL_DIR, local_files_only=True)

device = 0 if torch.cuda.is_available() else -1
ner = TokenClassificationPipeline(model=mdl, tokenizer=tok,
                                  aggregation_strategy="simple", device=device)

samples = [
    "Email me at alice.lee @ example . com or call 415-555-1234. My SSN is 123-45-6789.",
    "Temp API key: sk_live_1234567890abcdef",
    "Use 4242-4242-4242-4242 for tests; SSN 123-45-6789 should be blocked."
]

for text in samples:
    preds = ner(text)
    print(redact_output(text, preds))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Exists? True
['config.json',
 'model.safetensors',
 'special_tokens_map.json',
 'vocab.txt',
 'training_args.bin',
 'tokenizer_config.json',
 'tokenizer.json',
 'labels.json',
 'checkpoint-1126',
 'checkpoint-1689']


Device set to use cpu


Email me at al***@example.com or call ***-***-***-***-1234 i[PERSON]*[PERSON]**-6***-***-789
Temp [PERSON] key: [REDACTED-API][REDACTED-API]E[REDACTED-API][REDACTED-API][REDACTED-API]
Use ***-***-***-***-4242 for tests; SS***-***- ***-***-6789 should be blocked.
