In [None]:
import json
import re
from typing import Dict, List

import numpy as np
import pandas as pd
import streamlit as st

st.set_page_config(page_title="Consistency Checker â€” Advanced", layout="wide")
st.title("ðŸ“Š Questionnaire Logic Checker â€” Advanced Addâ€‘on")
st.caption("Brand-pattern rules, cross-wave checks, and client sample validations.")

# =============================
# Base uploads
# =============================
with st.sidebar:
    st.header("Inputs")
    uploaded_file = st.file_uploader("Current wave data (CSV)", type=["csv"], key="this")
    client_sample_file = st.file_uploader("Client sample (CSV)", type=["csv"], key="client")
    last_wave_file = st.file_uploader("Last wave data (CSV)", type=["csv"], key="prev")
    desk_research_file = st.file_uploader("Desk research (CSV)", type=["csv"], key="desk")
    st.markdown("---")
    st.subheader("Thresholds")
    main_cap = st.slider("A1a brand count cap", 3, 20, 8)
    close_min = st.slider("C-close expected minimum when strong intent", 5, 10, 8)

if not uploaded_file:
    st.info("Upload your current wave CSV to start.")
    st.stop()

# Load data
try:
    df = pd.read_csv(uploaded_file)
except Exception as e:
    st.error(f"Failed to read current wave CSV: {e}")
    st.stop()

client_df = pd.read_csv(client_sample_file) if client_sample_file else None
last_wave_df = pd.read_csv(last_wave_file) if last_wave_file else None
desk_df = pd.read_csv(desk_research_file) if desk_research_file else None

# Result frame
result = df.copy()

# Helpers

def ensure_in_list(val, allowed: List[int]) -> bool:
    try:
        vi = int(float(val))
        return vi in allowed
    except Exception:
        return False


def boolish(val) -> bool:
    s = str(val).strip().lower()
    return s in {"1", "true", "yes", "y", "t"} or ensure_in_list(val, [1])


def find_brand_columns(frame: pd.DataFrame, prefix: str) -> Dict[str, str]:
    pat = re.compile(r"^" + re.escape(prefix) + r"[_\- ]?(?P<brand>.+)$", re.IGNORECASE)
    out: Dict[str, str] = {}
    for c in frame.columns:
        m = pat.match(c)
        if m:
            out[m.group("brand")] = c
    return out

# 1) S2 vs S3
if "S2" in df.columns and "S3" in df.columns:
    mis = df["S2"].astype(str) != df["S3"].astype(str)
    result.loc[mis, "CHK_S2_vs_S3"] = "Mismatch"
    result.loc[~mis, "CHK_S2_vs_S3"] = "OK"

# 2) S4a1 vs client sample year
if client_df is not None and "S4a1" in df.columns:
    id_col = next((c for c in ["ID", "RespondentID", "RID", "id"] if c in df.columns and c in client_df.columns), None)
    year_col = next((c for c in client_df.columns if re.search(r"year", c, re.IGNORECASE)), None)
    if id_col and year_col:
        merged = df[[id_col, "S4a1"]].merge(client_df[[id_col, year_col]], on=id_col, how="left")
        y1 = merged["S4a1"].astype(str).str.extract(r"(\d{4})")[0]
        y2 = merged[year_col].astype(str).str.extract(r"(\d{4})")[0]
        mis_ids = merged.loc[y1 != y2, id_col]
        result.loc[result[id_col].isin(mis_ids), "CHK_S4a1_vs_ClientYear"] = "Mismatch"
        result["CHK_S4a1_vs_ClientYear"] = result["CHK_S4a1_vs_ClientYear"].fillna("OK")

# 3) A1a total sanity and cap vs S3
a1a_cols = [c for c in df.columns if c.lower().startswith("a1a_")]
if a1a_cols:
    sel = df[a1a_cols].applymap(boolish)
    counts = sel.sum(axis=1)
    result["CHK_A1a_total_count"] = counts
    result.loc[counts > main_cap, "CHK_A1a_total_flag"] = f">{main_cap} brands"
    result.loc[counts <= main_cap, "CHK_A1a_total_flag"] = "OK"
    if "S3" in df.columns:
        s3n = pd.to_numeric(df["S3"], errors="coerce")
        result.loc[(s3n.notna()) & (counts < s3n), "CHK_A1a_vs_S3"] = "A1a < S3"
        result.loc[(s3n.notna()) & (counts >= s3n), "CHK_A1a_vs_S3"] = result.loc[(s3n.notna()) & (counts >= s3n), "CHK_A1a_vs_S3"].fillna("OK")

