In [None]:
# %%script true
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
import scipy.stats
import geopandas as gpd
%matplotlib inline
from IPython.display import Markdown
from functools import reduce
from dash import Dash, dcc, html, Input, Output, State, ALL, MATCH
from dash.exceptions import PreventUpdate
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio

In [None]:
if 0:
    style = {
        "background-color": "#1b1b1b",  # rgb(27, 27, 27)
        "color": "white",  # font
    }
    pio.templates.default = "plotly_dark"
else:
    style = {}
    pio.templates.default = "plotly"

In [None]:
import sys

sys.path.append("..")

from os.path import join
import preprocessing

data_path = "../data"
table_names = [
    "RV_O_010_L_OK_SK.CSV",
    "RV_O_040_L_OK_SK.CSV",
    "RV_O_047_L_OK_SK.CSV",
    "RV_O_067_L_OK_SK.CSV",
]
tables = [
    preprocessing.translate_sex(
        preprocessing.rename_columns(pd.read_csv(join(data_path, table), sep=";"))
    )
    for table in table_names
]
table_10, table_40, table_47, table_67 = tables

districts_url = "https://bbrejova.github.io/viz/data/districts.json"
districts = gpd.read_file(districts_url)
# https://raw.githubusercontent.com/drakh/slovakia-gps-data/master/GeoJSON/epsg_4326/districts_epsg_4326.geojson
districts_geojson_url = join(data_path, "districts.geojson")
districts_geojson = gpd.read_file(districts_geojson_url, crs="EPSG:4326")

table_40["age"] = pd.to_numeric(table_40["age"].replace({"90 a viac rokov": "90"}))

def assert_connection():
    assert table_10 is tables[0]
    assert table_40 is tables[1]
    assert table_47 is tables[2]
    assert table_67 is tables[3]

In [None]:
districts_geojson_indexed = districts_geojson.set_index("IDN3")
districts_indexed = districts.set_index("IDN3")
districts_indexed[["geometry", "Area", "AreaHA"]] = districts_geojson_indexed[
    ["geometry", "Shape_Area", "VYMERA_ha"]
]
geo_frame = districts_indexed

In [None]:
from categorize_education import EDUCATION_CATEGORY_MAP
from isco_occupation import OCCUPATION_ISCO_MAP
from required_education import REQURED_EDUCATION_MAP

for table in tables:
    if "education" in table.columns:
        table["education_category"] = table["education"].map(EDUCATION_CATEGORY_MAP).astype('category')
    if "ISCO_occupation" in table.columns:
        table["ISCO_group"] = table["ISCO_occupation"].map(OCCUPATION_ISCO_MAP).astype('category')
        table["required_education"] = table["ISCO_occupation"].map(REQURED_EDUCATION_MAP).astype('category')

districts_indexed = districts.set_index("LAU1_CODE")
for table in tables:
    table.set_index("LAU1_CODE", inplace=True)
    table[["region_name", "NUTS3_CODE", "ecoregion_name", "NUTS2_CODE"]] = (
        districts_indexed[["NUTS3", "NUTS3_CODE", "NUTS2", "NUTS2_CODE"]]
    )
    table.reset_index(inplace=True)

In [None]:
for i in range(len(tables)):
    object_columns = [column for column in tables[i].columns if tables[i][column].dtype == 'object']
    tables[i][object_columns] = tables[i][object_columns].astype('string')
    tables[i][object_columns] = tables[i][object_columns].astype('category')

In [None]:
display('table_10')
table_10.info()
display('table_40')
table_40.info()
display('table_47')
table_47.info()
display('table_67')
table_67.info()
display('districts')
districts.info()
display('districts_geojson')
districts_geojson.info()

In [None]:
def compute_groups(data, query, groupby, value, restriction=""):
    """
    Arguments
        value - number | percent
    """
    if restriction != "":
        data = data.query(restriction)
    if query != "":
        filtered = data.query(query)
    else:
        filtered = data
    aggregated = filtered.groupby(groupby)["count"].sum().rename("number").to_frame()
    aggregated["total"] = data.groupby(groupby)["count"].sum()
    aggregated["percent"] = aggregated["number"] / aggregated["total"] * 100
    return aggregated


def plot_treemap(data, query, groupby, value, restriction=""):
    data = compute_groups(data, query, groupby, value, restriction)
    if groupby not in ["NUTS2_CODE", "NUTS3_CODE", "LAU1_CODE"]:
        return px.treemap(
            data.reset_index(),
            path=[groupby],
            values=value,
            hover_data=["number", "percent"],
        )
    else:
        merged = geo_frame.merge(data, on=groupby)
        return px.choropleth_mapbox(
            merged,
            geojson=merged.geometry,
            locations=merged.index,
            color=value,
            mapbox_style="carto-positron",
            center={"lat": 48.6737532, "lon": 19.696058},
            zoom=7,
            opacity=0.5,
            hover_data=["LAU1", "number", "percent"],
        )


fig = plot_treemap(
    table_40,
    "`education` == 'vysokoškolské vzdelanie - 1. stupeň (Bc.)'",
    groupby="region_name",
    value="number",
)
fig.show()

