In [None]:
import os
import sqlite3
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset

from coral.Geography_Helper import Geography_Helper

if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir(os.path.dirname(os.getcwd()))

In [None]:
def load_tables_from_db(
    db_path: Path, tables: list[str] = ["Recordings", "Speakers"]
) -> tuple[pd.DataFrame, ...]:
    """Load specified tables from a SQLite database and return them as a tuple of data frames.

    Args:
        db_path: The path to the SQLite database file.
        tables: A list of table names to load. Defaults to ["Recordings", "Speakers"].

    Returns:
        A tuple containing pandas DataFrames for the specified tables.
    """
    conn = sqlite3.connect(db_path)
    dataframes = []
    for table in tables:
        query = f"SELECT * FROM {table}"
        dataframe = pd.read_sql_query(query, conn)
        dataframes.append(dataframe)
    conn.close()
    return tuple(dataframes)

In [None]:
# dataset_name = "internal"

dataset_name = "CoRal-project/coral-v2"


# Define the path to the public database
path_database_public = Path("/Volumes/CoRal/raw") / "CoRal_public.db"

## Load base data

In [None]:
# Check if the dataset name is "internal"
if dataset_name == "internal":
    # Load the "Recordings" and "Conversations" tables from the database
    df_recordings, df_conversations = load_tables_from_db(
        path_database_public, ["Recordings", "Conversations"]
    )

    # Expand the conversations table to create two rows for each conversation, one for each speaker
    df_conversation_expanded = pd.concat(
        [
            df_conversations.assign(
                id_speaker=df_conversations["id_speaker_a"]
            ),  # Assign speaker A
            df_conversations.assign(
                id_speaker=df_conversations["id_speaker_b"]
            ),  # Assign speaker B
        ],
        ignore_index=True,
    )

    # Drop columns specific to the original conversation structure
    df_conversation_expanded = df_conversation_expanded.drop(
        columns=["id_speaker_a", "id_speaker_b", "id_recorder"]
    )

    # Combine the recordings and expanded conversations into a single DataFrame
    df_combined = pd.concat(
        [df_recordings, df_conversation_expanded], ignore_index=True
    )

else:
    # Handle huggingface datasets
    if dataset_name == "CoRal-project/coral":
        subsets = None  # No subsets for this dataset
    elif dataset_name == "CoRal-project/coral-v2":
        subsets = ["read_aloud", "conversation"]  # Define subsets for coral-v2

    splits = ["train", "val", "test"]  # Define dataset splits

    # Initialize an empty DataFrame to store combined data
    df_combined = pd.DataFrame()

    for subset in subsets:
        # Load the dataset for the current subset
        dataset = load_dataset(dataset_name, subset)
        for split in splits:
            # Remove the "audio" column and convert the split to a pandas DataFrame
            ds_split = dataset[split].remove_columns(["audio"])
            df = ds_split.to_pandas()

            # Add metadata columns for split and subset
            df["split"] = split
            df["subset"] = subset

            # Append the data to the combined DataFrame
            if df_combined.empty:
                df_combined = df
            else:
                df_combined = pd.concat([df_combined, df], ignore_index=True)

### Augment data

In [None]:
df_combined["id"] = df_combined["id_recording"].combine_first(
    df_combined["id_conversation"]
)

In [None]:
df_speakers = load_tables_from_db(path_database_public, ["Speakers"])[0]

# Create a mapping from df_speakers for age, gender, and dialect
speaker_map = df_speakers.set_index("id_speaker")[
    ["age", "gender", "dialect", "zip_school"]
].to_dict(orient="index")
speaker_map = df_speakers.rename(columns={"zip_school": "zipcode"})


# BEGIN: Create individual mapping dictionaries
age_map = df_speakers.set_index("id_speaker")["age"].to_dict()
gender_map = df_speakers.set_index("id_speaker")["gender"].to_dict()
dialect_map = df_speakers.set_index("id_speaker")["dialect"].to_dict()
zipcode_map = df_speakers.set_index("id_speaker")["zip_school"].to_dict()
# END: Create individual mapping dictionaries

# BEGIN: Apply individual mapping dictionaries to df_combined
df_combined["age"] = df_combined["id_speaker"].map(age_map)
df_combined["gender"] = df_combined["id_speaker"].map(gender_map)
df_combined["dialect"] = df_combined["id_speaker"].map(dialect_map)
df_combined["zipcode"] = df_combined["id_speaker"].map(zipcode_map)
# END: Apply individual mapping dictionaries to df_combined

