In [None]:
# embed.py
# pip install pandas transformers "supabase>=2" python-dotenv torch

import os, re, json, hashlib
from typing import List, Dict, Optional
import pandas as pd
import numpy as np

from supabase import create_client, Client
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from dotenv import load_dotenv

# ---- Env ----
load_dotenv("../secrets/.env.dev")
SUPABASE_URL = os.environ["SUPABASE_URL"]
SUPABASE_KEY = os.environ["SUPABASE_KEY"]
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)

# ---- Model ----
MODEL_ID = "Qwen/Qwen3-Embedding-8B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🔧 Torch device available: {DEVICE}, CUDA count={torch.cuda.device_count()}")

tok = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
mdl = AutoModel.from_pretrained(
    MODEL_ID,
    dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
    device_map="auto",        # will use GPU if possible
    low_cpu_mem_usage=True
)
mdl.eval()
print(f"✅ Model loaded on device(s): {set(p.device for p in mdl.parameters())}")

@torch.no_grad()
def _last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    seq_lens = attention_mask.sum(dim=1) - 1
    bsz = last_hidden_states.shape[0]
    return last_hidden_states[torch.arange(bsz, device=last_hidden_states.device), seq_lens]

@torch.no_grad()
def embed_texts(texts: List[str], batch_size=1, max_length=1024) -> np.ndarray:
    vecs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        toks = tok(batch, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
        dev = next(iter(mdl.state_dict().values())).device
        toks = {k: v.to(dev) for k, v in toks.items()}
        out = mdl(**toks)
        pooled = _last_token_pool(out.last_hidden_state, toks["attention_mask"])
        pooled = F.normalize(pooled, p=2, dim=1)
        pooled = pooled[:, :4000].to(torch.float32)  # keep 4000 for halfvec(4000)
        arr = pooled.cpu().numpy()
        norms = np.linalg.norm(arr, axis=1)
        print(f"➡️ Embedded batch {i//batch_size+1} on {dev} | shape={arr.shape} | norm≈{norms.mean():.4f}")
        vecs.append(arr)
    return np.vstack(vecs)

# =========================
# NEWS-AWARE CHUNKER
# =========================
BOILERPLATE_PATTERNS = [
    r"^Advertisement$", r"^Ads?$", r"^Subscribe now.*$", r"^Sign up for.*newsletter.*$",
    r"^Read more:.*$", r"^Editor’s note:.*$", r"^Correction:.*$", r"^©.*$"
]
BOILERPLATE_RE = re.compile("|".join(BOILERPLATE_PATTERNS), flags=re.IGNORECASE)

def _clean_lines(text: str) -> List[str]:
    text = text.replace("\r", "")
    lines = [ln.strip() for ln in text.split("\n")]
    return [ln for ln in lines if ln and not BOILERPLATE_RE.match(ln)]

def _looks_headline(s: str) -> bool:
    return len(s) <= 140 and not s.endswith(".") and (
        s.istitle() or s.isupper() or re.match(r"^[A-Z0-9][A-Za-z0-9,:’'“”\-– ]+$", s)
    )

def _is_dek_or_byline(s: str) -> bool:
    sl = s.lower()
    return (len(s) <= 180) and (sl.startswith("by ") or "updated" in sl or "—" in s or "–" in s)

def _paragraph_blocks(lines: List[str]) -> List[str]:
    paras, cur = [], []
    for ln in lines:
        if re.match(r"^(\*|-|•|\d+\.)\s+", ln):   # treat bullets as their own para
            if cur: paras.append(" ".join(cur)); cur = []
            paras.append(ln)
        else:
            cur.append(ln)
    if cur: paras.append(" ".join(cur))
    return paras

def _split_sentences_news(para: str) -> List[str]:
    marked = re.sub(r'([.!?]["”\']?)\s+(?=[A-Z0-9“"(\[])', r'\1<SPLIT>', para.strip())
    parts = [p.strip() for p in marked.split('<SPLIT>') if p.strip()]
    return parts

def _count_tokens(s: str) -> int:
    return len(tok.encode(s, add_special_tokens=False))

def chunk_news_smart(
    raw_text: str,
    target_tokens: int = 320,
    overlap_tokens: int = 48,
    *,
    debug: bool = False,
    debug_n: int = 5
) -> List[str]:
    lines = _clean_lines(raw_text)
    if not lines:
        return []

    headline = lines[0] if _looks_headline(lines[0]) else None
    i = 1 if headline else 0
    dek = lines[i] if i < len(lines) and _is_dek_or_byline(lines[i]) else None
    i = i + 1 if dek else i

    body_lines = lines[i:]
    paras = _paragraph_blocks(body_lines)

    prefix = " — ".join([x for x in [headline, dek] if x]) if (headline or dek) else None
    prefix_tok = _count_tokens(prefix) if prefix else 0

    chunks: List[str] = []
    current: List[str] = []
    current_tok = 0

    def flush_chunk():
        nonlocal current, current_tok, prefix
        if not current:
            return
        text = " ".join(current)
        if prefix and not chunks:
            text = f"{prefix} — {text}"
        chunks.append(text)
        back, btok = [], 0
        for s in reversed(current):
            t = _count_tokens(s)
            if btok + t > overlap_tokens:
                break
            back.insert(0, s)
            btok += t
        current = back
        current_tok = sum(_count_tokens(s) for s in current)
        prefix = None

    for para in paras:
        sents = _split_sentences_news(para)
        for s in sents:
            tok_count = _count_tokens(s)
            if tok_count > target_tokens:
                parts = re.split(r'([,;:]\s+)', s)
                buf = ""
                for p in parts:
                    if _count_tokens(buf + p) > target_tokens and buf:
                        current.append(buf.strip())
                        flush_chunk()
                        buf = ""
                    buf += p
                if buf.strip():
                    current.append(buf.strip()); current_tok += _count_tokens(buf.strip())
                continue

            budget = target_tokens - (0 if chunks else prefix_tok)
            if current_tok + tok_count <= budget or not current:
                current.append(s); current_tok += tok_count
            else:
                flush_chunk()
                current.append(s); current_tok = tok_count

    flush_chunk()

    if debug:
        print(f"🧩 Built {len(chunks)} chunks (target={target_tokens}, overlap={overlap_tokens})")
        for i, ch in enumerate(chunks[:debug_n]):
            t = _count_tokens(ch)
            preview = (ch[:160] + "…") if len(ch) > 160 else ch
            print(f"  • chunk[{i}] tokens={t}: {preview}")

    return chunks

def upsert_rows(rows: List[Dict]):
    for r in rows:
        if len(r["embedding"]) != 4000:
            raise ValueError("Each embedding must be length 4000 for halfvec(4000).")
    resp = (
        supabase.table("case_chunks")
        .upsert(
            rows,
            on_conflict="doc_id,chunk_id",   # ← important
            ignore_duplicates=False,
            returning="minimal"              # smaller response
        )
        .execute()
    )
    if getattr(resp, "error", None):
        raise RuntimeError(resp.error)
    print(f"✅ Upserted {len(rows)} rows into Supabase.")
    return resp


# ★ NEW: if CSV gives us a stable id, use it; else fall back to link+hash(text)
def _doc_id_from_csv(row, *, id_col="id", link_col="link", text_col="full_text", row_idx=0) -> str:
    csv_id = row.get(id_col)
    if csv_id is not None and str(csv_id).strip():
        return str(csv_id).strip()

    # fallback (previous logic)
    link = row.get(link_col)
    text = str(row.get(text_col) or "")
    base = str(link).strip() if link and str(link).strip() else None
    h = hashlib.sha1((text or f"row-{row_idx}").encode("utf-8")).hexdigest()[:10]
    return f"{base or 'hash'}#h{h}"


def ingest_one_doc(full_text: str, law=None, company=None, link=None, doc_id=None):
    if not doc_id:
        # Last-resort fallback only if someone calls this directly without a doc_id.
        h = hashlib.sha1((full_text or '').encode('utf-8')).hexdigest()[:10]
        doc_id = f"hash-{h}"
    print(f"\n🚀 Ingesting one doc: {doc_id}")

    chunks = chunk_news_smart(full_text, target_tokens=320, overlap_tokens=48)
    print(f"Split into {len(chunks)} chunks.")
    if not chunks:
        print("⚠️ No content after cleaning; skipping.")
        return

    embs = embed_texts(chunks)  # [n, 4000]
    rows = []
    for ci, (chunk_text_val, vec) in enumerate(zip(chunks, embs)):
        rows.append({
            "doc_id": doc_id,
            "chunk_id": ci,
            "text": chunk_text_val,
            "law": law,
            "company": company,
            "link": link,
            "embedding": vec.astype(float).tolist(),
        })
    upsert_rows(rows)


# ★ NEW: CSV ingestion wrapper
def ingest_csv(
    csv_path: str,
    text_col: str = "full_text",
    law_col: str = "law",
    company_col: str = "company",
    link_col: str = "link",
    id_col: str = "id",
):
    print(f"\n📄 Loading CSV: {csv_path}")
    df = pd.read_csv(csv_path)
    cols_lower = {c.lower() for c in df.columns}
    if text_col.lower() not in cols_lower:
        raise ValueError(f"CSV must have a '{text_col}' column. Found: {sorted(df.columns)}")

    for idx, row in df.iterrows():
        text = str(row.get(text_col) or "").strip()
        if not text:
            continue

        law = row.get(law_col) if law_col in df.columns else None
        company = row.get(company_col) if company_col in df.columns else None
        link = row.get(link_col) if link_col in df.columns else None

        doc_id = _doc_id_from_csv(row, id_col=id_col, link_col=link_col, text_col=text_col, row_idx=idx)

        ingest_one_doc(text, law=law, company=company, link=link, doc_id=doc_id)
        if (idx + 1) % 10 == 0:
            print(f"… processed {idx+1}/{len(df)} rows")

if __name__ == "__main__":
    # ★ NEW: run over your CSV (defaults to tech_law_violations.csv)
    CSV_PATH = os.environ.get("CSV_PATH", "tech_law_violations.csv")
    ingest_csv(CSV_PATH)


🔧 Torch device available: cuda, CUDA count=1


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.


✅ Model loaded on device(s): {device(type='cuda', index=0), device(type='meta')}

📄 Loading CSV: tech_law_violations.csv

🚀 Ingesting one doc: https://www.ftc.gov/news-events/press-releases/2019/07/ftc-imposes-5-billion-penalty-sweeping-new-privacy-restrictions-facebook
Split into 2 chunks.
➡️ Embedded batch 1 on cuda:0 | shape=(1, 4000) | norm≈0.9879
➡️ Embedded batch 2 on cuda:0 | shape=(1, 4000) | norm≈0.9924
✅ Upserted 2 rows into Supabase.

🚀 Ingesting one doc: https://www.ftc.gov/news-events/news/press-releases/2022/05/ftc-justice-department-order-twitter-pay-150-million-violating-2011-ftc-order-misrepresenting
Split into 1 chunks.
➡️ Embedded batch 1 on cuda:0 | shape=(1, 4000) | norm≈0.9916
✅ Upserted 1 rows into Supabase.

🚀 Ingesting one doc: https://www.ftc.gov/news-events/press-releases/2019/09/google-youtube-pay-record-170-million-penalty-alleged-violations-coppa
Split into 1 chunks.
➡️ Embedded batch 1 on cuda:0 | shape=(1, 4000) | norm≈0.9933
✅ Upserted 1 rows into Supab

APIError: {'message': 'duplicate key value violates unique constraint "case_chunks_doc_id_chunk_id_key"', 'code': '23505', 'hint': None, 'details': 'Key (doc_id, chunk_id)=(https://noyb.eu/en/gdpr-complaints-against-google-facebook-instagram-and-whatsapp-forced-consent, 0) already exists.'}