<a href="https://colab.research.google.com/github/ajitbubu/AI-power-UCM/blob/main/AI_power_cookie_catagories.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
## Set up
# @title Install deps (quiet)
!pip -q install xgboost scikit-learn shap fastapi uvicorn nest_asyncio pydantic==2.* > /dev/null
import warnings; warnings.filterwarnings("ignore")
print("✅ Deps installed")


✅ Deps installed


In [7]:
# Rulebook + Vendor DB (inline JSON)

# @title Inline rulebook & vendor DB
import re, json, math, hashlib, random
from dataclasses import dataclass

RULES = {
    "name_regex": [
        {"pattern": r"^_ga",       "purpose": "analytics", "confidence": 0.98, "id": "regex:_ga"},
        {"pattern": r"^_gid",      "purpose": "analytics", "confidence": 0.95, "id": "regex:_gid"},
        {"pattern": r"^_fbp",      "purpose": "ads",       "confidence": 0.95, "id": "regex:_fbp"},
        {"pattern": r"^fr$",       "purpose": "ads",       "confidence": 0.95, "id": "regex:fr"},
        {"pattern": r"csrf",       "purpose": "necessary", "confidence": 0.99, "id": "regex:csrf"},
        {"pattern": r"^session",   "purpose": "necessary", "confidence": 0.97, "id": "regex:session"},
    ],
    "prefix": [
        {"prefix": "__Host-",   "purpose": "necessary", "confidence": 0.99, "id": "prefix:__Host-"},
        {"prefix": "__Secure-", "purpose": "necessary", "confidence": 0.95, "id": "prefix:__Secure-"},
    ],
}

VENDORS = {
    # demo vendor ids (align with your backend seed ids if you want)
    "google-analytics.com": {
        "id": "00000000-0000-4000-a000-000000000111",
        "iab_purposes": [7],
        "risk_prior": 0.2
    },
    "facebook.com": {
        "id": "00000000-0000-4000-a000-000000000222",
        "iab_purposes": [1,3,4,7],
        "risk_prior": 0.6
    },
    "parentcompany.com": {
        "id": "00000000-0000-4000-a000-000000000333",
        "iab_purposes": [2],
        "risk_prior": 0.3
    }
}
print("✅ Rules & vendor DB loaded")


✅ Rules & vendor DB loaded


In [10]:
# @title Feature engineering
from collections import defaultdict
from urllib.parse import urlparse

def _entropy_from_len(val_len: int) -> float:
    # entropy proxy (bounded)
    return min(1.0, math.log2(max(1, val_len)) / 16.0)

def _ttl_bucket(ttl: int) -> str:
    if ttl is None: return "unknown"
    if ttl <= 86400: return "<=1d"
    if ttl <= 30*86400: return "1-30d"
    if ttl <= 400*86400: return "30-400d"
    return ">400d"

def _hash(s: str) -> int:
    return int(hashlib.md5(s.encode()).hexdigest(), 16) % 10000

def _name_regex_class(name: str):
    for rule in RULES["name_regex"]:
        if re.search(rule["pattern"], name or "", flags=re.I):
            return rule["id"], rule["purpose"], rule["confidence"]
    return None, None, None

def _prefix_rule(prefix: str):
    for rule in RULES["prefix"]:
        if prefix and (prefix.lower() == rule["prefix"].lower()):
            return rule["id"], rule["purpose"], rule["confidence"]
    return None, None, None

def _vendor_by_domain(domain: str):
    if not domain: return None
    domain = domain.lstrip(".").lower()
    # naive: longest suffix match
    candidates = sorted(VENDORS.keys(), key=len, reverse=True)
    for cand in candidates:
        if domain.endswith(cand):
            return VENDORS[cand]
    return None