In [None]:
df_duration = load_tables_from_db(path_database_public, ["Durations"])[0]

duration_map = df_duration.set_index("id")["duration"].to_dict()
# Apply the duration mapping to df_combined
df_combined["duration_seconds"] = df_combined["id"].map(duration_map)
df_combined["duration_minutes"] = df_combined["duration_seconds"] / 60
df_combined["duration_hours"] = df_combined["duration_seconds"] / 3600

# Set duration to 0 for all duplicates based on the 'id' column
df_combined.loc[
    df_combined.duplicated(subset="id", keep="first"),
    ["duration_seconds", "duration_minutes", "duration_hours"],
] = 0

In [None]:
df_combined["type"] = df_combined["id"].apply(
    lambda x: "conversation" if "conv" in x else "read_aloud"
)

In [None]:
df_combined

## Preprocess data

In [None]:
# Define the range for binning
bin_edges = [0, 25, 50, 200]  # Bins for age ranges: 0-24, 25-50, 50+

# Add a new column for binned data
df_combined["age_binned"] = pd.cut(
    df_combined["age"], bins=bin_edges, right=False, labels=["0-25", "25-50", "50+"]
)

In [None]:
dialect_map = {
    # Bornholmsk
    "bornholmsk": "Bornholmsk",
    # Fynsk
    "fynsk": "Fynsk",
    "østfynsk": "Fynsk",
    "vestfynsk (nordvest-, sydvestfynsk)": "Fynsk",
    "sydfynsk": "Fynsk",
    "langelandsk": "Fynsk",
    "tåsingsk (m. thurø)": "Fynsk",
    "ærøsk (m. lyø, avernakø, strynø, birkholm, drejø)": "Fynsk",
    # Københavnsk
    "amagermål": "Københavnsk",
    "københavnsk": "Københavnsk",
    # Sjællandsk
    "sjællandsk": "Sjællandsk",
    "nordsjællandsk": "Sjællandsk",
    "østsjællandsk": "Sjællandsk",
    "nordvestsjællandsk": "Sjællandsk",
    "sydvestsjællandsk": "Sjællandsk",
    "sydsjællandsk (sydligt sydsjællandsk)": "Sjællandsk",
    # Sydømål
    "østmønsk": "Sydømål",
    "vestmønsk": "Sydømål",
    "nordfalstersk": "Sydømål",
    "sydfalstersk": "Sydømål",
    "lollandsk": "Sydømål",
    "sydømål": "Sydømål",
    # Sønderjysk
    "sønderjysk": "Sønderjysk",
    "østligt sønderjysk (m. als)": "Sønderjysk",
    "vestlig sønderjysk (m. mandø og rømø)": "Sønderjysk",
    "syd for rigsgrænsen: mellemslesvisk, angelmål, fjoldemål": "Sønderjysk",
    "mellemslesvisk": "Sønderjysk",
    # Vestjysk
    "vestjysk": "Vestjysk",
    "thybomål": "Vestjysk",
    "morsingmål": "Vestjysk",
    "sallingmål": "Vestjysk",
    "hardsysselsk": "Vestjysk",
    "fjandbomål": "Vestjysk",
    "sydvestjysk (m. fanø)": "Vestjysk",
    "sydøstjysk": "Vestjysk",
    # Østjysk
    "midtøstjysk": "Østjysk",
    "ommersysselsk": "Østjysk",
    "djurslandsk (nord-, syddjurs m. nord- og sydsamsø, anholt)": "Østjysk",
    "østjysk": "Østjysk",
    # Nordjysk (NEW)
    "nørrejysk": "Nordjysk",
    "vendsysselsk (m. hanherred og læsø)": "Nordjysk",
    "himmerlandsk": "Nordjysk",
    # General fallback (optional)
    "jysk": "Jysk",
}

df_combined["dialect"] = df_combined["dialect"].apply(lambda x: dialect_map.get(x, x))

In [None]:
geo_helper = Geography_Helper()
df_combined["kommunekod"] = df_combined["zipcode"].apply(
    lambda x: geo_helper.getMunicipality(x)
)
df_combined["regionskod"] = df_combined["kommunekod"].apply(
    lambda x: geo_helper.getRegion(x)
)

## Define helper functions

In [None]:
import os

import pandas as pd


