In [None]:
# !pip install --force-reinstall --no-cache-dir "numpy<2" scipy seaborn matplotlib pandas pyarrow

import os
import math
import json
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.utils import shuffle

if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")

from notebooks.notebook_helpers import load_train_and_test_ids


## 1 Load Data

In [None]:
# Data
df = pd.read_csv("./data/final/poleno/poleno_labels_clean.csv", index_col=0)

# Train/test split
train_ids, test_ids = load_train_and_test_ids("./data/raw/available_dataset_ids_swisens.xlsx")

## 2 Pre-Process

### 2.1 Drop Invalid Entries

Drop old samples

In [None]:
# Drop old samples
keep_datasets = train_ids.loc[train_ids["throw-out"]==False, "dataset_id"].unique()
before_count = len(df)
df = df[df["dataset_id"].isin(keep_datasets)]
print(f"Removed {before_count - len(df)} rows from thrown-out datasets.")

Drop samples with nan

In [None]:
null_rows = df[df.isnull().any(axis=1)]
null_rows["dataset_id"].value_counts()

In [None]:
len_before_dropna = len(df)
df = df.dropna()
print(f"Samples in original dataset: {len_before_dropna}\nDropping samples with nan : {len_before_dropna-len(df)}\nLength after dropping : {len(df)}")

In [None]:
columns = ["dataset_id", "species", "genus"]

for col in columns:
    print(f"Nr unique {col:<{max(len(c) for c in columns)}} : {len(df[col].value_counts())}")

#### Number of images per event

In [None]:
samples_per_event_id = df["event_id"].value_counts()
valid_samples_per_event_id = samples_per_event_id[samples_per_event_id == 2].index
samples_per_event_id.value_counts()

Drop columns with more then two samples per `event_id`

In [None]:
df = df.loc[df["event_id"].isin(valid_samples_per_event_id)]

Drop columns that dont have exacly one image_nr 0 and 1

In [None]:
# Fastest way to count occurrences of 0/1 per event_id
counts = (
    df.groupby(["event_id", "image_nr"])
      .size()
      .unstack(fill_value=0)
)

# Keep event_ids with exactly one 0 and one 1
valid_ids = counts[(counts[0] == 1) & (counts[1] == 1)].index

len_before = len(df)
df = df[df["event_id"].isin(valid_ids)]
print(len(df) - len_before)

In [None]:
# def show_images_per_label(df, column="species", min_count=2000, highlight_column=None, highlight_labels=None):

#     highlight_column = column if highlight_column is None else highlight_column

#     sample_columns = list(set([column, highlight_column ]))

#     samples_per_label = df.value_counts(sample_columns).reset_index()

#     # Create color list (default all blue)
#     colors = ["skyblue"] * len(samples_per_label)


#     if highlight_labels:
#         highlight_set = set(highlight_labels)
#         colors = [
#             "orange" if label in highlight_set else "skyblue"
#             for label in samples_per_label[highlight_column]
#         ]

#     plt.figure(figsize=(14, 4))
#     plt.bar(x=samples_per_label[column], height=samples_per_label["count"], color=colors)
#     plt.axhline(min_count, color='red', ls='dotted')
#     plt.xticks(rotation=90)
#     plt.title(f"NR of samples for {column}")
#     plt.show()

In [None]:
# highlight_labels = ["Acer", "Chenopodium", "Platanus", "Holcus", "Cynosurus", "Anthoxanthum"]
# show_images_per_label(df, column="genus", min_count=10000, highlight_labels=highlight_labels)

In [None]:
# show_images_per_label(df, column="species", min_count=2000, highlight_column="genus", highlight_labels=highlight_labels)

In [None]:
# highlight_labels=[]
# show_images_per_label(df, column="dataset_id", min_count=2000, highlight_labels=highlight_labels)

### 2.2 Create additional columns

In [None]:
def extract_common_root(df, root_col="root"):
    """
    Given a dataframe with a column containing absolute root paths,
    returns:
      - common_root: the longest shared prefix directory
      - df with new column 'intermediate_path'
    """
    roots = df[root_col].astype(str).tolist()

    # Normalize and split all paths
    split_paths = [os.path.normpath(r).split(os.sep) for r in roots]

    # Find longest common prefix across all rows
    common_parts = []
    for parts in zip(*split_paths):
        if all(p == parts[0] for p in parts):
            common_parts.append(parts[0])
        else:
            break

    common_root = os.sep.join(common_parts)

    # Compute intermediate path for each row = root minus common prefix
    common_len = len(common_parts)
    df["intermediate_path"] = [
        os.sep.join(parts[common_len:]) if len(parts) > common_len else ""
        for parts in split_paths
    ]

    return common_root, df