def featurize(sample: dict):
    c = sample.get("cookie", {})
    v = sample.get("vendor_context", {}) or {}
    name = c.get("name","")
    domain = (c.get("domain","") or "").lower().lstrip(".")
    path = c.get("path","/")
    fp = bool(c.get("first_party", True))
    max_age = c.get("max_age", None)
    samesite = (c.get("same_site") or "Lax")
    secure = bool(c.get("secure", True))
    http_only = bool(c.get("http_only", False))
    prefix = c.get("prefix") or ""
    initiator_host = (c.get("initiator_script_host") or "").lower()
    net_dests = c.get("network_destinations") or []
    creation_order = int(c.get("creation_order_idx", 0))

    name_rule_id, name_rule_purpose, name_rule_conf = _name_regex_class(name)
    prefix_rule_id, prefix_rule_purpose, prefix_rule_conf = _prefix_rule(prefix)

    vend = v or _vendor_by_domain(domain) or {}
    vend_id = vend.get("id")
    vend_risk = vend.get("risk_prior", 0.3)
    vend_purposes = vend.get("iab_purposes", [])

    features_dict = {
        "name": name,
        "domain": domain,
        "path_hash": _hash(path),
        "first_party": int(fp),
        "max_age": int(max_age or 0),
        "ttl_bucket": _ttl_bucket(max_age or 0),
        "same_site": samesite,
        "secure": int(secure),
        "http_only": int(http_only),
        "prefix": prefix or "",
        "initiator_host": initiator_host,
        "creation_order": creation_order,
        "num_net_dests": len(net_dests),
        "dest_has_ga": int(any("google-analytics" in (d or "") for d in net_dests)),
        "dest_has_fb": int(any("facebook" in (d or "") for d in net_dests)),
        "entropy_proxy": _entropy_from_len(24 if prefix else 16),  # we don't use values, just proxy
        "cross_site_usage": int(samesite.lower()=="none" and not fp),
        "name_rule_hit": int(name_rule_id is not None),
        "prefix_rule_hit": int(prefix_rule_id is not None),
        "vendor_known": int(vend_id is not None),
        "vendor_risk": float(vend_risk),
        "vendor_declares_p7": int(7 in vend_purposes),
        "vendor_declares_ad": int(any(x in vend_purposes for x in [1,3,4])),
        "name_regex_class": name_rule_id or "none",
    }

    context = {
        "rules_hit": [r for r in [name_rule_id, prefix_rule_id] if r],
        "rule_purposes": [p for p in [name_rule_purpose, prefix_rule_purpose] if p],
        "rule_confidences": [c for c in [name_rule_conf, prefix_rule_conf] if c],
        "vendor_id": vend_id,
        "vendor_purposes": vend_purposes,
        "vendor_risk": vend_risk,
    }
    return features_dict, context


In [13]:
# @title Train a tiny XGBoost model (demo)
import numpy as np, pandas as pd
from sklearn.feature_extraction import DictVectorizer
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score
import xgboost as xgb
import re

# Synthetic labeled examples (small but representative)
examples = [
    # analytics
    {"cookie":{"name":"_ga","domain":"shop.example.com","first_party":True,"max_age":63072000,"same_site":"Lax","secure":True,"http_only":False,"prefix":"" ,"initiator_script_host":"www.googletagmanager.com","network_destinations":["www.google-analytics.com"]},"label":"analytics"},
    {"cookie":{"name":"_gid","domain":"blog.example.com","first_party":True,"max_age":86400,"same_site":"Lax","secure":True,"http_only":False,"prefix":"" ,"initiator_script_host":"www.googletagmanager.com","network_destinations":["www.google-analytics.com"]},"label":"analytics"},
    # ads
    {"cookie":{"name":"_fbp","domain":".facebook.com","first_party":False,"max_age":7776000,"same_site":"None","secure":True,"http_only":False,"prefix":"","initiator_script_host":"connect.facebook.net","network_destinations":["graph.facebook.com"]},"label":"ads"},
    {"cookie":{"name":"fr","domain":".facebook.com","first_party":False,"max_age":7776000,"same_site":"None","secure":True,"http_only":False,"prefix":"","initiator_script_host":"connect.facebook.net","network_destinations":["www.facebook.com"]},"label":"ads"},
    # necessary
    {"cookie":{"name":"__Host-csrf","domain":"app.example.com","first_party":True,"max_age":3600,"same_site":"Strict","secure":True,"http_only":True,"prefix":"__Host-","initiator_script_host":"app.example.com","network_destinations":[]},"label":"necessary"},
    {"cookie":{"name":"sessionid","domain":"app.example.com","first_party":True,"max_age":7200,"same_site":"Lax","secure":True,"http_only":True,"prefix":"","initiator_script_host":"app.example.com","network_destinations":[]},"label":"necessary"},
    # functional
    {"cookie":{"name":"lang_pref","domain":"www.example.com","first_party":True,"max_age":31536000,"same_site":"Lax","secure":True,"http_only":False,"prefix":"","initiator_script_host":"www.example.com","network_destinations":[]},"label":"functional"},
    {"cookie":{"name":"theme","domain":"www.example.com","first_party":True,"max_age":31536000,"same_site":"Lax","secure":True,"http_only":False,"prefix":"","initiator_script_host":"www.example.com","network_destinations":[]},"label":"functional"},
    # unknown
    {"cookie":{"name":"x123","domain":"cdn.random.net","first_party":False,"max_age":2592000,"same_site":"None","secure":True,"http_only":False,"prefix":"","initiator_script_host":"cdn.random.net","network_destinations":["cdn.random.net/beacon"]},"label":"unknown"},
]