def distribution_plot(
    df: pd.DataFrame,
    group_by: str,
    agg_col: str = None,
    agg_func: str = "size",  # 'size', 'sum', 'nunique', 'percentage'
    stack_col: str = None,
    figsize: tuple = (10, 6),
    title: str = "Distribution Plot",
    colormap: str = "tab10",
    xlabel: str = None,
    ylabel: str = None,
    save_dir: str = None,
):
    """Create a (stacked) bar plot from a DataFrame with value labels and total labels.

    Parameters:
        df (pd.DataFrame): Input DataFrame.
        group_by (str): Column to group by on the x-axis.
        agg_col (str): Column to aggregate if using 'sum' or 'nunique'. Not needed for 'size'.
        agg_func (str): Aggregation function: 'size', 'nunique', 'sum', or 'percentage'.
        stack_col (str, optional): Column to stack bars by (categorical).
        figsize (tuple): Figure size.
        title (str): Plot title.
        colormap (str): Matplotlib colormap name.
        xlabel (str): Custom label for x-axis.
        ylabel (str): Custom label for y-axis.
        save_dir (str): Directory to save the plot as PNG if provided.

    Returns:
        matplotlib.axes.Axes: The plot axes.
    """
    valid_funcs = ["size", "sum", "nunique", "percentage"]
    if agg_func not in valid_funcs:
        raise ValueError(f"agg_func must be one of {valid_funcs}")
    if agg_func in ["sum", "nunique"] and agg_col is None:
        raise ValueError(f"agg_col must be specified when using '{agg_func}'")

    # Group and aggregate
    if stack_col:
        if agg_func == "size" or agg_func == "percentage":
            data = df.groupby([group_by, stack_col]).size().unstack(fill_value=0)
        elif agg_func == "sum":
            data = (
                df.groupby([group_by, stack_col])[agg_col].sum().unstack(fill_value=0)
            )
        elif agg_func == "nunique":
            data = (
                df.groupby([group_by, stack_col])[agg_col]
                .nunique()
                .unstack(fill_value=0)
            )
    else:
        if agg_func == "size" or agg_func == "percentage":
            data = df[group_by].value_counts().sort_index()
        elif agg_func == "sum":
            data = df.groupby(group_by)[agg_col].sum()
        elif agg_func == "nunique":
            data = df.groupby(group_by)[agg_col].nunique()

    # Convert to percentage if needed
    if agg_func == "percentage":
        if stack_col:
            data = data.div(data.sum(axis=1), axis=0) * 100
        else:
            data = (data / data.sum()) * 100

    # Plotting
    plt.figure(figsize=figsize)
    if stack_col:
        ax = data.plot(kind="bar", stacked=True, figsize=figsize, colormap=colormap)
        for idx, group in enumerate(data.index):
            y_offset = 0
            for col in data.columns:
                value = data.loc[group, col]
                if value > 0:
                    label = (
                        f"{value:.2f}%" if agg_func == "percentage" else f"{value:.2f}"
                    )
                    ax.text(
                        idx,
                        y_offset + value / 2,
                        label,
                        ha="center",
                        va="center",
                        fontsize=9,
                    )
                    y_offset += value
            total = data.loc[group].sum()
            total_label = (
                f"{total:.2f}%" if agg_func == "percentage" else f"{total:.2f}"
            )
            ax.text(
                idx,
                y_offset + max(data.values.max() * 0.01, 1),
                total_label,
                ha="center",
                va="bottom",
                fontsize=10,
                fontweight="bold",
            )
    else:
        ax = data.plot(kind="bar", figsize=figsize, colormap=colormap)
        for idx, value in enumerate(data.values):
            label = f"{value:.2f}%" if agg_func == "percentage" else f"{value:.2f}"
            ax.text(
                idx,
                value + max(data.max() * 0.01, 1),
                label,
                ha="center",
                va="bottom",
                fontsize=10,
            )

    ax.set_title(title)
    ax.set_xlabel(xlabel if xlabel else group_by)
    ax.set_ylabel(
        ylabel
        if ylabel
        else (
            "Percentage"
            if agg_func == "percentage"
            else ("Count" if agg_func == "size" else agg_func.capitalize())
        )
    )
    plt.xticks(rotation=45)
    plt.tight_layout()

    # Save the plot if save_dir is provided
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        filename = os.path.join(save_dir, f"{title.replace(' ', '_')}.png")
        plt.savefig(filename, bbox_inches="tight")
        print(f"Plot saved to: {filename}")

    plt.show()
    return ax

