In [1]:
import polars as pl
import utils
utils.mount_src()

from hypothesis_row import HypothesisRow

DATA = "../data/movie_lens.csv"
USERS_DF = "../data/user_counts/movie_lens.csv"

In [2]:
df = pl.read_csv(DATA)
users_df = pl.read_csv(USERS_DF)
hs = HypothesisRow.df_to_list(df, users_df)

In [7]:
df = HypothesisRow.hypothesis_rows_to_group_df(hs)
df["runtime_minutes"].value_counts()

runtime_minutes,count
str,u32
"""Long""",104
"""Short""",3
,56


In [20]:
col = "genre"

def get_counts_for_col(s: pl.Series) -> pl.DataFrame:
    return s.value_counts()

counts = get_counts_for_col(df[col])
counts

genre,count
str,u32
"""Drama""",63
,40
"""Film-Noir""",1
"""Thriller""",14
"""Action""",11
"""War""",2
"""Romance""",5
"""Crime""",13
"""Comedy""",14


In [35]:
import plotly.express as px

fig = px.bar(counts.sort(by="count", descending=False), x="count", y=col, orientation="h")
fig.show()

In [42]:
import plotly.graph_objects as go

def make_group_bar_chart(df: pl.DataFrame, col: str, descending=True, horizontal=True) -> go.Figure:
    counts = df[col].value_counts()
    if horizontal:
        fig = px.bar(counts.sort(by="count", descending=(not descending)), x="count", y=col, orientation="h")
    else:
        fig = px.bar(counts.sort(by="count", descending=descending), x=col, y="count")
    return fig

fig = make_group_bar_chart(df, "genre", horizontal=True)
fig.show()

In [48]:
def all_counts(df: pl.DataFrame) -> pl.DataFrame:
    rows = []
    for c in df.columns:
        counts = df[c].value_counts()
        for label, value in counts.iter_rows():
            rows.append({"category": c, "label": label, "count": value})
    return pl.DataFrame(rows).drop_nulls()

all_counts_df = all_counts(df)
all_counts_df

category,label,count
str,str,i64
"""runtime_minutes""","""Short""",3
"""runtime_minutes""","""Long""",104
"""occupation""","""technician-engineer""",2
"""occupation""","""programmer""",2
"""occupation""","""homemaker""",1
…,…,…
"""year""","""40s""",1
"""year""","""90s""",57
"""year""","""60s""",3
"""year""","""80s""",8


In [49]:
fig = px.bar(all_counts_df, x="count", y="label", color="category", orientation="h")
fig.update_layout(xaxis=dict(range=[0, max(all_counts_df["count"])]))
fig.show()

In [53]:
def make_group_bar_chart(df: pl.DataFrame, col: str | None, descending=True, horizontal=True) -> go.Figure:
    if col:
        counts = df[col].value_counts().drop_nulls()
        if horizontal:
            fig = px.bar(counts.sort(by="count", descending=(not descending)), x="count", y=col, orientation="h")
        else:
            fig = px.bar(counts.sort(by="count", descending=descending), x=col, y="count")
    else:
        counts = all_counts(df).drop_nulls()
        if horizontal:
            fig = px.bar(counts, x="count", y="label", color="category", orientation="h")
            fig.update_layout(xaxis=dict(range=[0, max(all_counts_df["count"])]))
        else:
            fig = px.bar(counts, x="label", y="count", color="category")

        #fig.update_layout(xaxis=dict(range=[0, max(all_counts_df["count"])]))
            
    return fig

fig = make_group_bar_chart(df, col=None, horizontal=False)
fig.show()