In [None]:
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import polars as pl
import yaml

from nixos.survey_analysis.answers import (
    get_answers_by_choice_id,
    get_categorical_question_answers,
)
from nixos.survey_analysis.config import Config
from nixos.survey_analysis.plot import (
    NotMultileChoicesQuestion,
    plot_categorical_question,
    plot_multiple_choices_question,
)

from nixos.survey_analysis.questions import (
    Choice,
    ChoicesByQuestionIdByType,
    group_by_question_type,
    group_choices_by_question_id,
    ungrouped_choices_from_columns,
)

In [None]:
with open("../configs/2023.yml") as f:
    config = Config.model_validate(yaml.safe_load(f))

df = pl.read_csv("../data/results-survey2023.csv")
questions_by_id = group_choices_by_question_id(
    ungrouped_choices_from_columns(
        columns=df.columns,
    ),
)
# remove choices that are excluded in the config
for question_id, choices in questions_by_id.items():
    if question_id not in config.questions:
        continue
    choices_to_exclude = config.questions[question_id].exclude
    if choices_to_exclude is None:
        continue
    questions_by_id[question_id] = [
        choice for choice in choices if choice.choice_id not in choices_to_exclude
    ]
questions_by_id_by_type = group_by_question_type(
    choices_by_question_id=questions_by_id,
    df=df,
)

In [None]:
gender_column = questions_by_id_by_type.categorical["gender"][0].column
region_column = questions_by_id_by_type.categorical["region"][0].column
gender_column, region_column

In [None]:
df_count_gender_per_region = (
    df
    .rename({gender_column: "gender", region_column: "region"})
    .with_columns(
        pl.col("gender").apply(lambda x: "Not answered" if x == "" else x),
        pl.col("region").apply(lambda x: "Not answered" if x == "" else x),
    )
    .groupby(by=["region", "gender"])
    .count()
    .pivot(index="region", columns="gender", aggregate_function="sum", values="count")
    .to_pandas()
    .set_index("region")
)

In [None]:
ax = df_count_gender_per_region.plot.bar(figsize=(20, 5))
for container in ax.containers:
    ax.bar_label(container)