In [None]:
str(table_10['economical_age_groups'].dtype)
assert_connection()
list(map(pd.DataFrame.info, tables));

In [None]:
# %%script true # Skip
app = Dash(__name__)


def get_filters(data, filter_attributes):
    assert_connection()
    return [
        html.Div(
            [
                html.H4(f"Select {data[attr].name}"),
                {
                    # lambda for lazyness
                    "category": lambda: dcc.Dropdown(
                        id={"type": "checklist", "attr": attr},
                        options=data[attr].cat.categories,
                        persistence=True,
                        multi=True,
                    ),
                    "int64": lambda: dcc.RangeSlider(
                        id={"type": "range-slider", "attr": attr},
                        min=data[attr].min(),
                        max=data[attr].max() + 1,
                        step=1,
                        marks={
                            i: str(i)
                            for i in range(data[attr].min(), data[attr].max() + 2, 10)
                        },
                        value=[data[attr].min(), data[attr].max() + 1],
                        persistence=True,
                    ),
                }.get(str(data[attr].dtype).lower(), None)(),
            ],
        )
        for i, attr in enumerate(filter_attributes)
    ]


app.layout = html.Div(
    [
        html.Div(
            [
                html.H4("Select table"),
                dcc.Dropdown(
                    id="table-index",
                    options=[
                        {"value": v, "label": l} for v, l in enumerate(table_names)
                    ],
                ),
                html.H4("Select filter attributes"),
                dcc.Dropdown(
                    id="filter-attributes",
                    multi=True,
                    persistence=True,
                ),
                html.H4("Select groupby attributes"),
                dcc.Dropdown(
                    id="groupby-attributes",
                    multi=True,
                    persistence=True,
                ),
                html.Div(id="filter-zone"),
                html.H4("Enter title"),
                dcc.Input(
                    id="title", type="text", persistence=True, style={"width": "100%"}
                ),
                html.H4("Select groupby"),
                dcc.RadioItems(
                    id="groupby",
                    options=[],
                    persistence=True,
                ),
                html.H4("Select display value"),
                dcc.RadioItems(
                    ["number", "percent"],
                    "number",
                    id="display-value",
                    persistence=True,
                ),
            ],
            style={"flex": 1, "minWidth": 400, "padding": 10},
        ),
        # html.Br(),
        html.Div(
            [
                dcc.Graph(
                    id="line-plot",
                    style={"aspect-ratio": "1.6"},
                    figure=px.treemap(),
                ),
                dcc.Textarea(id="function-call", style={"width": "100%"}),
            ],
            style={"flex": 2, "padding": 10},
        ),
    ],
    style=style | {"padding": 10, "display": "flex", "flexDirection": "row"},
)


@app.callback(
    Output("filter-attributes", "options"),
    Output("groupby-attributes", "options"),
    [Input("table-index", "value")],
    prevent_initial_call=True,
)
def update_attributes(table_index):
    if table_index is None:
        raise PreventUpdate
    return [
        {
            column: f"{tables[table_index][column].dtype}: {column}"
            for column in tables[table_index].columns
        }
    ] * 2


# update filters and groupby
@app.callback(
    Output("filter-zone", "children"),
    Output("groupby", "options"),
    Input("groupby-attributes", "value"),
    Input("filter-attributes", "value"),
    State("table-index", "value"),
    prevent_initial_call=True,
)
def update_fg(groupby_attributes, filter_attributes, table_index):
    return [
        get_filters(tables[table_index], filter_attributes),
        groupby_attributes,
    ]


@app.callback(
    Output("line-plot", "figure"),
    Output("function-call", "value"),
    Input("title", "value"),
    Input("display-value", "value"),
    Input("groupby", "value"),
    Input({"type": "checklist", "attr": ALL}, "value"),
    Input({"type": "range-slider", "attr": ALL}, "value"),
    State({"type": "checklist", "attr": ALL}, "id"),
    State({"type": "range-slider", "attr": ALL}, "id"),
    State("table-index", "value"),
    prevent_initial_call=True,
)
def update_figure(
    title,
    display_value,
    groupby,
    checklists,
    range_sliders,
    checklists_id,
    range_sliders_id,
    table_index,
):
    if groupby is None:
        raise PreventUpdate

    # selected_options[:] = list(arg)
    query = " and ".join(
        [
            f"`{rid['attr']}`.isin({selected})"
            for rid, selected in zip(checklists_id, checklists)
            if selected is not None and len(selected) > 0
        ]
    )
    restriction = " and ".join(
        [
            f"{rrange[0]} <= `{rid['attr']}` < {rrange[1]}"
            for rid, rrange in zip(range_sliders_id, range_sliders)
            if rrange is not None
        ]
    )

    figure: go.Figure = plot_treemap(
        tables[table_index],
        query,
        groupby,
        display_value,
        restriction=restriction,
    )
    figure.update_layout(title=title)
    return [
        figure,
        f"plot_treemap(pd.read_csv({table_names[table_index]!r}, sep=';'), {query!r}, {groupby!r}, {display_value!r}, restriction={restriction!r})",
    ]


app.run_server(port=8054, debug=True, use_reloader=True)
pass