X_dicts, y = [], []
for row in examples:
    f, _ctx = featurize(row)
    X_dicts.append(f)
    y.append(row["label"])

vec = DictVectorizer(sparse=False)
X = vec.fit_transform(X_dicts)

# Sanitize feature names
feature_names = [re.sub(r'[\[\]<]', '_', name) for name in vec.get_feature_names_out()]

le = LabelEncoder()
y_enc = le.fit_transform(y)

dtrain = xgb.DMatrix(X, label=y_enc, feature_names=feature_names)
params = dict(
    max_depth=4, eta=0.2, subsample=0.8, colsample_bytree=0.8,
    objective="multi:softprob", num_class=len(le.classes_), eval_metric="mlogloss"
)
model = xgb.train(params, dtrain, num_boost_round=80)
print("✅ Model trained on synthetic set | classes:", list(le.classes_))

✅ Model trained on synthetic set | classes: [np.str_('ads'), np.str_('analytics'), np.str_('functional'), np.str_('necessary'), np.str_('unknown')]


In [None]:
# @title Predict function: Rules → ML → Post-policy → Regulatory mapping
import shap

explainer = shap.TreeExplainer(model)

REG_MAP = {
    "necessary": {"tcf_purposes": [],       "gpp_impacts": [],                          "gcm": {"ad_user_data":"denied","ad_personalization":"denied","analytics_storage":"denied"}},
    "functional":{"tcf_purposes": [2],      "gpp_impacts": [],                          "gcm": {"ad_user_data":"denied","ad_personalization":"denied","analytics_storage":"denied"}},
    "analytics": {"tcf_purposes": [7],      "gpp_impacts": ["measurement"],             "gcm": {"ad_user_data":"denied","ad_personalization":"denied","analytics_storage":"granted"}},
    "ads":       {"tcf_purposes": [1,3,4],  "gpp_impacts": ["targeted_advertising","sharing"], "gcm": {"ad_user_data":"granted","ad_personalization":"granted","analytics_storage":"granted"}},
    "unknown":   {"tcf_purposes": [],       "gpp_impacts": [],                          "gcm": {"ad_user_data":"denied","ad_personalization":"denied","analytics_storage":"denied"}},
}