In [None]:
def plot_geo_counts(
    df,
    group_by: str,
    geo_level: str,
    agg_col: str = None,
    agg_func: str = "size",
    rename_col: str = None,
    title: str = None,
    save_dir: str = None,
    interactive: bool = False,
    cmap: str = "viridis",
    round_decimals: int = 2,
    **explore_kwargs,
):
    """Plots a geographic distribution (static or interactive choropleth).

    Parameters:
        df (pd.DataFrame): Input data.
        group_by (str): Column to group by (should match a column in geo data).
        geo_level (str): Level of geography (used with geo_helper.get_dfmap()).
        agg_col (str): Column to aggregate on (required for 'sum' and 'nunique').
        agg_func (str): Aggregation function: 'size', 'sum', or 'nunique'.
        rename_col (str): Optionally rename group_by to match geo data.
        title (str): Title for the plot.
        save_dir (str): If provided, saves static plot to this directory.
        interactive (bool): Use interactive map if True.
        cmap (str): Colormap for the plot.
        round_decimals (int): Rounding of numeric output.
        **explore_kwargs: Extra args for GeoDataFrame.explore() if interactive.

    Returns:
        GeoDataFrame: Merged and aggregated geodata.
    """
    # Determine default name for value_col
    if agg_func == "size":
        value_col = "count"
        data = df.groupby(group_by).size().reset_index(name=value_col)
    elif agg_func == "nunique":
        if not agg_col:
            raise ValueError("agg_col must be specified for 'nunique'")
        value_col = f"n_unique_{agg_col}"
        data = df.groupby(group_by)[agg_col].nunique().reset_index(name=value_col)
    elif agg_func == "sum":
        if not agg_col:
            raise ValueError("agg_col must be specified for 'sum'")
        value_col = f"sum_{agg_col}"
        data = df.groupby(group_by)[agg_col].sum().reset_index(name=value_col)
    else:
        raise ValueError("agg_func must be one of: 'size', 'sum', 'nunique'")

    # Rename group_by column if needed
    if rename_col:
        data = data.rename(columns={group_by: rename_col})
        group_by = rename_col

    # Merge with geo data
    dfmap = geo_helper.get_dfmap(geo_level)
    dfmap = pd.merge(dfmap, data, how="left", on=group_by)
    dfmap[value_col] = dfmap[value_col].fillna(0).round(round_decimals)

    # Plotting
    if interactive:
        if title:
            print(f"🗺️ {title}")
        return dfmap.explore(
            column=value_col,
            cmap=cmap,
            legend=True,
            tooltip=[group_by, value_col],
            **explore_kwargs,
        )
    else:
        dfmap.plot(column=value_col, cmap=cmap, legend=True)
        if title:
            plt.title(title)
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
            filepath = os.path.join(save_dir, f"{title.replace(' ', '_')}.png")
            plt.savefig(filepath, bbox_inches="tight")
            print(f"Map saved to: {filepath}")
        plt.axis("off")
        plt.tight_layout()
        plt.show()

    return dfmap

In [None]:
plt.rcParams["figure.figsize"] = (8, 5)  # Set default plot size to 10x6 inches

In [None]:
df = df_combined.copy()

# Age Distribution

In [None]:
distribution_plot(
    df,
    group_by="age_binned",
    agg_func="size",
    title="Age Distribution (Samples)",
    xlabel="Age Group",
    ylabel="Number of Samples",
    save_dir="outputs/visualizations",
)

distribution_plot(
    df,
    group_by="age_binned",
    agg_func="percentage",
    title="Age Distribution (Percentage of Samples)",
    xlabel="Age Group",
    ylabel="Percentage of Samples",
    save_dir="outputs/visualizations",
)

distribution_plot(
    df,
    group_by="age_binned",
    agg_func="nunique",
    agg_col="id_speaker",
    title="Age Distribution (Speakers)",
    xlabel="Age Group",
    ylabel="Number of Unique Speakers",
    save_dir="outputs/visualizations",
)

distribution_plot(
    df,
    group_by="age_binned",
    agg_func="sum",
    agg_col="duration_hours",
    title="Age Distribution (Hours of Recordings)",
    xlabel="Age Group",
    ylabel="Total Duration (Hours)",
    save_dir="outputs/visualizations",
)

# Dialect distribution

In [None]:
distribution_plot(
    df,
    group_by="dialect",
    agg_func="size",
    title="Dialect Distribution (Samples)",
    xlabel="Dialect",
    ylabel="Number of Samples",
    save_dir="outputs/visualizations",
)

distribution_plot(
    df,
    group_by="dialect",
    agg_func="percentage",
    title="Dialect Distribution (Percentage of Samples)",
    xlabel="Dialect",
    ylabel="Percentage of Samples",
    save_dir="outputs/visualizations",
)

