In [None]:
import json
import os


import ipywidgets as widgets
from IPython.display import display, clear_output
import plotly.express as px
from pandas import DataFrame


from otld import combine_appended_files
from otld.paths import input_dir
from otld.utils.string_utils import standardize_line_number

In [None]:
federal, state = combine_appended_files.main()
crosswalk = json.load(open(os.path.join(input_dir, "column_dict_196.json")))
crosswalk = {standardize_line_number(key): value for key, value in crosswalk.items()}
crosswalk.update({"5b": "Assistance: Child Care", "6b": "Non-assistance: Child Care"})

In [None]:
federal.head()

In [None]:
state.head()

In [None]:
def get_data(level: str) -> DataFrame:
    return globals()[level]

In [None]:
def longitudinal_line_one_col(df: str, state: str | list[str], column: str):
    """Line graphs of a single variable longitudinally, by state

    Args:
        df (str): _description_
        state (str | list[str]): _description_
        column (str): _description_

    Returns:
        _type_: _description_
    """

    df = df.lower()

    df = get_data(df).copy()

    if isinstance(state, str):


        df = df[df["STATE"] == state]
    elif isinstance(state, (list, tuple)):


        df = df[df["STATE"].isin(state)]


    df.loc[:, "STATE"] = df.loc[:, "STATE"].map(lambda x: x.title())


    fig = px.line(
        df,

        "year",
        column,
        color="STATE",
        labels={"year": "Year", column: crosswalk[column], "STATE": "State"},
        range_x=[],
    )


    fig.update_layout(xaxis={"dtick": 1})

    # fig.show()

    return fig

In [None]:
def within_year_bar(df: str, state: str, year: int, column: str | list[str] = None):
    df = df.lower()
    df = get_data(df).copy()

    df = df[(df["STATE"] == state) & (df["year"] == year)]
    if column:
        df = df[column]
    else:
        df.drop(["STATE", "year"], axis=1, inplace=True)

    df = df.rename(columns=crosswalk)
    df = df.melt()

    fig = px.bar(
        df,
        "variable",
        "value",
        labels={"value": "$Amount", "variable": "Line"},
        title=f"Line Values for {state.title()} in {year}",
    )

    return fig

In [None]:
def longitudinal_line_within_state(df: str, state: str, column: str | list[str]):
    df = df.lower()
    df = get_data(df).copy()

    df = df[df["STATE"] == state]
    if column:
        if isinstance(column, str):
            column = [column]

        column.append("year")

        df = df[column]
    else:
        df.drop(["STATE"], axis=1, inplace=True)

    df = df.rename(columns=crosswalk)
    df = df.melt(id_vars=["year"])

    fig = px.line(
        df,
        "year",
        "value",
        color="variable",
        labels={"year": "Year", "value": "$Amount", "variable": "Line"},
    )

    return fig

In [None]:
def display_line(func, state, column):
    display(
        widgets.VBox(
            [
                widgets.HBox([state, column]),
                func(df=federal, state=state.value, column=column.value),
            ]
        )
    )

In [None]:
def within_year_treemap(df: str, state: str, year: int, column: list[str] = None):
    df = df.lower()
    df = get_data(df).copy()

    if column:
        column = ["year", "STATE"] + column
        df = df[column]

    df = df.melt(id_vars=["year", "STATE"])
    df = df.fillna(0)
    df = df[(df["year"] == year) & (df["STATE"] == state)]
    df["variable"] = df.variable.map(crosswalk)
    df = df.sort_values(by=["variable"])

    fig = px.treemap(df, path=["year", "STATE", "variable"], values="value")
    fig.data[0].customdata = df.value.tolist()
    fig.data[0].texttemplate = "%{label}<br>Amount: $%{customdata:,}"

    return fig

In [None]:
def cross_state_treemap(df: str, year: int, column: str):
    df = df.lower()
    df = get_data(df).copy()

    df = df[df["year"] == year]
    df = df[["year", "STATE", column]]
    df = df.fillna(0)
    df = df.sort_values(by=["STATE"])

    fig = px.treemap(
        df,
        path=["year", "STATE"],
        values=column,
        title=f"Comparison of {crosswalk[column]} in {year}",
    )
    fig.data[0].customdata = df[column].tolist()
    fig.data[0].texttemplate = "%{label}<br>Amount: $%{customdata:,}"

    return fig

In [None]:
cross_state_treemap("federal", 200, "1")

In [None]:
state_select = widgets.SelectMultiple(
    options=federal["STATE"].unique(), value=["U.S. TOTAL"]
)
column_dropdown = widgets.Dropdown(options=federal.columns.to_list(), value="5")
df_select = widgets.Dropdown(options=["State", "Federal"], value="State")
year_dropdown = widgets.Dropdown(options=federal["year"].unique().tolist(), value=1997)

In [None]:
hbox_map = {
    "Longitudinal Line Chart One Column": [
        widgets.HBox([df_select, state_select, column_dropdown]),
        longitudinal_line_one_col,
    ],
    "Within-year Bar Chart Multiple Columns": [
        widgets.HBox([df_select, state_select, year_dropdown]),
        within_year_bar,
    ],
}

hbox_options = list(hbox_map.keys())
hbox_options = ["None"] + hbox_options
hbox_dropdown = widgets.Dropdown(options=hbox_options, value="None")

In [None]:
graph_output = widgets.Output()


def update_change_handler(graph: str):
    func = hbox_map[graph][1]

    def handler(change):
        with graph_output:
            clear_output()
            fig = func(df_select.value, state_select.value, column_dropdown.value)
            fig.show()

    return handler

In [None]:
def create_hbox(graph: str):
    return hbox_map[graph][0]


def render_controller(change):
    widgets_hbox = create_hbox(change["new"])
    output = widgets.Output()
    output.append_display_data(widgets_hbox)
    display(output)

    handle_change = update_change_handler(change["new"])

    df_select.observe(handle_change, names="value")
    state_select.observe(handle_change, names="value")
    column_dropdown.observe(handle_change, names="value")

    display(graph_output)


hbox_dropdown.observe(render_controller, names="value")
display(hbox_dropdown)

In [None]:
handle_change = update_change_handler("Longitudinal Line Chart One Column")

In [None]:
# widgets_output = widgets.Output()
# display(widgets_output)

# widgets_hbox = widgets.HBox([df_select, state_select, column_dropdown])
# widgets_output.append_display_data(widgets_hbox)

hbox_dropdown.observe(render_controller, names="value")
df_select.observe(handle_change, names="value")
state_select.observe(handle_change, names="value")
column_dropdown.observe(handle_change, names="value")

display(hbox_dropdown)
display(graph_output)