def classify(sample: dict):
    fdict, ctx = featurize(sample)
    rules_hit = list(ctx["rules_hit"])
    rule_purposes = ctx["rule_purposes"]
    rule_conf = max(ctx["rule_confidences"], default=0.0)

    # 1) Strong rules override
    if rules_hit and rule_conf >= 0.95:
        purpose = rule_purposes[0]
        probs = {lab: (1.0 if lab==purpose else 0.0) for lab in le.classes_}
        conf = 0.95
        top_features = [{"feature":"rules", "value": rules_hit[0], "shap": 0.5}]
    else:
        # 2) ML
        X_inst = vec.transform([fdict])
        dm = xgb.DMatrix(X_inst, feature_names=vec.get_feature_names_out())
        proba = model.predict(dm)[0]
        purpose = le.classes_[int(proba.argmax())]
        conf = float(proba.max())
        probs = {le.classes_[i]: float(proba[i]) for i in range(len(proba))}
        # 3) Post-policy: conservative push to ads in cross-site long-lived vendor-ad context
        cross_site = bool(fdict["cross_site_usage"])
        long_ttl = fdict["ttl_bucket"] in ("30-400d", ">400d")
        vendor_ad = bool(fdict["vendor_declares_ad"])
        if purpose != "necessary" and cross_site and long_ttl and vendor_ad and probs.get("ads",0) >= 0.55:
            purpose = "ads"
            conf = max(conf, probs["ads"])
        # SHAP (per instance)
        shap_vals = explainer.shap_values(X_inst)
        # For multi-class: pick the class index of current purpose
        k = list(le.classes_).index(purpose)
        sv = shap_vals[k][0]
        # Pick top absolute contributions
        feat_names = vec.get_feature_names_out()
        order = np.argsort(np.abs(sv))[::-1][:4]
        top_features = [{"feature": str(feat_names[i]), "value": str(X_inst[0][i]), "shap": float(sv[i])} for i in order]

    reg = REG_MAP.get(purpose, REG_MAP["unknown"])
    out = {
        "purpose": purpose,
        "probs": probs,
        "confidence": round(conf, 4),
        "top_features": top_features,
        "rules_hit": rules_hit,
        "regulatory": reg,
        "risk": round(min(1.0, 0.2 + (0.3 if purpose=="ads" else 0.0) + (0.1 if fdict["cross_site_usage"] else 0.0) + 0.2*fdict["vendor_risk"]), 3)
    }
    return out

print("✅ Classifier ready")


In [None]:
# @title Try 3 examples: analytics, ads, necessary
samples = [
    # Analytics (_ga)
    {"cookie": {"name":"_ga","domain":"shop.example.com","path":"/","first_party":True,"max_age":63072000,"same_site":"Lax","secure":True,"http_only":False,"prefix":"","initiator_script_host":"www.googletagmanager.com","network_destinations":["www.google-analytics.com"]},
     "vendor_context": {"domain_vendor_id":"00000000-0000-4000-a000-000000000111","vendor_iab_purposes":[7],"risk_prior":0.2}},

    # Ads (_fbp)
    {"cookie": {"name":"_fbp","domain":".facebook.com","path":"/","first_party":False,"max_age":7776000,"same_site":"None","secure":True,"http_only":False,"prefix":"","initiator_script_host":"connect.facebook.net","network_destinations":["graph.facebook.com","www.facebook.com"]},
     "vendor_context": {"domain_vendor_id":"00000000-0000-4000-a000-000000000222","vendor_iab_purposes":[1,3,4,7],"risk_prior":0.6}},

    # Necessary (__Host-csrf)
    {"cookie": {"name":"__Host-csrf","domain":"app.example.com","path":"/","first_party":True,"max_age":3600,"same_site":"Strict","secure":True,"http_only":True,"prefix":"__Host-","initiator_script_host":"app.example.com","network_destinations":[]}}
]

for i, s in enumerate(samples, 1):
    print(f"--- Sample {i} ---")
    print(json.dumps(classify(s), indent=2))


In [None]:
# @title Start a minimal FastAPI server (optional)
import nest_asyncio, asyncio, uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional, List, Dict, Any

nest_asyncio.apply()

app = FastAPI(title="UCM Cookie Classifier (demo)")

class CookieIn(BaseModel):
    name: str
    domain: Optional[str] = ""
    path: Optional[str] = "/"
    first_party: Optional[bool] = True
    max_age: Optional[int] = 0
    same_site: Optional[str] = "Lax"
    secure: Optional[bool] = True
    http_only: Optional[bool] = False
    prefix: Optional[str] = ""
    initiator_script_host: Optional[str] = ""
    network_destinations: Optional[List[str]] = []
    creation_order_idx: Optional[int] = 0

class VendorCtx(BaseModel):
    domain_vendor_id: Optional[str] = None
    vendor_iab_purposes: Optional[List[int]] = []
    risk_prior: Optional[float] = 0.3

