In [None]:
# !pip install --force-reinstall --no-cache-dir "numpy<2" scipy seaborn matplotlib pandas pyarrow

import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.utils import shuffle

if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")

df = pd.read_csv("./data/final/poleno/poleno_labels_clean.csv")

In [None]:
# sample_counts = df.value_counts(["genus", "species", "dataset_id"]).reset_index()
# sample_counts.to_excel("available_dataset_ids.xlsx")

Drop samples with nan

In [None]:
null_rows = df[df.isnull().any(axis=1)]
null_rows["dataset_id"].value_counts()

In [None]:
len_before_dropna = len(df)
df = df.dropna()
print(f"Samples in original dataset: {len_before_dropna}\nDropping samples with nan : {len_before_dropna-len(df)}\nLength after dropping : {len(df)}")

In [None]:
columns = ["dataset_id", "species", "genus"]

for col in columns:
    print(f"Nr unique {col:<{max(len(c) for c in columns)}} : {len(df[col].value_counts())}")

#### Number of images per event

In [None]:
samples_per_event_id = df["event_id"].value_counts()
valid_samples_per_event_id = samples_per_event_id[samples_per_event_id == 2].index
samples_per_event_id.value_counts()

Drop columns with more then two samples per `event_id`

In [None]:
df = df.loc[df["event_id"].isin(valid_samples_per_event_id)]

Drop columns that dont have exacly one image_nr 0 and 1

In [None]:
valid_ids = (
    df.groupby('event_id')['image_nr']
      .apply(lambda x: (x == 0).sum() == 1 and 
                       (x == 1).sum() == 1)
)
len_before_filtering = len(df)
df = df[df['event_id'].isin(valid_ids[valid_ids].index)]

print(len(df) - len_before_filtering)

In [None]:
def show_images_per_label(df, column="species", min_count=2000, highlight_column=None, highlight_labels=None):

    highlight_column = column if highlight_column is None else highlight_column

    sample_columns = list(set([column, highlight_column ]))

    samples_per_label = df.value_counts(sample_columns).reset_index()

    # Create color list (default all blue)
    colors = ["skyblue"] * len(samples_per_label)


    if highlight_labels:
        highlight_set = set(highlight_labels)
        colors = [
            "orange" if label in highlight_set else "skyblue"
            for label in samples_per_label[highlight_column]
        ]

    plt.figure(figsize=(14, 4))
    plt.bar(x=samples_per_label[column], height=samples_per_label["count"], color=colors)
    plt.axhline(min_count, color='red', ls='dotted')
    plt.xticks(rotation=90)
    plt.title("NR of samples for ")
    plt.show()

In [None]:
#  df.value_counts("genus").reset_index()

In [None]:
highlight_labels = ["Acer", "Phytophthora", "Chenopodium", "Platanus", "Holcus", "Cynosurus", "Anthoxanthum"]
show_images_per_label(df, column="genus", min_count=10000, highlight_labels=highlight_labels)

In [None]:
show_images_per_label(df, column="species", min_count=2000, highlight_column="genus", highlight_labels=highlight_labels)

In [None]:
highlight_labels=[]
show_images_per_label(df, column="dataset_id", min_count=2000, highlight_labels=highlight_labels)

In [None]:
df.groupby("species").value_counts(["dataset_id"]).reset_index()

## Dataset Split

### Split into base, collection and zero dataset

In [None]:
def split_dataframe(df, column, values):
    """
    Split a DataFrame into two parts based on matching values in a given column.

    Parameters
    ----------
    df : pandas.DataFrame
        The input DataFrame.
    column : str
        Column name to filter on (e.g., 'dataset_id', 'species', etc.).
    values : list
        List of values to select from the specified column.

    Returns
    -------
    tuple of (df_selected, df_remaining)
        df_selected  -> rows where df[column] is in values
        df_remaining -> all other rows
    """
    if column not in df.columns:
        raise KeyError(f"Column '{column}' not found in DataFrame.")

    if not isinstance(values, (list, tuple, set)):
        raise TypeError("Parameter 'values' must be a list, tuple, or set.")

    mask = df[column].isin(values)
    df_selected = df[mask].copy()
    df_remaining = df[~mask].copy()

    return df_selected, df_remaining


