In [None]:
# %%script true
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
import scipy.stats
import geopandas as gpd
%matplotlib inline
from IPython.display import Markdown
from functools import reduce
from dash import Dash, dcc, html, Input, Output, State, ALL, MATCH
from dash.exceptions import PreventUpdate
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
pd.options.display.float_format = '{:,.2f}'.format

In [None]:
if 0:
    style = {
        "background-color": "#1b1b1b",  # rgb(27, 27, 27)
        "color": "white",  # font
    }
    pio.templates.default = "plotly_dark"
else:
    style = {}
    pio.templates.default = "plotly"

In [None]:
import sys

sys.path.append("..")

from os.path import join
import preprocessing

data_path = "../data"
table_names = [
    "RV_O_010_L_OK_SK.CSV",
    "RV_O_040_L_OK_SK.CSV",
    "RV_O_047_L_OK_SK.CSV",
    "RV_O_067_L_OK_SK.CSV",
]
tables = [
    preprocessing.translate_sex(
        preprocessing.rename_columns(pd.read_csv(join(data_path, table), sep=";"))
    )
    for table in table_names
]
table_10, table_40, table_47, table_67 = tables

districts_url = "https://bbrejova.github.io/viz/data/districts.json"
districts = gpd.read_file(districts_url)
# https://raw.githubusercontent.com/drakh/slovakia-gps-data/master/GeoJSON/epsg_4326/districts_epsg_4326.geojson
districts_geojson_url = join(data_path, "districts.geojson")
districts_geojson = gpd.read_file(districts_geojson_url, crs="EPSG:4326")

table_40["age"] = pd.to_numeric(table_40["age"].replace({"90 a viac rokov": "90"}))

def assert_connection():
    assert table_10 is tables[0]
    assert table_40 is tables[1]
    assert table_47 is tables[2]
    assert table_67 is tables[3]

In [None]:
districts_geojson_indexed = districts_geojson.set_index("IDN3")
districts_indexed = districts.set_index("IDN3")
districts_indexed[["geometry", "Area", "AreaHA"]] = districts_geojson_indexed[
    ["geometry", "Shape_Area", "VYMERA_ha"]
]
geo_frame = districts_indexed

In [None]:
from categorize_education import EDUCATION_CATEGORY_MAP
from isco_occupation import OCCUPATION_ISCO_MAP
from required_education import REQURED_EDUCATION_MAP

for table in tables:
    if "education" in table.columns:
        table["education_category"] = table["education"].map(EDUCATION_CATEGORY_MAP).astype('category')
    if "ISCO_occupation" in table.columns:
        table["ISCO_group"] = table["ISCO_occupation"].map(OCCUPATION_ISCO_MAP).astype('category')
        table["required_education"] = table["ISCO_occupation"].map(REQURED_EDUCATION_MAP).astype('category')

districts_indexed = districts.set_index("LAU1_CODE")
for table in tables:
    if "LAU1_CODE" in table:
        table.set_index("LAU1_CODE", inplace=True)
        table[["region_name", "NUTS3_CODE", "ecoregion_name", "NUTS2_CODE"]] = (
            districts_indexed[["NUTS3", "NUTS3_CODE", "NUTS2", "NUTS2_CODE"]]
        )
        table['state_name'] = 'Slovensko'
        table.reset_index(inplace=True)

In [None]:
for i in range(len(tables)):
    object_columns = [column for column in tables[i].columns if tables[i][column].dtype == 'object']
    tables[i][object_columns] = tables[i][object_columns].astype('string')
    tables[i][object_columns] = tables[i][object_columns].astype('category')

In [None]:
display(table_47['required_education'].cat.categories)
display('table_10')
table_10.info()
display('table_40')
table_40.info()
display('table_47')
table_47.info()
display('table_67')
table_67.info()
display('districts')
districts.info()
display('districts_geojson')
districts_geojson.info()

In [None]:
def compute_groups(data, groupby, chosen_query="", filter_query=""):
    """
    Arguments
        filter_query - filters data
        chosen_query - will be used to count ratio
    """
    if filter_query != "":
        data = data.query(filter_query)
    if chosen_query != "":
        selected = data.query(chosen_query)
    else:
        selected = data
    aggregated = selected.groupby(groupby)["count"].sum().rename("number").to_frame()
    aggregated["number_percent"] = aggregated["number"] / selected["count"].sum() * 100
    aggregated["total"] = data.groupby(groupby)["count"].sum()
    aggregated["percent"] = aggregated["number"] / aggregated["total"] * 100
    return aggregated