class ClassifyIn(BaseModel):
    cookie: CookieIn
    vendor_context: Optional[VendorCtx] = None

@app.post("/ai/classify")
def api_classify(body: ClassifyIn):
    return classify(json.loads(body.model_dump_json()))

config = uvicorn.Config(app, host="0.0.0.0", port=8000, log_level="warning")
server = uvicorn.Server(config)

asyncio.get_event_loop().create_task(server.serve())
print("✅ FastAPI running at http://127.0.0.1:8000  (POST /ai/classify)")


In [None]:
# @title Save model + artifacts to files (pickles/JSON) and ZIP them
import os, json, pickle, zipfile, datetime

ART_DIR = "/content/ucm_cookie_model_artifacts"
os.makedirs(ART_DIR, exist_ok=True)

# 1) XGBoost model (native JSON for portability)
model_path = os.path.join(ART_DIR, "model_xgb.json")
model.save_model(model_path)

# 2) DictVectorizer & LabelEncoder
with open(os.path.join(ART_DIR, "dict_vectorizer.pkl"), "wb") as f:
    pickle.dump(vec, f)
with open(os.path.join(ART_DIR, "label_encoder.pkl"), "wb") as f:
    pickle.dump(le, f)

# 3) Rulebook & Vendor DB (from earlier cells)
with open(os.path.join(ART_DIR, "rules.json"), "w") as f:
    json.dump(RULES, f, indent=2)
with open(os.path.join(ART_DIR, "vendors.json"), "w") as f:
    json.dump(VENDORS, f, indent=2)

# 4) Regulatory map + feature spec (helpful for service)
REG_MAP = {
    "necessary": {"tcf_purposes": [],       "gpp_impacts": [],                          "gcm": {"ad_user_data":"denied","ad_personalization":"denied","analytics_storage":"denied"}},
    "functional":{"tcf_purposes": [2],      "gpp_impacts": [],                          "gcm": {"ad_user_data":"denied","ad_personalization":"denied","analytics_storage":"denied"}},
    "analytics": {"tcf_purposes": [7],      "gpp_impacts": ["measurement"],             "gcm": {"ad_user_data":"denied","ad_personalization":"denied","analytics_storage":"granted"}},
    "ads":       {"tcf_purposes": [1,3,4],  "gpp_impacts": ["targeted_advertising","sharing"], "gcm": {"ad_user_data":"granted","ad_personalization":"granted","analytics_storage":"granted"}},
    "unknown":   {"tcf_purposes": [],       "gpp_impacts": [],                          "gcm": {"ad_user_data":"denied","ad_personalization":"denied","analytics_storage":"denied"}},
}
with open(os.path.join(ART_DIR, "regulatory_map.json"), "w") as f:
    json.dump(REG_MAP, f, indent=2)

# 5) Minimal feature schema hint (optional)
feature_schema = {
    "required_fields": ["cookie.name","cookie.domain","cookie.same_site","cookie.first_party","cookie.max_age"],
    "engineered_by_service": ["path_hash","ttl_bucket","entropy_proxy","cross_site_usage","name_regex_class"],
    "notes": "Raw cookie values are never stored; derive entropy proxies only."
}
with open(os.path.join(ART_DIR, "feature_schema.json"), "w") as f:
    json.dump(feature_schema, f, indent=2)

# 6) Bundle to a zip for easy download
zip_path = "/content/ucm_cookie_model_artifacts.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z:
    for root, _, files in os.walk(ART_DIR):
        for fn in files:
            full = os.path.join(root, fn)
            arc = os.path.relpath(full, ART_DIR)
            z.write(full, arc)

print("✅ Artifacts saved at:", ART_DIR)
print("📦 Zip bundle:", zip_path)


In [None]:
# @title FastAPI loader using saved artifacts (run after A)
import os, json, pickle, nest_asyncio, asyncio, uvicorn
import numpy as np, xgboost as xgb
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
import re, math, hashlib

ART_DIR = "/content/ucm_cookie_model_artifacts"

# --- Load artifacts ---
booster = xgb.Booster()
booster.load_model(os.path.join(ART_DIR, "model_xgb.json"))
with open(os.path.join(ART_DIR, "dict_vectorizer.pkl"), "rb") as f:
    vec = pickle.load(f)
