In [3]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go

metadata_path = "../data/CA_metadata_and_annual_results.csv"

problem_counties = [
    'Calaveras County', 'Toulumne County', 'El Dorado County', 'Placer County',
    'Yuba County', 'Glenn County', 'Lassen County', 'Siskiyou County', 'Modoc County', 'Inyo County'
]

problem_counties = [
    'Alameda County'
]

HOUSING_NAME_MAP = {
    "single-family-detached": "Single-Family Detached",
    "single-family-attached": "Single-Family Attached",
}
housing_type = "single-family-detached"
housing_value = HOUSING_NAME_MAP[housing_type]
metadata = pd.read_csv(metadata_path, low_memory=False)

# Subset only the target counties...
subset = metadata[metadata["in.county_name"].isin(problem_counties)].copy()

# Define filters in the order I want them applied
# function returns a boolean mask
filters = [
    ("Upgrade = 0", lambda df: df["upgrade"] == 0),
    ("Housing Type", lambda df: df["in.geometry_building_type_recs"] == housing_value),
    ("Occupied", lambda df: df["in.vacancy_status"] == "Occupied"),
    ("Cooking Range = Gas", lambda df: df["in.cooking_range"].isin(["Gas"])),
    ("Heating Fuel = Natural Gas", lambda df: df["in.heating_fuel"] == "Natural Gas"),
    ("Water Heater Fuel = Natural Gas", lambda df: df["in.water_heater_fuel"] == "Natural Gas"),
    # ("Has PV = No", lambda df: df["in.has_pv"] == "No"),
    # ("HVAC Cooling Type = NaN", lambda df: df["in.hvac_cooling_type"].isna()),
    ("Tenure = Owner", lambda df: df["in.tenure"] == "Owner"),
]

def create_sankey_for_county(df_county, county_name, filters):
    # 1) Calculate pass counts at each step
    pass_counts = []
    current_mask = np.ones(len(df_county), dtype=bool)

    # Step 0: total
    pass_counts.append(current_mask.sum())

    for _, filter_func in filters:
        current_mask = current_mask & filter_func(df_county)
        pass_counts.append(current_mask.sum())

    # 2) Create labels for each node
    # num of nodes is len(filters) + 1 (including the "Total" step)
    step_labels = []
    num_steps = len(pass_counts)
    for i in range(num_steps):
        if i == 0:
            # First node: "Total"
            # step_labels.append(f"Total\n(N = {pass_counts[i]})")
            print("first node")
        else:
            # Subsequent nodes: use the filter name + pass count
            label = filters[i-1][0]  # filter label
            step_labels.append(f"{label}\n(N = {pass_counts[i]})")

    # 3) Build the Sankey "links"
    # For each step i, link node i to node i+1 with the value pass_counts[i+1].
    source_indices = []
    target_indices = []
    values = []
    link_labels = []
    for i in range(num_steps - 1):
        source_indices.append(i)
        target_indices.append(i + 1)
        values.append(pass_counts[i + 1])
        link_labels.append(f"Passed: {pass_counts[i + 1]}")

    # 4) Create the Sankey flame graph
    fig = go.Figure(data=[go.Sankey(
        arrangement="snap",
        node=dict(
            label=step_labels,
            pad=15,
            thickness=20
        ),
        link=dict(
            source=source_indices,
            target=target_indices,
            value=values,
            label=link_labels  # show how many pass at each link
        )
    )])

    fig.update_layout(
        title_text=f"Filtered Buildings in {county_name} ({housing_type})",
        font_size=10
    )
    return fig

for county in problem_counties:
    df_county = subset[subset["in.county_name"] == county]
    if df_county.empty:
        print(f"No data for {county} in the metadata.")
        continue
    
    fig = create_sankey_for_county(df_county, county, filters)
    fig.show()

first node