def plot_groups(data, groupby, value, chosen_query="", filter_query=""):
    data = compute_groups(data, groupby, chosen_query, filter_query)
    if groupby not in ["NUTS2_CODE", "NUTS3_CODE", "LAU1_CODE"]:
        # hierarchy = [['state_name', 'ecoregion_name', 'region_name', 'district_name']]
        # h = next(filter(lambda p: groupby in p[1], enumerate(hierarchy)), [None])[0]
        data = data.reset_index()
        data = data.sort_values(by=value, ascending=False)
        if value == "percent":
            return px.bar(
                data,
                y=groupby,
                color=groupby,
                x="percent",
                orientation='h',
                hover_data=["number", "percent"],
            )
        elif value == "number":
            return px.treemap(
                data,
                path=[px.Constant("all"), groupby],
                values="number",
                hover_data=["number", "number_percent"],
            )
    else:
        merged = geo_frame.merge(data, on=groupby)
        return px.choropleth_mapbox(
            merged,
            geojson=merged.geometry,
            locations=merged.index,
            color=value,
            mapbox_style="carto-positron",
            center={"lat": 48.6737532, "lon": 19.696058},
            zoom=7,
            opacity=0.5,
            hover_data=["LAU1", "number", "percent"],
        )


fig = plot_groups(
    table_40,
    groupby="region_name",
    value="percent",
    chosen_query="`education` == 'vysokoškolské vzdelanie - 1. stupeň (Bc.)'",
)
fig.show()

In [None]:
assert_connection()

In [None]:
# %%script true # Skip
app = Dash(__name__)


def get_selectivity(data, attributes, type):
    assert_connection()
    return [
        html.Div(
            [
                html.H4(f"Select {data[attr].name}"),
                {
                    # lambda for lazyness
                    "category": lambda: dcc.Dropdown(
                        id={"type": type, "attr": attr},
                        options=data[attr].cat.categories,
                        persistence=True,
                        multi=True,
                    ),
                    "int64": lambda: dcc.RangeSlider(
                        id={"type": type, "attr": attr},
                        min=data[attr].min(),
                        max=data[attr].max() + 1,
                        step=1,
                        marks={
                            i: str(i)
                            for i in range(data[attr].min(), data[attr].max() + 2, data[attr].r)
                        },
                        value=[data[attr].min(), data[attr].max() + 1],
                        persistence=True,
                    ),
                }.get(str(data[attr].dtype).lower(), None)(),
            ]
        )
        for i, attr in enumerate(attributes)
    ]


figure = go.Figure()
figure.add_annotation(
    name="draft watermark",
    text="START",
    textangle=-30,
    opacity=0.1,
    font=dict(color="black", size=100),
    xref="paper",
    yref="paper",
    x=0.5,
    y=0.5,
    showarrow=False,
)
app.layout = html.Div(
    [
        html.Div(
            [
                html.H4("Select table"),
                dcc.Dropdown(
                    id="table-index",
                    options=[
                        {"value": v, "label": l} for v, l in enumerate(table_names)
                    ],
                ),
                html.H4("Select chosen/percented attributes"),
                dcc.Dropdown(
                    id="chosen-attributes",
                    multi=True,
                    persistence=True,
                ),
                html.H4("Select filter attributes"),
                dcc.Dropdown(
                    id="filter-attributes",
                    multi=True,
                    persistence=True,
                ),
                html.H4("Select groupby attributes"),
                dcc.Dropdown(
                    id="groupby-attributes",
                    multi=True,
                    persistence=True,
                ),
                html.H4("Choose zone"),
                html.Div(id="choose-zone"),
                html.H4("Filter zone"),
                html.Div(id="filter-zone"),
                html.H4("Enter title"),
                dcc.Input(
                    id="title", type="text", persistence=True, style={"width": "100%"}
                ),
                html.H4("Select groupby"),
                dcc.RadioItems(
                    id="groupby",
                    options=[],
                    persistence=True,
                ),
                html.H4("Select display value"),
                dcc.RadioItems(
                    ["number", "percent"],
                    "number",
                    id="display-value",
                    persistence=True,
                ),
            ],
            style={"flex": 1, "minWidth": 400, "padding": 10},
        ),
        # html.Br(),
        html.Div(
            [
                dcc.Graph(
                    id="line-plot",
                    style={"aspect-ratio": "1.6"},
                    figure=figure,
                ),
                dcc.Textarea(id="function-call", style={"width": "100%"}),
            ],
            style={"flex": 2, "padding": 10},
        ),
    ],
    style=style | {"padding": 10, "display": "flex", "flexDirection": "row"},
)


