In [42]:
import os, time, json
from pathlib import Path
import pandas as pd
from dotenv import load_dotenv
from tenacity import retry, wait_exponential, stop_after_attempt
from openai import OpenAI

def load_azure_env():
    repo = Path.cwd()
    for p in [repo / "scripts" / ".env", repo / ".env"]:
        if p.exists():
            load_dotenv(p, override=True)
    required = ["AZURE_OPENAI_ENDPOINT","AZURE_OPENAI_API_KEY","AZURE_OPENAI_DEPLOYMENT","AZURE_OPENAI_EMBEDDINGS"]
    missing = [k for k in required if not os.getenv(k)]
    if missing:
        print("Azure not configured; missing:", ", ".join(missing))
        return False
    return True

In [43]:
def azure_chat_client():
    return OpenAI(
        api_key=os.getenv("AZURE_OPENAI_API_KEY"),
        base_url=f"{os.getenv('AZURE_OPENAI_ENDPOINT')}/openai/deployments/{os.getenv('AZURE_OPENAI_DEPLOYMENT')}",
        default_query={"api-version": os.getenv("AZURE_OPENAI_API_VERSION","2024-02-15-preview")},
        default_headers={"api-key": os.getenv("AZURE_OPENAI_API_KEY")},
    )

def azure_embed_client():
    return OpenAI(
        api_key=os.getenv("AZURE_OPENAI_API_KEY"),
        base_url=f"{os.getenv('AZURE_OPENAI_ENDPOINT')}/openai/deployments/{os.getenv('AZURE_OPENAI_EMBEDDINGS')}",
        default_query={"api-version": os.getenv("AZURE_OPENAI_API_VERSION","2024-02-15-preview")},
        default_headers={"api-key": os.getenv("AZURE_OPENAI_API_KEY")},
    )

In [44]:
# --- Merchant labeling (robust JSON, schema, UTC, CI-safe) ---

from datetime import datetime, timezone
import os, json, re, time
from tenacity import retry, wait_exponential, stop_after_attempt

MERCHANT_DIM_PATH = Path("config/merchants_dim.csv")
IN_CI = os.getenv("GITHUB_ACTIONS", "").lower() == "true"
AI_STRICT = os.getenv("AI_STRICT", "0") == "1"  # set to 1 to fail on AI errors
BATCH = 10 if IN_CI else 20     # smaller in CI to avoid truncation
MAP_ALL = True

SCHEMA = {
    "merchant_key": "string",
    "display_name": "string",
    "category": "string",
    "subcategory": "string",
    "tags": "string",
    "source": "string",
    "confidence": "float64",
    "last_updated": "string",
}

def _ensure_merchants_dim():
    if MERCHANT_DIM_PATH.exists():
        md = pd.read_csv(MERCHANT_DIM_PATH)
        for col, dt in SCHEMA.items():
            if col not in md.columns:
                md[col] = pd.Series(dtype=dt)
        return md.astype(SCHEMA, copy=False)
    MERCHANT_DIM_PATH.parent.mkdir(parents=True, exist_ok=True)
    return pd.DataFrame({c: pd.Series(dtype=dt) for c, dt in SCHEMA.items()})

def _parse_labels_strict_or_salvage(txt: str):
    txt = txt.strip()
    try:
        obj = json.loads(txt)
        if isinstance(obj, dict) and isinstance(obj.get("items"), list):
            return obj["items"]
        if isinstance(obj, list):
            return obj
    except Exception:
        pass
    m = re.findall(r"\[[\s\S]*\]", txt)
    if m:
        for cand in reversed(m):
            try:
                return json.loads(cand)
            except Exception:
                continue
    raise RuntimeError(f"Failed to parse AI JSON (first 400 chars):\n{txt[:400]}")

@retry(wait=wait_exponential(multiplier=1, min=1, max=20), stop=stop_after_attempt(5))
def azure_label_batch(keys_batch):
    compact = [str(k)[:100] for k in keys_batch]
    sys_prompt = (
        "You label merchant identifiers for a personal finance dashboard.\n"
        "For each merchant_key, produce: merchant_key, display_name, category, subcategory, tags.\n"
        "- display_name: short human name (e.g., 'ARCO', 'Apple Card').\n"
        "- category: one of Dining, Groceries, Gas, Shopping, Utilities, Subscriptions, Transfers, Income, "
        "Health, Travel, Entertainment, Education, Fees, Misc.\n"
        "- subcategory: specific subtype (e.g., 'Gas Station', 'Internet Service').\n"
        "- tags: array of 1–5 lowercase keywords.\n"
        'Return ONLY JSON: {"items":[{...}]} with no extra commentary.'
    )
    usr_payload = {"merchant_keys": compact}
    c = azure_chat_client()
    r = c.chat.completions.create(
        model=os.getenv("AZURE_OPENAI_DEPLOYMENT"),
        messages=[
            {"role":"system","content": sys_prompt},
            {"role":"user","content": json.dumps(usr_payload)}
        ],
        temperature=0,
        max_tokens=1400,
        response_format={"type": "json_object"},
    )
    return _parse_labels_strict_or_salvage(r.choices[0].message.content)