isolate_species = ['Urtica sp. 1', 'Urtica sp. 1', 'Juniperus communis', 'Lolium rigidum',]

isolate_dataset = ["11eed16c-db23-b464-8bde-1e119433b62f"]

df_species_isolated, df_base = split_dataframe(df, column="species", values=isolate_species)
df_dataset_isolated, df_base = split_dataframe(df_base, column="dataset_id", values=isolate_dataset)

### Split into train, val, test dataset

In [None]:
def split_train_val_test_grouped(
    df,
    stratify_col,
    group_col="event_id",
    test_size_per_class=50,
    val_size_per_class=20,
    random_state=42,
):
    """
    Split a DataFrame into train, validation, and test sets while:
    - keeping all rows from the same group (e.g., event_id) together,
    - sampling a fixed number of groups per class for both test and validation sets.

    Parameters
    ----------
    df : pandas.DataFrame
        Input dataset.
    stratify_col : str
        Column used for stratification (e.g., 'species').
    group_col : str, optional
        Column that defines grouping (e.g., 'event_id').
    test_size_per_class : int, optional
        Number of groups (not rows) per class to assign to the test set.
    val_size_per_class : int, optional
        Number of groups (not rows) per class to assign to the validation set.
    random_state : int, optional
        Random seed.

    Returns
    -------
    (df_train, df_val, df_test)
        DataFrames with consistent event grouping.
    """

    if stratify_col not in df.columns:
        raise KeyError(f"Column '{stratify_col}' not found in DataFrame.")
    if group_col not in df.columns:
        raise KeyError(f"Column '{group_col}' not found in DataFrame.")

    # Map each event to its class (species)
    event_info = df[[group_col, stratify_col]].drop_duplicates(subset=group_col)

    # Shuffle events to randomize selection
    event_info = shuffle(event_info, random_state=random_state).reset_index(drop=True)

    test_events = []
    val_events = []
    train_events = []

    # --- Fixed-size sampling per class for test and val ---
    for label, group in event_info.groupby(stratify_col):
        n = len(group)

        # If too small, use whatever available
        n_test = min(test_size_per_class, n)
        test_part = group.sample(n=n_test, random_state=random_state)
        remaining = group.drop(test_part.index)

        n_val = min(val_size_per_class, len(remaining))
        val_part = remaining.sample(n=n_val, random_state=random_state)
        train_part = remaining.drop(val_part.index)

        test_events.append(test_part)
        val_events.append(val_part)
        train_events.append(train_part)

    # Combine event-level data
    test_events = pd.concat(test_events)
    val_events = pd.concat(val_events)
    train_events = pd.concat(train_events)

    # Map back to full rows (include all images per event)
    df_test = df[df[group_col].isin(test_events[group_col])]
    df_val = df[df[group_col].isin(val_events[group_col])]
    df_train = df[df[group_col].isin(train_events[group_col])]

    return df_train, df_val, df_test

In [None]:
basic_config = {
    "data": df_base,
    "test_events_per_class": 100,
    "val_events_per_class": 100,
}

separate_species_config = {
    "data": df_species_isolated,
    "test_events_per_class": 100,
    "val_events_per_class": 100,
}

separate_dataset_config = {
    "data": df_dataset_isolated,
    "test_events_per_class": 100,
    "val_events_per_class": 100,
}

split_configs = [basic_config, separate_species_config, separate_dataset_config]

for config in split_configs:

    train_df, val_df, test_df = split_train_val_test_grouped(
        df=config["data"],
        stratify_col="dataset_id",
        group_col="event_id",
        test_size_per_class=config["test_events_per_class"],
        val_size_per_class=config["val_events_per_class"],
        random_state=42,
    )

    print("Train events:", train_df["event_id"].nunique())
    print("Val events:", val_df["event_id"].nunique())
    print("Test events:", test_df["event_id"].nunique())

In [None]:
test_df["dataset_id"].value_counts()

In [None]:
val_df["dataset_id"].value_counts()

In [None]:
train_df[["dataset_id", "species"]].value_counts()