@app.callback(
    Output("chosen-attributes", "options"),
    Output("filter-attributes", "options"),
    Output("groupby-attributes", "options"),
    [Input("table-index", "value")],
    prevent_initial_call=True,
)
def update_attributes(table_index):
    if table_index is None:
        raise PreventUpdate
    return [
        {
            column: f"{tables[table_index][column].dtype}: {column}"
            for column in tables[table_index].columns
        }
    ] * 3


# update filters and groupby
@app.callback(
    Output("choose-zone", "children"),
    Output("filter-zone", "children"),
    Output("groupby", "options"),
    Input("chosen-attributes", "value"),
    Input("filter-attributes", "value"),
    Input("groupby-attributes", "value"),
    State("table-index", "value"),
    prevent_initial_call=True,
)
def update_fg(chosen_attributes, filter_attributes, groupby_attributes, table_index):
    return [
        get_selectivity(tables[table_index], chosen_attributes, type="chosen"),
        get_selectivity(tables[table_index], filter_attributes, type="filter"),
        groupby_attributes,
    ]


@app.callback(
    Output("line-plot", "figure"),
    Output("function-call", "value"),
    Input("title", "value"),
    Input("display-value", "value"),
    Input("groupby", "value"),
    Input({"type": "chosen", "attr": ALL}, "value"),
    State({"type": "chosen", "attr": ALL}, "id"),
    Input({"type": "filter", "attr": ALL}, "value"),
    State({"type": "filter", "attr": ALL}, "id"),
    State("table-index", "value"),
    prevent_initial_call=True,
)
def update_figure(
    title,
    display_value,
    groupby,
    chosen,
    chosen_id,
    filter,
    filter_id,
    table_index,
):
    if groupby is None:
        raise PreventUpdate

    # selected_options[:] = list(arg)
    data = tables[table_index]
    chosen_query = " and ".join(
        [
            {
                "category": lambda: (
                    f"`{rid['attr']}`.isin({entry})" if len(entry) > 0 else ""
                ),
                "int64": lambda: f"{entry[0]} <= `{rid['attr']}` < {entry[1]}",
            }.get(str(data[rid["attr"]].dtype).lower(), None)()
            for rid, entry in zip(chosen_id, chosen)
            if entry is not None
            and {
                "category": lambda: len(entry) > 0,
                "int64": lambda: len(entry) == 2,
            }.get(str(data[rid["attr"]].dtype).lower(), None)()
        ]
    )
    filter_query = " and ".join(
        [
            {
                "category": lambda: (
                    f"`{rid['attr']}`.isin({entry})" if len(entry) > 0 else ""
                ),
                "int64": lambda: f"{entry[0]} <= `{rid['attr']}` < {entry[1]}",
            }.get(str(data[rid["attr"]].dtype).lower(), None)()
            for rid, entry in zip(filter_id, filter)
            if entry is not None
            and {
                "category": lambda: len(entry) > 0,
                "int64": lambda: len(entry) == 2,
            }.get(str(data[rid["attr"]].dtype).lower(), None)()
        ]
    )
    figure = go.Figure()
    figure.update_layout(title=title)
    try:
        figure: go.Figure = plot_groups(
            tables[table_index], groupby, display_value, chosen_query, filter_query
        )
    except Exception as e:
        figure.add_annotation(
            name="draft watermark",
            text="ERROR",
            textangle=-30,
            opacity=0.1,
            font=dict(color="red", size=100),
            xref="paper",
            yref="paper",
            x=0.5,
            y=0.5,
            showarrow=False,
        )
        print(e)
    return [
        figure,
        f"plot_groups(pd.read_csv({table_names[table_index]!r}, sep=';'), {groupby=!r}, {display_value=!r}, {chosen_query=!r}, {filter_query=!r})",
    ]


app.run_server(port=8054, debug=True, use_reloader=True)
pass