with open(os.path.join(ART_DIR, "label_encoder.pkl"), "rb") as f:
    le = pickle.load(f)
with open(os.path.join(ART_DIR, "rules.json")) as f:
    RULES = json.load(f)
with open(os.path.join(ART_DIR, "vendors.json")) as f:
    VENDORS = json.load(f)
with open(os.path.join(ART_DIR, "regulatory_map.json")) as f:
    REG_MAP = json.load(f)

# --- Helpers (same as training-time, condensed) ---
def _entropy_from_len(val_len: int) -> float:
    return min(1.0, math.log2(max(1, val_len)) / 16.0)

def _ttl_bucket(ttl: int) -> str:
    if ttl is None: return "unknown"
    if ttl <= 86400: return "<=1d"
    if ttl <= 30*86400: return "1-30d"
    if ttl <= 400*86400: return "30-400d"
    return ">400d"

def _hash(s: str) -> int:
    return int(hashlib.md5(s.encode()).hexdigest(), 16) % 10000

def _name_regex_class(name: str):
    for rule in RULES.get("name_regex", []):
        if re.search(rule["pattern"], name or "", flags=re.I):
            return rule["id"], rule["purpose"], rule["confidence"]
    return None, None, None

def _prefix_rule(prefix: str):
    for rule in RULES.get("prefix", []):
        if prefix and (prefix.lower() == rule["prefix"].lower()):
            return rule["id"], rule["purpose"], rule["confidence"]
    return None, None, None

def _vendor_by_domain(domain: str):
    if not domain: return None
    domain = domain.lstrip(".").lower()
    cs = sorted(VENDORS.keys(), key=len, reverse=True)
    for cand in cs:
        if domain.endswith(cand):
            return VENDORS[cand]
    return None

def featurize(sample: dict):
    c = sample.get("cookie", {})
    v = sample.get("vendor_context", {}) or {}
    name = c.get("name","")
    domain = (c.get("domain","") or "").lower().lstrip(".")
    path = c.get("path","/")
    fp = bool(c.get("first_party", True))
    max_age = c.get("max_age", 0)
    samesite = (c.get("same_site") or "Lax")
    secure = bool(c.get("secure", True))
    http_only = bool(c.get("http_only", False))
    prefix = c.get("prefix") or ""
    initiator_host = (c.get("initiator_script_host") or "").lower()
    net_dests = c.get("network_destinations") or []
    creation_order = int(c.get("creation_order_idx", 0))

    name_rule_id, name_rule_purpose, name_rule_conf = _name_regex_class(name)
    prefix_rule_id, prefix_rule_purpose, prefix_rule_conf = _prefix_rule(prefix)

    vend = v or _vendor_by_domain(domain) or {}
    vend_id = vend.get("id")
    vend_risk = vend.get("risk_prior", 0.3)
    vend_purposes = vend.get("iab_purposes", [])

    fdict = {
        "name": name,
        "domain": domain,
        "path_hash": _hash(path),
        "first_party": int(fp),
        "max_age": int(max_age or 0),
        "ttl_bucket": _ttl_bucket(max_age or 0),
        "same_site": samesite,
        "secure": int(secure),
        "http_only": int(http_only),
        "prefix": prefix or "",
        "initiator_host": initiator_host,
        "creation_order": creation_order,
        "num_net_dests": len(net_dests),
        "dest_has_ga": int(any("google-analytics" in (d or "") for d in net_dests)),
        "dest_has_fb": int(any("facebook" in (d or "") for d in net_dests)),
        "entropy_proxy": _entropy_from_len(24 if prefix else 16),
        "cross_site_usage": int(samesite.lower()=="none" and not fp),
        "name_rule_hit": int(name_rule_id is not None),
        "prefix_rule_hit": int(prefix_rule_id is not None),
        "vendor_known": int(vend_id is not None),
        "vendor_risk": float(vend_risk),
        "vendor_declares_p7": int(7 in vend_purposes),
        "vendor_declares_ad": int(any(x in vend_purposes for x in [1,3,4])),
        "name_regex_class": name_rule_id or "none",
    }
    ctx = {
        "rules_hit": [r for r in [name_rule_id, prefix_rule_id] if r],
        "rule_purposes": [p for p in [name_rule_purpose, prefix_rule_purpose] if p],
        "rule_confidences": [c for c in [name_rule_conf, prefix_rule_conf] if c],
        "vendor_id": vend_id,
        "vendor_purposes": vend_purposes,
        "vendor_risk": vend_risk,
    }
    return fdict, ctx

