# Evaluation of the first version of the classifier

Generate the output on:
* GBIF validation dataset
* Flemming dataset?

Metrics:
* Accuracy, precision, recall, F1 --> mini_metrics

## Normalize the data

In [None]:
from pathlib import Path
import pandas as pd
from fastai.vision.all import load_learner
import numpy as np

In [None]:
pred_path = Path("/home/george/codes/lepinet/data/flemming/preds_04-lepi-prod_model1_label.csv")
lrn_path = Path("/home/george/codes/lepinet/data/lepi/models/04-lepi-prod_model1-save-hierarchy-id2name")
out_path = Path("/home/george/codes/lepinet/data/flemming/preds_04-lepi-prod_model1_label_normalized.csv")

In [None]:
df=pd.read_csv(pred_path)

In [None]:
df.columns

In [None]:
# Format the table: instance_id,filename,level,label,prediction,confidence,threshold, known_label

# add instance_id column
df['instance_id'] = df.index // 3

# add threshold column
df['threshold'] = 0.5

In [None]:
learn=load_learner(lrn_path)

In [None]:
len(learn.dls.vocab)

In [None]:
len(np.unique(list(learn.hierarchy.keys())))

In [None]:
df['known_label'] = df['label'].astype(str).isin(learn.dls.vocab)

In [None]:
# reorder columns
new_order = ["instance_id","filename","level","label","prediction","confidence","threshold","known_label"]
df=df[new_order]

In [None]:
df.to_csv(out_path, index=False)

## Parenthesis: check the number of parameters in the model

In [None]:
import importlib
from fastai.vision.all import load_learner, CategoryMap, vision_learner
import yaml 
from pathlib import Path

config_path = Path("/home/george/codes/lepinet/configs/20251106_1_test_ece.yaml")

with open(config_path) as f:
    config=yaml.safe_load(f)

gen_dls = getattr(importlib.import_module('011_lepi_large_prod_v2'), 'gen_dls')

dls,hierarchy=gen_dls(**config['train'])
model_arch = getattr(importlib.import_module('fastai.vision.all'), config['train']['model_arch_name'])
learn = vision_learner(dls, model_arch)
learn.model

In [None]:
# Count total parameters
total_params = sum(p.numel() for p in learn.model.parameters())