def build_img_path(df, dataset_id_col="dataset_id", rec_path_col="rec_path"):
    """
    Creates an img_path column by combining:
        dataset_id / intermediate_path / rec_path
    """
    df["img_path"] = df.apply(
        lambda row: os.path.join(
            str(row["intermediate_path"]),
            str(row[dataset_id_col]),
            str(row[rec_path_col])
        ),
        axis=1
    )
    return df


_, df = extract_common_root(df, root_col="root")
df = build_img_path(df, dataset_id_col="dataset_id", rec_path_col="rec_path")


Force correct dataset to species mapping

In [None]:
with open("data/processed/Poleno_25/dataset_ids.json" , "r") as f:
    dataset_ids_to_species = json.load(f)

df.loc[df["dataset_id"].isin(dataset_ids_to_species), "species"] = \
    df.loc[df["dataset_id"].isin(dataset_ids_to_species), "dataset_id"].map(dataset_ids_to_species)

df["genus"] = df["species"].apply(lambda x: x.split()[0])

## 3 Dataset Split

### 3.1 Split into base, collection and zero dataset

In [None]:
def split_dataframe(df, column, values):
    """
    Split a DataFrame into two parts based on matching values in a given column.

    Parameters
    ----------
    df : pandas.DataFrame
        The input DataFrame.
    column : str
        Column name to filter on (e.g., 'dataset_id', 'species', etc.).
    values : list
        List of values to select from the specified column.

    Returns
    -------
    tuple of (df_selected, df_remaining)
        df_selected  -> rows where df[column] is in values
        df_remaining -> all other rows
    """
    df = df.copy()

    if column not in df.columns:
        raise KeyError(f"Column '{column}' not found in DataFrame.")

    if not isinstance(values, (list, tuple, set)):
        raise TypeError("Parameter 'values' must be a list, tuple, or set.")

    mask = df[column].isin(values)
    df_selected = df[mask].copy()
    df_remaining = df[~mask].copy()

    return df_selected, df_remaining


# isolate_species = ['Urtica sp. 1', 'Urtica sp. 1', 'Juniperus communis', 'Lolium rigidum',]

isolate_dataset = test_ids["dataset_id"].tolist()

# df_species_isolated, df_base = split_dataframe(df, column="species", values=isolate_species)
df_dataset_isolated, df_base = split_dataframe(df, column="dataset_id", values=isolate_dataset)

df_base = df_base.sort_values(by="event_id")
df_dataset_isolated = df_dataset_isolated.sort_values(by="event_id")

### 3.2 Split into train, val, test dataset

In [None]:
def split_train_val_test_grouped(
    df,
    stratify_col,
    group_col="event_id",
    test_size_per_class=50,
    val_size_per_class=20,
    test_prop_per_class=0,
    val_prop_per_class=0,   
    random_state=42,
):
    df = df.copy()

    if val_prop_per_class + test_prop_per_class >= 1.0:
        raise ValueError("Sum of val_prop_per_class and test_prop_per_class must be less than 1.0")

    if stratify_col not in df.columns:
        raise KeyError(f"Column '{stratify_col}' not found in DataFrame.")
    if group_col not in df.columns:
        raise KeyError(f"Column '{group_col}' not found in DataFrame.")

    event_info = df[[group_col, stratify_col]].drop_duplicates(subset=group_col)
    event_info = shuffle(event_info, random_state=random_state).reset_index(drop=True)

    test_events = []
    val_events = []
    train_events = []

    for label, group in event_info.groupby(stratify_col):
        total_groups = len(group)

        # ---------------------
        # TEST count
        # ---------------------
        prop_test = math.ceil(test_prop_per_class * total_groups)
        n_test = max(test_size_per_class, prop_test)
        n_test = min(n_test, total_groups)

        test_part = group.sample(n=n_test, random_state=random_state)
        remaining = group.drop(test_part.index)

        # ---------------------
        # VAL count
        # ---------------------
        prop_val = math.ceil(val_prop_per_class * total_groups)

        n_val = max(val_size_per_class, prop_val)
        n_val = min(n_val, len(remaining))  # still clamp to available groups

        val_part = remaining.sample(n=n_val, random_state=random_state)
        train_part = remaining.drop(val_part.index)

        test_events.append(test_part)
        val_events.append(val_part)
        train_events.append(train_part)

    test_events = pd.concat(test_events)
    val_events = pd.concat(val_events)
    train_events = pd.concat(train_events)

    df_test  = df[df[group_col].isin(test_events[group_col])]
    df_val   = df[df[group_col].isin(val_events[group_col])]
    df_train = df[df[group_col].isin(train_events[group_col])]

    return df_train, df_val, df_test

