In [None]:
import os.path as osp
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from fontTools.misc.psOperators import ps_dict

from sklearn.model_selection import train_test_split

from sklearn import preprocessing


In [None]:
DATA_ROOT_PATH = "../metadata"
SEED = 777
np.random.seed(SEED)
SPLIT_RATIO = 0.9


data_df = pd.read_csv(osp.join(DATA_ROOT_PATH, "FungiCLEF2023_train_metadata_PRODUCTION.csv"))


In [None]:
def make_mini() -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    train_mini = pd.read_csv("../metadata/DF20M-train_metadata_PROD.csv")
    test_mini = pd.read_csv("../metadata/DF20M-public_test_metadata_PROD.csv")

    df20m = pd.concat([train_mini, test_mini]).reset_index(drop=True)
    return df20m, train_mini, test_mini

df20m, train_mini, test_mini = make_mini()

In [None]:
unique_genus = df20m["genus"].unique()
data_df = data_df[data_df["genus"].isin(unique_genus)].reset_index(drop=True)

In [None]:
le = preprocessing.LabelEncoder()

data_df['class_id'] = le.fit_transform(data_df['scientificName']).astype(np.int64)
data_df.head()

In [None]:
original_train_images = set(train_mini["ImageUniqueID"])
original_test_images = set(test_mini["ImageUniqueID"])

train_mini_observations = data_df[data_df["ImageUniqueID"].isin(original_train_images)]["observationID"]
train_mini_observations = train_mini_observations.unique()

test_mini_observations = data_df[data_df["ImageUniqueID"].isin(original_test_images)]["observationID"]
test_mini_observations = test_mini_observations.unique()

len(train_mini_observations), len(test_mini_observations)

In [None]:
sum(data_df["ImageUniqueID"].isin(original_train_images)) + sum(data_df["ImageUniqueID"].isin(original_test_images)) == len(data_df)

In [None]:
overflown_obs = []
for test_mini_observation in test_mini_observations:
    if test_mini_observation in train_mini_observations:
        overflown_obs.append(test_mini_observation)
        
len(overflown_obs)

In [None]:
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg


LOCAL_DIR = "../data/DF20M"

def plot_images_based_on_observationID(df: pd.DataFrame, obs: int):
    observation_images = df[df["observationID"] == obs]

    observation_images["image_path"] = observation_images.image_path.apply(
        lambda path: os.path.join(LOCAL_DIR, os.path.basename(path)))
    
    num_rows = (len(observation_images) // 3) + int(len(observation_images)%3 != 0)
    num_cols = min(len(observation_images), 3)  
    # Create a subplot grid
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 8))

    # Flatten the axes array in case there's only one row or column
    axes = axes.flatten()

    # Iterate through the image paths and plot each image
    for i, image_path in enumerate(observation_images["image_path"]):
        axes[i].imshow(mpimg.imread(image_path))
        axes[i].axis('off')  # Turn off axis labels

    # Adjust layout to prevent overlapping
    plt.tight_layout()
    
    # Show the plot
    plt.show()

In [None]:
import random


for i in range(100):
    obs_id = overflown_obs[random.randint(0, len(overflown_obs)-1)]
    print(f"ObservationID: {obs_id}", data_df[data_df["observationID"] == obs_id]["scientificName"].unique())
    plot_images_based_on_observationID(data_df, obs_id)
    
    assert obs_id in train_mini_observations and obs_id in test_mini_observations
    
""" Interesting observationIDs
2238404539, 2238460869, 2424122260

"""  

In [None]:
len(data_df["class_id"].unique())

In [None]:
train_indexes, val_indexes = [], []

for class_id in data_df["class_id"].unique():
    single_class_data = data_df[data_df["class_id"] == class_id]
    observation_ids = single_class_data["observationID"].unique()
    
    train_ids, val_ids = train_test_split(observation_ids, train_size=SPLIT_RATIO, random_state=SEED)
    
    # print(len(train_ids), len(val_ids))
    
    cls_train_idxs = list(data_df[data_df["observationID"].isin(train_ids)].index)
    cls_val_idxs = list(data_df[data_df["observationID"].isin(val_ids)].index)
    
    assert len(set([*cls_train_idxs, *cls_val_idxs])) == len(cls_train_idxs) + len(cls_val_idxs)
    
    train_indexes += cls_train_idxs
    val_indexes += cls_val_idxs

train_df = data_df.iloc[train_indexes]
val_df = data_df.iloc[val_indexes]
    

In [None]:
assert all([obs_id not in set(train_df["observationID"]) for obs_id in set(val_df["observationID"])]) 

In [None]:
# original_col_order = data_df.columns.values
# train_df = train_df[original_col_order]
# val_df = val_df[original_col_order]
train_df = train_df.sort_index()
val_df = val_df.sort_index()

len(train_df) / len(data_df)

In [None]:
from pathlib import Path

image_pairs = pd.concat([train_df, val_df])[["ImageUniqueID", "image_path"]]

for ( img_id, img_path) in image_pairs.values:
    assert img_id == Path(img_path).stem
    orig_path = data_df[data_df["ImageUniqueID"] == img_id]["image_path"].values[0]
    assert img_id == Path(orig_path).stem
    df20_orig_path = df20m[df20m["ImageUniqueID"] == img_id]["image_path"].values[0]
    assert img_id == Path(df20_orig_path).stem

In [None]:
train_df.to_csv("../metadata/DanishFungi2020M-train_mini-FIX02.csv", index=False)
val_df.to_csv("../metadata/DanishFungi2020M-val_mini-FIX02.csv", index=False)