def label_new_merchants(df, merchant_key_col="merchant_key"):
    md = _ensure_merchants_dim()
    if merchant_key_col not in df.columns:
        print(f"Column '{merchant_key_col}' not in dataframe; skipping labeling.")
        return 0

    known = set(md["merchant_key"].astype(str)) if len(md) > 0 else set()
    candidates = sorted(set(df[merchant_key_col].astype(str)) - known)
    if not MAP_ALL or not candidates:
        print("No new merchants to label."); return 0

    added = 0
    for i in range(0, len(candidates), BATCH):
        batch = candidates[i:i+BATCH]
        try:
            items = azure_label_batch(batch)
        except Exception as e:
            print(f"⚠️ AI batch failed [{i}:{i+len(batch)}] — {e}")
            if AI_STRICT:
                raise
            else:
                continue

        now = datetime.now(timezone.utc).isoformat()
        rows = []
        for it in items:
            mk = str(it.get("merchant_key") or "").strip()
            if not mk:
                continue
            display = str(it.get("display_name", mk)).upper().strip()
            category = str(it.get("category","")).strip()
            subcat   = str(it.get("subcategory","")).strip()
            tags_val = it.get("tags", [])
            tags_csv = ",".join([str(t).strip() for t in tags_val]) if isinstance(tags_val, list) else ""
            rows.append({
                "merchant_key": mk,
                "display_name": display,
                "category": category,
                "subcategory": subcat,
                "tags": tags_csv,
                "source": "azure",
                "confidence": 0.90,
                "last_updated": now
            })

        if rows:
            chunk = pd.DataFrame(rows)
            for col, dt in SCHEMA.items():
                if col not in chunk.columns:
                    chunk[col] = pd.Series(dtype=dt)
            chunk = chunk.astype(SCHEMA, copy=False)
            md = md.astype(SCHEMA, copy=False)

            md = pd.concat([md, chunk], ignore_index=True)
            md = md.sort_values("last_updated").drop_duplicates(["merchant_key"], keep="last")
            md.to_csv(MERCHANT_DIM_PATH, index=False)
            added += len(chunk)
            print(f"Added {len(chunk)} merchant mappings (running total {added}).")

        time.sleep(0.1)

    return added


In [45]:
# --- Embeddings builder (caches to vectorstore/embeddings.parquet) ---

import numpy as np, pyarrow as pa, pyarrow.parquet as pq

VECTOR_PATH = Path("vectorstore/embeddings.parquet")
EMBED_ROWS = 500

def build_embeddings(df, text_cols=("display_name","description")):
    ec = azure_embed_client()
    cols = [c for c in text_cols if c in df.columns]
    if not cols:
        print("No text columns found; skipping embeddings."); return 0

    recent = df.tail(EMBED_ROWS).copy()
    texts = recent[cols[0]].astype(str)
    for c in cols[1:]:
        texts = texts + " | " + recent[c].astype(str)
    texts = texts.tolist()

    embs = []
    for t in texts:
        e = ec.embeddings.create(model=os.getenv("AZURE_OPENAI_EMBEDDINGS"), input=[t])
        embs.append(e.data[0].embedding)

    if not embs:
        print("No embeddings produced."); return 0
    dim = len(embs[0])

    flat = np.array(embs, dtype="float32").ravel()
    arr = pa.FixedSizeListArray.from_arrays(pa.array(flat), dim)
    table = pa.Table.from_pydict({"row_idx": pa.array(recent.index.astype(int)), "embedding": arr})
    VECTOR_PATH.parent.mkdir(parents=True, exist_ok=True)
    pq.write_table(table, VECTOR_PATH)
    print(f"Wrote {len(embs)} embeddings (dim {dim}) → {VECTOR_PATH}")
    return len(embs)


In [46]:
def azure_ai_enabled():
    req = ["AZURE_OPENAI_ENDPOINT","AZURE_OPENAI_API_KEY","AZURE_OPENAI_DEPLOYMENT","AZURE_OPENAI_EMBEDDINGS"]
    return all(os.getenv(k) for k in req)

def run_azure_ai(enriched_df):
    if not load_azure_env() or not azure_ai_enabled():
        print("Azure not configured; skipping AI steps.")
        return {"labeled":0,"embedded":0}
    labeled = label_new_merchants(enriched_df)
    embedded = build_embeddings(enriched_df)
    return {"labeled": labeled, "embedded": embedded}