In [None]:
basic_config = {
    "name": "basic",
    "data": df_base,
    "val_prop_per_class": 0.05,
    "min_val_events_per_class": 20,
    "test_prop_per_class": 0.1,
    "min_test_events_per_class": 20,
}

# separate_species_config = {
#     "data": df_species_isolated,
#     "val_prop_per_class": 0.05,
#     "min_val_events_per_class": 25,
#     "test_prop_per_class": 0.1,
#     "min_test_events_per_class": 25,
# }

separate_dataset_config = {
    "name": "isolated",
    "data": df_dataset_isolated,
    "val_prop_per_class": 0.05,
    "min_val_events_per_class": 20,
    "test_prop_per_class": 0.1,
    "min_test_events_per_class": 20,
}

split_configs = [basic_config, separate_dataset_config]

save_in = "data/final/poleno/"
save_csv = False

for config in split_configs:

    train_df, val_df, test_df = split_train_val_test_grouped(
        df=config["data"],
        stratify_col="dataset_id",
        group_col="event_id",
        test_size_per_class=config["min_test_events_per_class"],
        test_prop_per_class=config["test_prop_per_class"],
        val_size_per_class=config["min_val_events_per_class"],
        val_prop_per_class=config["val_prop_per_class"],
        random_state=42,
    )
    
    print("Train events:", train_df["event_id"].nunique())
    print("Val events:", val_df["event_id"].nunique())
    print("Test events:", test_df["event_id"].nunique())

    if save_csv:
        train_df.to_csv(os.path.join(save_in, config["name"] + "_train.csv"), index=False)
        val_df.to_csv(os.path.join(save_in, config["name"] + "_val.csv"), index=False)
        test_df.to_csv(os.path.join(save_in, config["name"] + "_test.csv"), index=False)

In [None]:
# root = r"Z:\marvel\marvel-fhnw\data\Poleno25"

# # get parquet file to check
# parquet_files = [f for f in os.listdir(r"Z:\marvel\marvel-fhnw\data\Poleno25") if f.endswith(".parquet")]

# complete_id_to_genus = {}
# complete_id_to_species = {}
# for file in parquet_files:
#     print(f"Processing file: {file}")
#     parquet = pd.read_parquet(os.path.join(root, file))
#     id_to_genus = dict(zip(parquet["dataset_id"], parquet["genus"]))
#     id_to_species = dict(zip(parquet["dataset_id"], parquet["species"]))
#     complete_id_to_genus.update(id_to_genus)
#     complete_id_to_species.update(id_to_species)


## Check species

In [None]:
# dataset_id = "11edffad-46fc-1c2c-9d9c-66f2ec8a65cb"
# df.loc[df["dataset_id"] == dataset_id]

In [None]:
# str(pd.read_parquet(f"Z:/marvel/marvel-fhnw/data/Poleno25/{dataset_id}.parquet")["species"].unique())

In [None]:
# dataset_ids_to_species[dataset_id]

In [None]:
def find_dataset_genus_mismatches(df, dataset_col, genus_col, mapping):
    """
    Compare unique dataset_id → genus pairs in a DataFrame
    with a dictionary mapping dataset_id → expected genus.

    Prints mismatches.
    """

    # Extract unique pairs
    unique_pairs = df[[dataset_col, genus_col]].drop_duplicates()

    mismatches = []

    for _, row in unique_pairs.iterrows():
        dataset = row[dataset_col]
        genus = row[genus_col]

        if dataset in mapping:
            expected_genus = mapping[dataset]
            if genus != expected_genus:
                mismatches.append((dataset, genus, expected_genus))

    # Print results
    if not mismatches:
        print("No mismatches found.")
    else:
        print("Mismatches:")
        for ds, g, exp in mismatches:
            print(f"  dataset_id={ds} | genus={g} | expected={exp}")

    return mismatches

mismatches = find_dataset_genus_mismatches(df, "dataset_id", "species", dataset_ids_to_species)

In [None]:
def check_dataset_species(df, dataset_col="dataset_id", species_col="species"):
    """
    Returns dataset_ids that have more than one species.
    """
    # Count how many times each (dataset_id, species) pair occurs
    counts = df[[dataset_col, species_col]].drop_duplicates()

    # Number of unique species per dataset_id
    species_counts = counts[dataset_col].value_counts()

    # IDs with >1 species
    inconsistent = species_counts[species_counts > 1].index.tolist()

    if not inconsistent:
        print("ok ✅")
        return []
    else:
        print("Inconsistent dataset_ids:")
        for ds in inconsistent:
            species_list = df[df[dataset_col] == ds][species_col].unique()
            print(f"  {ds} -> {species_list}")

    return inconsistent


_ = check_dataset_species(df, dataset_col="dataset_id", species_col="species")

# check_dataset_species(df, dataset_col="species", species_col="dataset_id")