# 4) A2b main make âˆˆ A1a
a2b_col = next((c for c in df.columns if c.lower().startswith("a2b")), None)
if a2b_col and a1a_cols:
    def a2b_in_a1a(row) -> bool:
        brand = str(row[a2b_col]).strip()
        if not brand or brand.lower() == "nan":
            return True
        matches = [c for c in a1a_cols if c.split("_", 1)[-1].strip().lower() == brand.lower()]
        if not matches:
            return False
        return any(boolish(row[m]) for m in matches)
    ok = df.apply(a2b_in_a1a, axis=1)
    result.loc[~ok, "CHK_A2b_in_A1a"] = "Main make not in A1a"
    result["CHK_A2b_in_A1a"] = result["CHK_A2b_in_A1a"].fillna("OK")

# 5 & 16) S4a1 year âˆˆ A3 years
a3_cols = [c for c in df.columns if c.lower().startswith("a3")]
if "S4a1" in df.columns and a3_cols:
    def a3_has_year(row) -> bool:
        txt = str(row["S4a1"]) if pd.notna(row["S4a1"]) else ""
        m = re.search(r"(\d{4})", txt)
        if not m:
            return True
        y = m.group(1)
        for c in a3_cols:
            if y in str(row[c]):
                return True
        return False
    ok = df.apply(a3_has_year, axis=1)
    result.loc[~ok, "CHK_S4a1_vs_A3"] = "Year not listed in A3"
    result["CHK_S4a1_vs_A3"] = result["CHK_S4a1_vs_A3"].fillna("OK")

# 7) B1 vs A2a (used brands must have B1 âˆˆ {4,5})
b1_map = find_brand_columns(df, "B1")
a2a_map = find_brand_columns(df, "A2a")
for brand, a2a_c in a2a_map.items():
    b1_c = b1_map.get(brand)
    if not b1_c:
        continue
    used = df[a2a_c].apply(boolish)
    bad = used & ~df[b1_c].apply(lambda x: ensure_in_list(x, [4, 5]))
    result.loc[bad, f"CHK_B1_for_used_{brand}"] = "B1 should be 4/5"

# 8) B2 vs B3a (consider â†’ good impression 4/5)
b2_map = find_brand_columns(df, "B2")
b3a_map = find_brand_columns(df, "B3a")
for brand, b3a_c in b3a_map.items():
    b2_c = b2_map.get(brand)
    if not b2_c:
        continue
    cons = df[b3a_c].apply(boolish)
    bad = cons & ~df[b2_c].apply(lambda x: ensure_in_list(x, [4, 5]))
    result.loc[bad, f"CHK_B2_for_consider_{brand}"] = "B2 should be 4/5"

# 9) B3b mentioned â†’ B2 4/5
b3b_map = find_brand_columns(df, "B3b")
for brand, b3b_c in b3b_map.items():
    b2_c = b2_map.get(brand)
    if not b2_c:
        continue
    mentioned = df[b3b_c].apply(boolish)
    bad = mentioned & ~df[b2_c].apply(lambda x: ensure_in_list(x, [4, 5]))
    result.loc[bad, f"CHK_B2_for_B3b_{brand}"] = "B2 should be 4/5"

# 10) C-close vs B3b/B3a/B2 (expect >= close_min)
cclose_map = find_brand_columns(df, "Cclose") or find_brand_columns(df, "C_close")
for brand, c_c in cclose_map.items():
    strong = pd.Series(False, index=df.index)
    if brand in b3b_map:
        strong |= df[b3b_map[brand]].apply(boolish)
    if brand in b3a_map:
        strong |= df[b3a_map[brand]].apply(boolish)
    if brand in b2_map:
        strong |= df[b2_map[brand]].apply(lambda x: ensure_in_list(x, [4, 5]))
    bad = strong & ~df[c_c].apply(lambda x: ensure_in_list(x, list(range(close_min, 11))))
    result.loc[bad, f"CHK_Cclose_high_{brand}"] = f"Expect â‰¥{close_min}"

# 11) Cfunc vs B2 (monotonic alignment)
cfunc_map = find_brand_columns(df, "Cfunc")
for brand, cf_c in cfunc_map.items():
    b2_c = b2_map.get(brand)
    if not b2_c:
        continue
    def misaligned(r):
        b2v = pd.to_numeric(r[b2_c], errors="coerce")
        cfv = pd.to_numeric(r[cf_c], errors="coerce")
        if pd.isna(b2v) or pd.isna(cfv):
            return False
        return (b2v < 4) and (cfv >= 4)
    bad = df.apply(misaligned, axis=1)
    result.loc[bad, f"CHK_Cfunc_vs_B2_{brand}"] = "Misaligned"

# 12) G2 vs G1 (industry-based rule; example for Mining)
if "G1" in df.columns and "G2" in df.columns:
    def g_rule(row):
        industry = str(row["G1"]).strip().lower()
        try:
            dist = float(row["G2"])  # km
        except Exception:
            return True
        if "mining" in industry:
            return dist <= 50
        return True
    ok = df.apply(g_rule, axis=1)
    result.loc[~ok, "CHK_G2_vs_G1"] = "Range not plausible for industry"