def classify_payload(sample: dict):
    fdict, ctx = featurize(sample)
    rules_hit = list(ctx["rules_hit"])
    rule_purposes = ctx["rule_purposes"]
    rule_conf = max(ctx["rule_confidences"], default=0.0)

    # Rules-first (high-confidence)
    if rules_hit and rule_conf >= 0.95:
        purpose = rule_purposes[0]
        probs = {lab: (1.0 if lab==purpose else 0.0) for lab in le.classes_}
        conf = 0.95
        top_features = [{"feature":"rules", "value": rules_hit[0], "shap": 0.5}]
    else:
        X_inst = vec.transform([fdict])
        dm = xgb.DMatrix(X_inst, feature_names=vec.get_feature_names_out())
        proba = booster.predict(dm)[0]
        purpose = le.classes_[int(proba.argmax())]
        conf = float(proba.max())
        probs = {le.classes_[i]: float(proba[i]) for i in range(len(proba))}

        # Post-policy: conservative ads bump
        cross_site = bool(fdict["cross_site_usage"])
        long_ttl = fdict["ttl_bucket"] in ("30-400d", ">400d")
        vendor_ad = bool(fdict["vendor_declares_ad"])
        if purpose != "necessary" and cross_site and long_ttl and vendor_ad and probs.get("ads",0) >= 0.55:
            purpose = "ads"
            conf = max(conf, probs["ads"])

        # Lightweight indicative "top features" (no SHAP to keep deps small in service)
        feat_names = vec.get_feature_names_out()
        # Just take largest absolute values as a hint (proxy, not exact SHAP)
        order = np.argsort(np.abs(X_inst[0]))[::-1][:4]
        top_features = [{"feature": str(feat_names[i]), "value": str(X_inst[0][i])} for i in order]

    reg = REG_MAP.get(purpose, REG_MAP["unknown"])
    out = {
        "purpose": purpose,
        "probs": probs,
        "confidence": round(conf, 4),
        "top_features": top_features,
        "rules_hit": rules_hit,
        "regulatory": reg,
        "risk": round(min(1.0, 0.2 + (0.3 if purpose=="ads" else 0.0) + (0.1 if fdict["cross_site_usage"] else 0.0) + 0.2*fdict["vendor_risk"]), 3)
    }
    return out

# --- FastAPI types + route ---
class CookieIn(BaseModel):
    name: str
    domain: Optional[str] = ""
    path: Optional[str] = "/"
    first_party: Optional[bool] = True
    max_age: Optional[int] = 0
    same_site: Optional[str] = "Lax"
    secure: Optional[bool] = True
    http_only: Optional[bool] = False
    prefix: Optional[str] = ""
    initiator_script_host: Optional[str] = ""
    network_destinations: Optional[List[str]] = []
    creation_order_idx: Optional[int] = 0

class VendorCtx(BaseModel):
    domain_vendor_id: Optional[str] = None
    vendor_iab_purposes: Optional[List[int]] = []
    risk_prior: Optional[float] = 0.3

class ClassifyIn(BaseModel):
    cookie: CookieIn
    vendor_context: Optional[VendorCtx] = None

nest_asyncio.apply()
app = FastAPI(title="UCM Cookie Classifier (artifacts)")

@app.post("/ai/classify")
def api_classify(body: ClassifyIn):
    return classify_payload(json.loads(body.model_dump_json()))

config = uvicorn.Config(app, host="0.0.0.0", port=8001, log_level="warning")
server = uvicorn.Server(config)
asyncio.get_event_loop().create_task(server.serve())
print("✅ Artifact-backed API running at http://127.0.0.1:8001 (POST /ai/classify)")