distribution_plot(
    df,
    group_by="dialect",
    agg_func="nunique",
    agg_col="id_speaker",
    title="Dialect Distribution (Speakers)",
    xlabel="Dialect",
    ylabel="Number of Unique Speakers",
    save_dir="outputs/visualizations",
)

distribution_plot(
    df,
    group_by="dialect",
    agg_func="sum",
    agg_col="duration_hours",
    title="Dialect Distribution (Hours of Recordings)",
    xlabel="Dialect",
    ylabel="Total Duration (Hours)",
    save_dir="outputs/visualizations",
)

# Gender Distribution

In [None]:
distribution_plot(
    df,
    group_by="gender",
    agg_func="size",
    title="Gender Distribution (Samples)",
    xlabel="Gender",
    ylabel="Number of Samples",
    save_dir="outputs/visualizations",
)

distribution_plot(
    df,
    group_by="gender",
    agg_func="percentage",
    title="Gender Distribution (Percentage of Samples)",
    xlabel="Gender",
    ylabel="Percentage of Samples",
    save_dir="outputs/visualizations",
)

distribution_plot(
    df,
    group_by="gender",
    agg_func="nunique",
    agg_col="id_speaker",
    title="Gender Distribution (Speakers)",
    xlabel="Gender",
    ylabel="Number of Unique Speakers",
    save_dir="outputs/visualizations",
)

distribution_plot(
    df,
    group_by="gender",
    agg_func="sum",
    agg_col="duration_hours",
    title="Gender Distribution (Hours of Recordings)",
    xlabel="Gender",
    ylabel="Total Duration (Hours)",
    save_dir="outputs/visualizations",
)

## Geoplots

In [None]:
plot_geo_counts(
    df,
    group_by="zipcode",
    geo_level="zipcode",
    rename_col="postnummer",
    agg_func="size",
    title="Zipcode Distribution (Samples)",
    save_dir="outputs/visualizations",
)

plot_geo_counts(
    df,
    group_by="zipcode",
    geo_level="zipcode",
    rename_col="postnummer",
    agg_func="nunique",
    agg_col="id_speaker",
    title="Zipcode Distribution (Speakers)",
    save_dir="outputs/visualizations",
)

plot_geo_counts(
    df,
    group_by="zipcode",
    geo_level="zipcode",
    rename_col="postnummer",
    agg_func="sum",
    agg_col="duration_hours",
    title="Zipcode Distribution (Hours of Recordings)",
    save_dir="outputs/visualizations",
)

In [None]:
# Plot the distribution of recordings by municipality
plot_geo_counts(
    df,
    group_by="kommunekod",
    geo_level="municipality",
    agg_func="size",
    value_col="count",
    title="Municipality Distribution (Recordings)",
    interactive=False,
    save_dir="outputs/visualizations",
)

plot_geo_counts(
    df,
    group_by="kommunekod",
    geo_level="municipality",
    agg_func="nunique",
    agg_col="id_speaker",
    value_col="count",
    title="Municipality Distribution (Unique Speakers)",
    interactive=False,
    save_dir="outputs/visualizations",
)

plot_geo_counts(
    df,
    group_by="kommunekod",
    geo_level="municipality",
    agg_func="sum",
    agg_col="duration_hours",
    value_col="count",
    title="Municipality Distribution (Total Duration in Hours)",
    interactive=False,
    save_dir="outputs/visualizations",
)

In [None]:
plot_geo_counts(
    df,
    group_by="regionskod",
    geo_level="region",
    agg_func="size",
    title="Region Distribution (Recordings)",
    interactive=False,
    save_dir="outputs/visualizations",
)

plot_geo_counts(
    df,
    group_by="regionskod",
    geo_level="region",
    agg_func="nunique",
    agg_col="id_speaker",
    title="Region Distribution (Unique Speakers)",
    interactive=False,
    save_dir="outputs/visualizations",
)

plot_geo_counts(
    df,
    group_by="regionskod",
    geo_level="region",
    agg_func="sum",
    agg_col="duration_hours",
    title="Region Distribution (Total Duration in Hours)",
    interactive=False,
    save_dir="outputs/visualizations",
)

In [None]:
plot_geo_counts(
    df,
    group_by="kommunekod",
    geo_level="municipality",
    agg_func="size",
    title="Municipality Distribution (Recordings)",
    interactive=True,
)

In [None]:
df_combined.groupby("type")["duration_hours"].sum().T

In [None]:
df_combined["duration_hours"].sum()

In [None]:
df_combined["id_speaker"].nunique()