# 23) Straight-liners in F2/F4/F6
for blk in ["F2", "F4", "F6"]:
    blk_cols = [c for c in df.columns if c.lower().startswith(blk.lower() + "_")]
    if blk_cols:
        vals = df[blk_cols].apply(pd.to_numeric, errors="coerce")
        straight = vals.nunique(axis=1) == 1
        result.loc[straight, f"CHK_{blk}_straightliner"] = "Straight-liner"

# 24) B3a vs E4 (consider â†’ probably/definitely)
e4_map = find_brand_columns(df, "E4")
for brand, b3a_c in b3a_map.items():
    e4_c = e4_map.get(brand)
    if not e4_c:
        continue
    sel = df[b3a_c].apply(boolish)
    bad = sel & ~df[e4_c].apply(lambda x: ensure_in_list(x, [4, 5]))
    result.loc[bad, f"CHK_E4_for_B3a_{brand}"] = "E4 should be 4/5"

# 25) E1/E4/E4c vs B2 (positive B2 â†’ higher E*)
e1_map = find_brand_columns(df, "E1")
e4c_map = find_brand_columns(df, "E4c")
for brand, b2_c in b2_map.items():
    pos = df[b2_c].apply(lambda x: ensure_in_list(x, [4, 5]))
    for lab, m in [("E1", e1_map), ("E4", e4_map), ("E4c", e4c_map)]:
        if m and brand in m:
            bad = pos & ~df[m[brand]].apply(lambda x: ensure_in_list(x, [4, 5]))
            result.loc[bad, f"CHK_{lab}_for_posB2_{brand}"] = f"{lab} should be 4/5"

# 31, 33, 34) Combined intent consistency
for brand in set(b3a_map) & set(b2_map) & set(cclose_map):
    sel = df[b3a_map[brand]].apply(boolish)
    b2_hi = df[b2_map[brand]].apply(lambda x: ensure_in_list(x, [5]))
    b2_low = df[b2_map[brand]].apply(lambda x: ensure_in_list(x, [1, 2, 3]))
    c_hi = df[cclose_map[brand]].apply(lambda x: ensure_in_list(x, list(range(close_min, 11))))
    result.loc[sel & b2_low, f"CHK_B2_low_with_B3a_{brand}"] = "Conflict"
    result.loc[sel & b2_hi & ~c_hi, f"CHK_Cclose_low_with_hiB2_{brand}"] = f"Expect â‰¥{close_min}"

# 37) E1 vs F1 proximity
if "E1" in df.columns and "F1" in df.columns:
    e1 = pd.to_numeric(df["E1"], errors="coerce")
    f1 = pd.to_numeric(df["F1"], errors="coerce")
    result.loc[(e1 - f1).abs() > 2, "CHK_E1_vs_F1"] = ">2 pts diff"
    result["CHK_E1_vs_F1"] = result["CHK_E1_vs_F1"].fillna("OK")

# Cross-wave summaries (14, 15, 17, 35, 36) â€” quick metrics
notes = []
if last_wave_df is not None:
    a1a_map_now = find_brand_columns(df, "A1a")
    a2a_map_now = find_brand_columns(df, "A2a")
    if a1a_map_now and a2a_map_now:
        this_rates = []
        for b, a2 in a2a_map_now.items():
            a1 = a1a_map_now.get(b)
            if not a1:
                continue
            r = ((~df[a1].apply(boolish)) & df[a2].apply(boolish)).mean()
            this_rates.append(r)
        if this_rates:
            notes.append(f"This wave A2a-not-in-A1a avg: {np.mean(this_rates):.1%}")
    a1a_map_prev = find_brand_columns(last_wave_df, "A1a")
    a2a_map_prev = find_brand_columns(last_wave_df, "A2a")
    prev_rates = []
    for b, a2 in a2a_map_prev.items():
        a1 = a1a_map_prev.get(b)
        if not a1:
            continue
        r = ((~last_wave_df[a1].apply(boolish)) & last_wave_df[a2].apply(boolish)).mean()
        prev_rates.append(r)
    if prev_rates:
        notes.append(f"Last wave avg: {np.mean(prev_rates):.1%}")

if desk_df is not None:
    notes.append("Desk research attached: use for availability/awareness plausibility checks.")

st.subheader("Results")
st.dataframe(result, use_container_width=True)

csv = result.to_csv(index=False).encode("utf-8")
st.download_button("ðŸ’¾ Download checked CSV", csv, file_name="checked_advanced.csv", mime="text/csv")

if notes:
    st.info("\n".join(notes))

st.markdown("---")
st.markdown("**Notes:** This addâ€‘on implements a first pass of rules 1, 2, 3, 4, 5/16, 7, 8, 9, 10, 11, 12, 23, 24, 25, 31/33/34, and 37. Others require custom mappings or manual recode inputs (industry/region recodes, openâ€‘end cleaning, desk research comparisons, Volvo/BCS specifics). Upload lastâ€‘wave and client sample files to enable the related checks.")
