In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import polars as pl
import numpy as np
from scipy.stats import norm

df = pl.read_csv("../data/Mental_Health_and_Social_Media_Balance_Dataset.csv")
TOTAL = df.count().item(0, 0)

fig, axs = plt.subplots(3, 3)
fig.set_figwidth(12)
fig.set_figheight(12)
plt.subplots_adjust(
    wspace=0.35,
    hspace=0.35
)

# Gender
gender_distribution = df.group_by("Gender").agg(pl.col("Gender").count().alias("Count"))

axs[0, 0].pie(
    gender_distribution.select("Count").to_series(),
    labels=gender_distribution.select("Gender").to_series(),
    autopct='%1.1f%%'
)
axs[0, 0].set_title("Gender Distribution")

# Social Media
social_media_distribution = df.group_by("Social_Media_Platform").agg(
    (pl.col("Social_Media_Platform").count() / TOTAL).alias("Relative_Amount")
)

sns.barplot(social_media_distribution, x="Social_Media_Platform", y="Relative_Amount", ax=axs[0, 1])
axs[0, 1].set_title("Social Media Platform Distribution")
axs[0, 1].set_xlabel("")
axs[0, 1].set_ylabel("")

# Age
age_distribution = df.select("Age")

min_age = age_distribution.min().item()
max_age = age_distribution.max().item()

age_breaks = np.arange(start=min_age, stop=max_age, step=4) 
age_distribution = age_distribution.to_series() \
                    .cut(breaks=age_breaks) \
                    .to_frame() \
                    .group_by("Age") \
                    .agg((pl.col("Age").count() / TOTAL).alias("Relative_Amount")) \
                    .sort("Age")

sns.barplot(age_distribution, x="Age", y="Relative_Amount", order=age_distribution.select("Age").to_series(), ax=axs[0, 2])
axs[0, 2].set_title("Age Distribution")
axs[0, 2].set_xlabel("")
axs[0, 2].set_ylabel("")

# Daily Screen Time
min_screen_time = df.select("Daily_Screen_Time(hrs)").min().item()
max_screen_time = df.select("Daily_Screen_Time(hrs)").max().item()

screen_time_breaks = np.arange(start=min_screen_time, stop=max_screen_time, step=1) 
 
screen_time_distribution = df.with_columns(pl.col("Daily_Screen_Time(hrs)").cut(screen_time_breaks).alias("Binned_Daily_Screen_Time(hrs)")) \
                                .group_by("Binned_Daily_Screen_Time(hrs)") \
                                .agg(
                                    (pl.col("Daily_Screen_Time(hrs)").count() / TOTAL).alias("Relative_Aount"),
                                    pl.col("Daily_Screen_Time(hrs)").min().alias("Min")
                                ) \
                                .sort("Min")
                                
mean_screen_time = df.select("Daily_Screen_Time(hrs)").mean().item()
std_screen_time = df.select("Daily_Screen_Time(hrs)").std().item()

x_norm_screen_time = np.linspace(min_screen_time, max_screen_time, 500)
y_norm_screen_time = norm.pdf(x_norm_screen_time, loc=mean_screen_time, scale=std_screen_time)

sns.barplot(
    screen_time_distribution,
    x="Binned_Daily_Screen_Time(hrs)",
    y="Relative_Aount",
    order=screen_time_distribution.select("Binned_Daily_Screen_Time(hrs)").to_series(),
    ax = axs[1, 0]
)
ax_norm_screen_time = axs[1, 0].twiny()
ax_norm_screen_time.plot(x_norm_screen_time, y_norm_screen_time, color="red", label="Normalverteilung")
ax_norm_screen_time.set_xticks([])
axs[1, 0].set_title("Daily_Screen_Time(hrs) Distribution")
axs[1, 0].set_xlabel("")
axs[1, 0].set_ylabel("")

# Remaining
remaining_cols = [
    ("Sleep_Quality(1-10)", 1, 1, True),
    ("Stress_Level(1-10)", 1, 2, True),
    ("Days_Without_Social_Media", 2, 0, True),
    ("Exercise_Frequency(week)", 2, 1, True),
    ("Happiness_Index(1-10)", 2, 2, False)
]

for col, plot_row, plot_col, show_normal  in remaining_cols:
    distribution = df.group_by(col).agg(
        (pl.col(col).count() / TOTAL).alias("Relative_Amount"),
        pl.col(col).min().alias("Min"),
        pl.col(col).max().alias("Max"),
        ).sort(col)
    
    
    sns.barplot(distribution, x=col, y="Relative_Amount", ax=axs[plot_row, plot_col])

    if show_normal:
        min = distribution.select("Min").min().item()
        max = distribution.select("Max").max().item()
        mean = df.select(col).mean().item()
        std = df.select(col).std().item()
        
        x_norm = np.linspace(min, max, 500)
        y_norm = norm.pdf(x_norm, loc=mean, scale=std)
        ax_norm = axs[plot_row, plot_col].twiny()
        ax_norm.plot(x_norm, y_norm, color="red", label="Normalverteilung")
        ax_norm.set_xlim(min, max)
        ax_norm.set_xticks([])
        
    axs[plot_row, plot_col].set_title(f"{col} Distribution")
    axs[plot_row, plot_col].set_xlabel("")
    axs[plot_row, plot_col].set_ylabel("")
    
# fix xticklabel overlapping
for ax in [ax for axes in axs for ax in axes]:
    if ax.tick_params:
        ax.tick_params(axis="x", rotation=40)
    
plt.show()