# Count only trainable parameters
trainable_params = sum(p.numel() for p in learn.model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")

In [None]:
learn.unfreeze()

In [None]:
import torch
sum([(p.in_features * p.out_features) for p in learn.model[1] if isinstance(p, torch.nn.Linear)])

## Get some graphs

First, get some info about the training set:
- Number of species, genus, family
- Distribution
- Ten most common species
- Ten most common species in each country

Second, get some metrics on the testing set:
- Performance on the model on the ten most common species and species/country

In [None]:
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
from pprint import pprint

In [None]:
parquet_path = Path("/home/george/codes/lepinet/data/lepi/0061420-241126133413365_sampled_processing_metadata_postprocessed.parquet")
preds_path = Path("/home/george/codes/lepinet/data/lepi/preds/00_eulepi.csv")
fig_path = Path("/home/george/codes/lepinet/data/lepi/figures")
ami_path = Path("/home/george/codes/lepinet/data/ami/parquets")
output_dir  = fig_path / "figures_accuracy"

In [None]:
df=pd.read_parquet(parquet_path)

In [None]:
df.columns

In [None]:
df.head(1)

Training set: everything except rows '0'

Validation set: everything marked as '1'

Test set: everything marked as '0'

In [None]:
len(df), df['speciesKey'].nunique(), df['genusKey'].nunique(), df['familyKey'].nunique()

In [None]:
ts = df[~df['set'].isin(['0'])] # train set
valset = ts[ts['set'].isin(['1'])]
tts = df[df['set'].isin(['0'])] # test set

In [None]:
len(ts), ts['speciesKey'].nunique(), ts['genusKey'].nunique(), ts['familyKey'].nunique()

In [None]:
6247 + 1724 + 86

In [None]:
ts['countryCode']

In [None]:
len(ts[ts['countryCode'].isin(["BE", "BG", "HR", "CZ", "DK", "FI", "FR", "IE", "IL", "IT", "PT", "SK", "ES", "SE", "NL"])])

How many species have 500 images?

In [None]:
max(ts['speciesKey'].value_counts())

In [None]:
ts[ts.groupby('speciesKey')['speciesKey'].transform('count') == 450]['speciesKey'].nunique()

Ten most populated species/genus/families

In [None]:
ts.countryCode.unique()

In [None]:
countries_ias = ["BE","BG","CZ","DK","FR","IE","IL","IT","PT","SK","SE"]
countries_abms = ["BE","BG","CZ","DK","FI","IE","IL","IT","SK","SE","NL"]
# countries = ["BE","BG","CZ","DK","FI","FR","IE","IL","IT","PT","SK","SE","NL", "HR"]
countries = ["BE","BG","CZ","DK","FI","FR","IE","IT","PT","SK","SE","NL", "HR"]
print(len(countries))

countries_map = {
    "BE": "belgium",
    "BG": "bulgaria",
    "CZ": "czech_republic",
    "DK": "denmark",
    "FI": "finland",
    "FR": "france",
    "IE": "ireland",
    "IL": "israel",
    "IT": "italy",
    "PT": "portugal",
    "SK": "slovakia",
    "SE": "sweden",
    "NL": "the_netherlands",
    "HR": "croatia"
}

def get_df(ami_path, country):
    """Get dataset for one country"""
    project=['ias','abms']
    paths = [list(ami_path.glob(f"*{p}_{country}_2025*.parquet"))[0] for p in project if len(list(ami_path.glob(f"*{p}_{country}_2025*.parquet"))) > 0]
    dfs = [pd.read_parquet(p) for p in paths]
    df = pd.concat(dfs, ignore_index=True)
    return df[(df['taxonlevel']=='species') & (df['score']>0.5) & (df['algorithm']=='fastai-species')]

def get_top_c(ami_path, country, n=10):
    df=get_df(ami_path, country)
    print(f'Most predicted species in {c}')
    print(df['label'].value_counts().head(n).index.tolist())
    return df['labelid'].value_counts().head(n).index.tolist()

def get_in_ls(ami_path, country, ls):
    df=get_df(ami_path, country)
    return df[df['labelid'].isin([str(l) for l in ls])]['label'].index.tolist()

In [None]:
df_occ = tts


sns.set(style="whitegrid")
# --- CONFIG ---
# countries = ["BE","BG","HR","CZ","DK","FI","FR","IE","IL","IT","PT","SK","ES","SE","NL"]
# countries = ["BE","BG","CZ","DK","FI","FR","IE","IT","PT","SK","SE","NL"]
confidence_threshold = 0.5
top_n_country = 10
top_n_overall = 10

# Create a mapping: speciesKey → scientificName (for readable x-axis)
species_name_map = (
    df_occ.dropna(subset=['speciesKey', 'scientificName'])
         .drop_duplicates('speciesKey')
         .set_index('speciesKey')['scientificName']
         .apply(lambda x: ' '.join(x.split(" ")[:2]))
         .to_dict()
)


# Convert keys to numeric if they are strings representations
for k in ['speciesKey','genusKey','familyKey']:
    if k in df_occ.columns:
        df_occ[k] = pd.to_numeric(df_occ[k], errors='coerce')

# ---------- 1) Species distribution curves ----------
# Count images per speciesKey
species_counts = df_occ['speciesKey'].value_counts().sort_values(ascending=False)
n_species = species_counts.size
print(f"Number of species (unique speciesKey): {n_species}")
print(f"Max images per species (observed): {species_counts.max()}")

# Rank-abundance curve (linear y, rank x) and log-log
ranks = np.arange(1, len(species_counts) + 1)
counts = species_counts.values

In [None]:


# --- Ensure keys exist (speciesKey, genusKey, familyKey, filename, countryCode) ---

# df_occ = your occurrences dataframe
# df_pred = your predictions dataframe


# plt.figure(figsize=(8,5))
# plt.plot(ranks, counts, marker='.', linewidth=0.8)
# plt.title("Species rank-abundance (linear scale)")
# plt.xlabel("Species rank (1 = most abundant)")
# plt.ylabel("Images per species")
# plt.tight_layout()
# plt.show()

# # Histogram of counts (and log histogram)
# plt.figure(figsize=(8,4))
# plt.hist(counts, bins=50)
# plt.title("Histogram of images per species")
# plt.xlabel("Images per species")
# plt.ylabel("Number of species")
# plt.tight_layout()
# plt.show()

# # Lorenz / cumulative share plot: what share of images are in top X% species
# cum_counts = np.cumsum(counts)
# total = cum_counts[-1]
# cum_share = cum_counts / total
# prop_species = np.arange(1, len(counts)+1)/len(counts)

# plt.figure(figsize=(6,6))
# plt.plot(prop_species, cum_share, linewidth=1.2)
# plt.plot([0,1],[0,1], linestyle='--', color='gray')  # line of equality
# plt.xlabel("Proportion of species (ranked)")
# plt.ylabel("Cumulative proportion of images")
# plt.title("Lorenz-like curve: how images concentrate among species")
# plt.tight_layout()
# plt.show()

# # Optional: print concentration numbers
# # for pct in [0.01, 0.05, 0.1, 0.2]:
# #     k = int(len(counts) * pct)
# #     if k < 1:
# #         continue
# #     share = cum_counts[k-1] / total
# #     print(f"Top {int(pct*100)}% species ({k} species) own {share:.3f} of all images")

# if 'genusKey' in df_occ.columns:
#     genus_counts = df_occ['genusKey'].value_counts().sort_values(ascending=False)
#     ranks_g = np.arange(1, len(genus_counts)+1)
#     counts_g = genus_counts.values

#     plt.figure(figsize=(8,5))
#     plt.plot(ranks_g, counts_g, marker='.', linewidth=0.8)
#     plt.title("Genus rank-abundance (linear scale) — names omitted")
#     plt.xlabel("Genus rank")
#     plt.ylabel("Images per genus")
#     plt.tight_layout()
#     plt.show()

# # Family
# if 'familyKey' in df_occ.columns:
#     family_counts = df_occ['familyKey'].value_counts().sort_values(ascending=False)
#     ranks_f = np.arange(1, len(family_counts)+1)
#     counts_f = family_counts.values

#     plt.figure(figsize=(8,5))
#     plt.plot(ranks_f, counts_f, marker='.', linewidth=0.8)
#     plt.title("Family rank-abundance (linear) — names omitted")
#     plt.xlabel("Family rank")
#     plt.ylabel("Images per family")
#     plt.tight_layout()
#     plt.show()


# # If you want a compact table of top N genus/family:
# print("Top 10 genera (by count):")
# print(genus_counts.head(10))
# print("\nTop 10 families (by count):")
# print(family_counts.head(10))

use_subplots = True  
if use_subplots:
    n = len(countries)
    ncols = 3
    nrows = int(np.ceil(n / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(15, 2 * nrows * 3 + 1))
    axes = axes.flatten()

    for i, c in enumerate(countries):
        df_c = df_occ[df_occ['countryCode'] == c]
        ax = axes[i]
        if df_c.empty:
            ax.axis('off')
            ax.set_title(f"{c} (no data)")
            continue
        sc = df_c['speciesKey'].value_counts().head(top_n_country) # Only top n
        sc.index = [species_name_map.get(k, str(k)) for k in sc.index]  # <--- show names
        sc.plot(kind='bar', ax=ax)
        ax.set_title(f"{c}")
        ax.set_xlabel("")
        ax.set_ylabel("Count")
        # ax.tick_params(axis='x', labelrotation=45)
        ax.tick_params(axis='x')
    # hide unused axes
    for j in range(i+1, len(axes)):
        axes[j].axis('off')
    fig.suptitle(f"Top {top_n_country} species per country", fontsize=14)
    plt.tight_layout()
    plt.show()
else:
    for c in countries:
        df_c = df_occ[df_occ['countryCode'] == c]
        if df_c.empty:
            print(f"No records for country {c}")
            continue
        sc = df_c['speciesKey'].value_counts().head(top_n_country)
        sc.index = [species_name_map.get(k, str(k)) for k in sc.index]
        # sc = df_c.groupby('speciesKey').transform('count')['scientificName'].head(top_n_country)
        plt.figure(figsize=(8,4))
        sc.plot(kind='bar')
        plt.title(f"Top {top_n_country} speciesKey in {c} (counts). X labels are speciesKey (IDs).")
        plt.xlabel("speciesKey (ID)")
        plt.ylabel("Count")
        plt.tight_layout()
        plt.show()

Explore predictions

In [None]:
df_pred = pd.read_csv(preds_path)

In [None]:
df_pred.head(6)

In [None]:
confidence_threshold = 0.5
top_n_country = 10
top_n_overall = 10


# ---------- 4) Model performance on top-N species (overall) ----------
# Determine top-N species overall (by speciesKey counts)
species_counts = df_occ[df['familyKey']=='7015']['speciesKey'].value_counts().sort_values(ascending=False)

top_species_keys = species_counts.head(top_n_overall).index.astype(str).tolist()
print("Top speciesKey (overall):", top_species_keys)

# Merge predictions with ground truth taxon keys from df_occ via filename
# Ensure both have same filename column and same type
left = df_pred.copy()
right = tts[['filename','speciesKey','genusKey','familyKey','countryCode']].drop_duplicates(subset='filename')
merged = pd.merge(left, right, on='filename', how='left', suffixes=('','_occ'))

# Only keep rows where we have a true taxon key
merged = merged.dropna(subset=['speciesKey','genusKey','familyKey'])

# Map true_label based on level
def get_true_label(row):
    lvl = int(row['level'])
    if lvl == 0:
        return int(row['speciesKey'])
    elif lvl == 1:
        return int(row['genusKey'])
    elif lvl == 2:
        return int(row['familyKey'])
    else:
        return np.nan

merged['true_label'] = merged.apply(get_true_label, axis=1)

# Valid prediction = confidence >= threshold (we'll use confidence_threshold)
merged['valid_conf'] = merged['confidence'] >= confidence_threshold

# Correct prediction = valid_conf AND prediction == true_label
# (prediction and true_label might be numeric or strings, force numeric compare)
merged['prediction_num'] = pd.to_numeric(merged['prediction'], errors='coerce')
merged['correct'] = merged['valid_conf'] & (merged['prediction'] == merged['true_label'])

# Focus on rows corresponding to the top species (we care about performance on images belonging to these species)
merged_top = merged[merged['speciesKey'].isin(top_species_keys)]
level_map = {0: 'species', 1: 'genus', 2: 'family'}

In [None]:
# Compute accuracy per top species and per level

acc = (
    merged_top
    .groupby(['speciesKey', 'level'])
    .agg(n_images=('filename','count'),
         n_valid_preds=('valid_conf','sum'),
         n_correct=('correct','sum'))
    .reset_index()
)
acc['accuracy'] = acc['n_correct'] / acc['n_valid_preds'].replace(0, np.nan)  # accuracy conditional on valid predictions
acc['level_name'] = acc['level'].astype(int).map(level_map)

# Display
print("\nModel performance (top species overall):")
display_cols = ['speciesKey','level_name','n_images','n_valid_preds','n_correct','accuracy']
print(acc[display_cols].sort_values(['speciesKey','level_name']))

# Plot accuracy for each top species across levels
plt.figure(figsize=(10,5))
# pivot to get speciesKey on x and levels as different bars
pivot = acc.pivot(index='speciesKey', columns='level_name', values='accuracy')
pivot.plot(kind='bar', figsize=(12,5))
plt.ylabel('Accuracy (correct / valid predictions)')
plt.title(f"Model accuracy on top {top_n_overall} species (confidence >= {confidence_threshold})")
plt.ylim(0,1)
plt.tight_layout()
plt.show()

# Also show number of valid predictions vs images (to see coverage)
pivot_n = acc.pivot(index='speciesKey', columns='level_name', values='n_valid_preds').fillna(0)
pivot_n.plot(kind='bar', figsize=(12,5))
plt.ylabel('Number of valid predictions (confidence >= threshold)')
plt.title(f"Valid predictions (per level) on top {top_n_overall} species")
plt.tight_layout()
plt.show()


In [None]:
# ---------- 5) (Optional) Performance per country top-10 species ----------
# If you want model performance for the top-10 species within each country,
# repeat the above per country: get top speciesKey within country and filter merged accordingly.
for c in countries:
    df_c = df_occ[df_occ['countryCode'] == c]
    if df_c.empty:
        continue
    top_c = df_c['speciesKey'].value_counts().head(top_n_country).index.astype('str').tolist()
    merged_c = merged[merged['speciesKey'].isin(top_c)]
    acc_c = (
        merged_c
        .groupby(['speciesKey','level'])
        .agg(n_images=('filename','count'),
             n_valid_preds=('valid_conf','sum'),
             n_correct=('correct','sum'))
        .reset_index()
    )
    if acc_c.empty:
        continue
    acc_c['accuracy'] = acc_c['n_correct'] / acc_c['n_valid_preds'].replace(0, np.nan)
    acc_c['level_name'] = acc_c['level'].astype(int).map(level_map)
    print(f"\nCountry {c} - model accuracy on top {top_n_country} species (speciesKey):")
    print(acc_c[['speciesKey','level_name','n_images','n_valid_preds','n_correct','accuracy']])
    # small plot
    pivot_c = acc_c.pivot(index='speciesKey', columns='level_name', values='accuracy')
    pivot_c.plot(kind='bar', figsize=(10,4), title=f"Accuracy for top {top_n_country} species in {c}")
    plt.ylim(0,1)
    plt.tight_layout()
    plt.show()

In [None]:
# left = df_pred.copy()
# right = tts[['filename','speciesKey','genusKey','familyKey','countryCode']].drop_duplicates(subset='filename')
# merged = pd.merge(left, right, on='filename', how='left', suffixes=('','_occ'))
# merged

In [None]:
species_name_map = (
    df_occ.dropna(subset=['speciesKey', 'scientificName'])
         .drop_duplicates('speciesKey')
         .set_index('speciesKey')['scientificName']
         .apply(lambda x: ' '.join(x.split(" ")[:2]))
         .to_dict()
)

In [None]:
# custom_species_keys_country = ["1877988", "1876542", "1316908", "1857068", "1325561", "1322635", "4425774", "1872901"]
# custom_species_keys_country = [1877988, 1876542, 1316908, 1857068,1325561, 1322635, 4425774, 1872901]
custom_species_keys_country = [
    1651430,
    1877988,
    1779258,
    2034536,
    4531353,
    1785185,
    1785185,
    4530608,
    1861251,
    4532122,
    4532122,
    4532122,
    4532122,
    4532122,
    4532122,
    4532122,
    4532122,
    11149183,
    1882905,
    1879452,
    1803108,
    1737131,
    1737131,
    2042933,
    1876542,
    10457666,
    4485843,
    4485843,
    4989904,
    4989904,
    1316908,
    5743297,
    5743301,
    1820406,
    1890065,
    1857068,
    2037925,
    1325561,
    1322635,
    6133619,
    4425774,
    1872901,
    9264380,
    8352161,
    5109855,
    5109855,
    5109931,
    5109931,
    5109931,
    5109882,
    1850658,
    5879783,
    5879783,
    1824131,
    1824131,
    1824113,
    1824113,
    1824113,
    1315391,
    1159172
]


id2name = {}
ids = df_pred[(df_pred['level']==0) & (df_pred['label'].isin(custom_species_keys_country))]['label'].unique()
names = ts[ts['speciesKey'].isin([str(c) for c in custom_species_keys_country])]['scientificName'].unique()
# names = ts[ts['speciesKey'].isin(custom_species_keys_country)]['scientificName'].unique()
with open("/home/george/codes/lepinet/data/lepi/figures/id2name.txt", 'w') as f:
    for i,n in zip(ids,names):
        id2name[i]=n
        f.write(f"{i}: {n}\n")
pprint(id2name)

custom_species_keys_country = list(id2name.keys())

In [None]:
nbof_imgs_per_spc = (
    df_pred[(df_pred['level'] == 0) & (df_pred['label'].isin(custom_species_keys_country))]
    ['label']
    .value_counts()
    .rename(index=id2name)
)
nbof_imgs_per_spc

In [None]:
len(df_pred[(df_pred['level']==0) & (df_pred['label'].isin(custom_species_keys_country))]['label'].unique())

In [None]:
merged.head(6)

In [None]:
# ----------------------------------------------
# (5) Model performance per country top species
# ----------------------------------------------
use_subplots = True   # <--- optional: display all country plots in one figure
# custom_species_keys_country = ["1877988", "1876542", "1316908", "1857068", "1325561", "1322635", "4425774", "1872901"]  # <--- optional override list (like before)
# custom_species_keys_country = []  # <--- optional override list (like before)
# custom_species_keys_country = [str(e) for e in custom_species_keys_country]
ncols = 3

save_figures = True  # <--- enable/disable saving easily
os.makedirs(output_dir, exist_ok=True)

if use_subplots:
    nrows = int(np.ceil(len(countries) / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(15, nrows * 5))
    axes = axes.flatten()

for i, c in enumerate(countries):
    df_c = df_occ[df_occ['countryCode'] == c]
    if df_c.empty:
        continue

    # Use custom species list if provided
    if custom_species_keys_country:
        top_c = custom_species_keys_country
    else:
        # top_c = df_c['speciesKey'].value_counts().head(top_n_country).index.astype(str).tolist()
        # Force to plot Noctuidae only:
        # top_c = df_c[df_c['familyKey']==7015]['speciesKey'].value_counts().head(top_n_country).index.astype(str).tolist()
        top_c = get_top_c(ami_path, country=countries_map[c])
        top_c = [int(c) for c in top_c]

    merged_c = merged[merged['speciesKey'].isin(top_c)]
    acc_c = (
        merged_c.groupby(['speciesKey', 'level'])
        .agg(n_images=('filename', 'count'),
             n_valid_preds=('valid_conf', 'sum'),
             n_correct=('correct', 'sum'))
        .reset_index()
    )

    if acc_c.empty:
        continue

    acc_c['accuracy'] = acc_c['n_correct'] / acc_c['n_valid_preds'].replace(0, np.nan)
    acc_c['coverage'] = acc_c['n_valid_preds'] / acc_c['n_images'].replace(0, np.nan)
    acc_c['level_name'] = acc_c['level'].astype(int).map(level_map)

    print(f"\nCountry {c} - model accuracy on top {top_n_country} species:")
    print(acc_c[['speciesKey','level_name','n_images','n_valid_preds','n_correct','accuracy', 'coverage']])
    acc_c[['speciesKey','level_name','n_images','n_valid_preds','n_correct','accuracy', 'coverage']].to_csv(output_dir/"acc_ias.csv")

    # Pivot for plotting
    pivot_c = acc_c.pivot(index='speciesKey', columns='level_name', values='accuracy')

    # species_cov = (
    #     acc_c.groupby('species')['coverage']
    #     .mean()
    #     .reindex(top_c)
    # )

    # species_cov = (
    #     acc_c.groupby(['speciesKey', 'level'])['coverage']
    #     .mean()
    # )
    # print(species_cov)
    species_cov = acc_c[acc_c['level_name']=='species'].set_index('speciesKey')['coverage']

    # ---- Human-readable x-axis labels
    # Ensure the pivot table follows top_c order
    pivot_c = pivot_c.reindex(top_c)
    # pivot_c.index = [species_name_map.get(int(k), str(k)) for k in pivot_c.index]
    pivot_c.index = [
        f"{species_name_map.get(k, str(k))} ({species_cov.loc[k]*100:.0f}%)"
        if not pd.isna(species_cov.loc[k])
        else species_name_map.get(k, str(k))
        for k in pivot_c.index
    ]

    if use_subplots:
        ax = axes[i]
        pivot_c.plot(kind='bar', ax=ax, legend=False, width=0.8)
        for container in ax.containers:
            ax.bar_label(container, fmt="%.2f",
                 fontsize=7, padding=2,
                 rotation=90,  # <--- rotate vertically
                 label_type="edge")
        ax.set_title(f"{countries_map[c].replace('_', ' ').title()}")
        ax.set_ylim(0.5, 1.1)  # zoom into the range
        ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.2f}"))
        # ax.tick_params(axis='x', rotation=45)
        ax.tick_params(axis='x')
        ax.set_xlabel("")
        ax.set_ylabel("Accuracy")
    else:
        ax = pivot_c.plot(kind='bar', figsize=(10, 7), width=0.8,
                        legend=True)
        for container in ax.containers:
            ax.bar_label(container, fmt="%.2f", fontsize=7,
                        padding=2, rotation=90, label_type="edge")
        ax.set_ylim(0, 1.1)
        # Add legend below title, above plot
        ax.legend(title="Taxonomic level", loc="upper center", ncol=3,
                # bbox_to_anchor=(0.35, 0.25), 
                # bbox_to_anchor=(0.25, 0.18), 
                bbox_to_anchor=(0.5, 1.25), 
                frameon=False, 
                # bbox_transform=fig.transFigure
                )

        # Adjust title spacing (slightly higher)
        ax.set_title(f"Accuracy for top {top_n_country} species in {countries_map[c].replace('_', ' ').title()}", pad=50)
        # ax.set_title(f"Accuracy for IAS", pad=50)
        ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.2f}"))
        ax.set_xlabel("")
        ax.set_ylabel("Accuracy")
        ax.tick_params(axis='x')
        # fig.subplots_adjust(top=0.86)   # reduce/increase this to create more/less room
        plt.tight_layout(rect=[0, 0, 1, 0.95])
        if save_figures:
            filename = output_dir/f"accuracy_{c}.svg"
            # filename = output_dir/f"ias.svg"
            plt.savefig(filename, dpi=300, bbox_inches="tight")
        plt.show()
        # break

if use_subplots:
    # Hide unused axes (as before)
    for j in range(i+1, len(axes)):
        axes[j].axis("off")

    # Shared legend (only once, not on every subplot)
    handles, labels = axes[0].get_legend_handles_labels()

    # Adjust layout to leave space for legend + global title
    # Global title (higher)
    fig.suptitle(f"Model accuracy per country (top {top_n_country} species)",
                fontsize=14, y=1.01)

    # Shared legend placed below the title
    fig.legend(handles, labels, loc="upper center", ncol=3,
            title="Taxonomic level",
            bbox_to_anchor=(0.5, 1.0))  # shift slightly below the suptitle

    # Adjust layout to leave space for both title and legend
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    if save_figures:
        filename = output_dir/"accuracy_all_countries.svg"
        fig.savefig(filename, dpi=300, bbox_inches="tight")
    plt.show()

In [None]:
acc_c[['speciesKey','level_name','n_images','n_valid_preds','n_correct','accuracy', 'coverage']].tail(50)


Use the AMI data

In [None]:
all_parquet = ami_path.glob("*2025*.parquet")
ls_parquet = list(all_parquet)

In [None]:
project=['ias','abms']
countries=['denmark','croatia', 'finland', 'sweden', 'france', 'slovakia', 'the_netherlands', 'portugal', 'belgium', 'ireland', 'bulgaria', 'italy', 'isreal', 'czech_republic']
for p in project:
    for c in countries:
        nb_pq = len(list(ami_path.glob(f"*{p}_{c}_2025*.parquet")))
        if nb_pq == 0:
            print(f"The combition {p} and {c} does not exist.")
        if nb_pq > 1:
            print(list(ami_path.glob(f"*{p}_{c}_2025*.parquet")))

In [None]:
ls_parquet

In [None]:
df = pd.read_parquet(ls_parquet[0])

In [None]:
len(df)

In [None]:
df.head(10)

In [None]:
ft_prd=df[(df['taxonlevel']=='species') & (df['score']>0.5) & (df['algorithm']=='fastai-species')]

In [None]:
ft_prd.head()

In [None]:
ft_prd['label'].value_counts().head(10).index

In [None]:
countries_map = {
    "BE": "belgium",
    "BG": "bulgaria",
    "CZ": "czech_republic",
    "DK": "denmark",
    "FI": "finland",
    "FR": "france",
    "IE": "ireland",
    "IL": "israel",
    "IT": "italy",
    "PT": "portugal",
    "SK": "slovakia",
    "SE": "sweden",
    "NL": "the_netherlands",
    "HR": "croatia"
}

def get_df(ami_path, country):
    """Get dataset for one country"""
    project=['ias','abms']
    paths = [list(ami_path.glob(f"*{p}_{country}_2025*.parquet"))[0] for p in project if len(list(ami_path.glob(f"*{p}_{country}_2025*.parquet"))) > 0]
    dfs = [pd.read_parquet(p) for p in paths]
    df = pd.concat(dfs, ignore_index=True)
    return df[(df['taxonlevel']=='species') & (df['score']>0.5) & (df['algorithm']=='fastai-species')]

def get_top_c(ami_path, country, n=10):
    df=get_df(ami_path, country)
    print(f'Most predicted species in {c}')
    print(df['label'].value_counts().head(n).index.tolist())
    return df['labelid'].value_counts().head(n).index.tolist()

def get_in_ls(ami_path, country, ls):
    df=get_df(ami_path, country)
    return df[df['labelid'].isin([str(l) for l in ls])]['label'].index.tolist()

ddf=get_df(ami_path, 'denmark')
get_top_c(ami_path, 'denmark')
get_in_ls(ami_path, 'denmark', custom_species_keys_country)

In [None]:
custom_species_keys_country

In [None]:
sorted(countries)

In [None]:
use_subplots = True  
save_figures = True  # <--- enable/disable saving easily
os.makedirs(output_dir, exist_ok=True)
if use_subplots:
    n = len(countries)
    ncols = 3
    nrows = int(np.ceil(n / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(15, 2 * nrows * 3 + 1))
    axes = axes.flatten()

    for i, c in enumerate(countries):
        df_c = get_df(ami_path, country=countries_map[c])
        ax = axes[i]
        if df_c.empty:
            ax.axis('off')
            ax.set_title(f"{c} (no data)")
            continue

        # Select only top detections
        top_c = get_top_c(ami_path, country=countries_map[c])
        df_c = df_c[df_c['labelid'].isin(top_c)]

        sc = df_c['label'].value_counts().head(top_n_country) # Only top n
        sc.index = [species_name_map.get(k, str(k)) for k in sc.index]  # <--- show names
        sc.plot(kind='bar', ax=ax)
        ax.set_title(f"{countries_map[c].replace('_', ' ').title()}")
        ax.set_xlabel("")
        ax.set_ylabel("Count")
        # ax.tick_params(axis='x', labelrotation=45)
        ax.tick_params(axis='x')
    # hide unused axes
    # for j in range(i+1, len(axes)):
    #     axes[j].axis('off')
    # fig.suptitle(f"Top {top_n_country} species per country", fontsize=14)
    # plt.tight_layout()
    # plt.show()
else:
    for c in countries:
        df_c = df_occ[df_occ['countryCode'] == c]
        if df_c.empty:
            print(f"No records for country {c}")
            continue
        sc = df_c['speciesKey'].value_counts().head(top_n_country)
        sc.index = [species_name_map.get(k, str(k)) for k in sc.index]
        # sc = df_c.groupby('speciesKey').transform('count')['scientificName'].head(top_n_country)
        plt.figure(figsize=(10,7))
        sc.plot(kind='bar')
        plt.title(f"Top {top_n_country} detected species in {countries_map[c].replace('_', ' ').title()}.")
        plt.xlabel("speciesKey (ID)")
        plt.ylabel("Count")
        plt.tight_layout()
        if save_figures:
            filename = output_dir/f"distribution_{c}.svg"
            plt.savefig(filename, dpi=300, bbox_inches="tight")
        plt.show()


if use_subplots:
    # Hide unused axes (as before)
    for j in range(i+1, len(axes)):
        axes[j].axis("off")

    # Shared legend (only once, not on every subplot)
    handles, labels = axes[0].get_legend_handles_labels()

    # Adjust layout to leave space for legend + global title
    # Global title (higher)
    fig.suptitle(f"Species distribution per country (top {top_n_country} species)",
                fontsize=14, y=1.01)

    # Shared legend placed below the title
    fig.legend(handles, labels, loc="upper center", ncol=3,
            title="Taxonomic level",
            bbox_to_anchor=(0.5, 1.0))  # shift slightly below the suptitle

    # Adjust layout to leave space for both title and legend
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    if save_figures:
        filename = output_dir/"distribution_all_countries.svg"
        fig.savefig(filename, dpi=300, bbox_inches="tight")
    plt.show()
