# Label counting to determine dataset size

## Label mapping from ontology to datasets


In [None]:
import pandas as pd
import numpy as np
import faiss
# from sentence_transformers import SentenceTransformer


Setting up embeddings

In [None]:
# model = SentenceTransformer("google/embeddinggemma-300m")

from transformers import ClapTextModelWithProjection, AutoTokenizer

# Text encoder only: ~110M params = ~220MB in float16
model = ClapTextModelWithProjection.from_pretrained(
    "laion/clap-htsat-unfused", dtype="float16"  # Uses ~220MB RAM
)
tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")




def embed_texts(texts)-> np.ndarray:
    # vectors = model.encode(
    #     texts, normalize_embeddings=False
    # ).astype('float32')
    # # Normalize the vectors to unit length
    # vectors = vectors / np.linalg.norm(vectors, axis=1, keepdims=True)
    
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    vectors = model(**inputs).text_embeds.detach().cpu().numpy()
    # Normalize the vectors to unit length
    vectors = vectors / np.linalg.norm(vectors, axis=1, keepdims=True)
    
    return vectors

def build_index(texts: list[str]) -> faiss.IndexFlatIP:
    vectors = embed_texts(texts)
    dimension = vectors.shape[1]
    
    index = faiss.IndexFlatIP(dimension)#
    
    label_vectors = vectors.astype('float32')
    index.add(label_vectors)
    
    return index

Load in ontology labels

In [None]:
from rdflib.namespace import OWL, RDFS
import rdflib

g = rdflib.Graph()
g.parse("../../anthropogenic_ontology.ttl", format="turtle")

BASE = ("http://www.semanticweb.org/dbotteld/ontologies/2025/6/sound_ontology#")

labels = []

query = """
SELECT ?cls ?label
WHERE {
  ?cls a owl:Class .
  FILTER NOT EXISTS { ?sub rdfs:subClassOf ?cls . }
  OPTIONAL { ?cls rdfs:label ?label . }
}
"""

no_label_list = ["AudioSet"]

for row in g.query(query, initNs={"owl": OWL, "rdfs": RDFS, "base": BASE}):
    label = row.label if row.label else row.cls
    if str(label) not in no_label_list:  
      labels.append(str(label))

label_vectors = embed_texts(labels)
labels

### Similarity functions

In [None]:
def build_similarity_dataframe(
    query_labels: list[str],
    dataset_labels: list[str],
    top_k: int = 3
) -> pd.DataFrame:
    """
    Build a DataFrame with similarity scores between ontology labels and dataset labels.
    
    Args:
        query_labels: List of ontology labels to query
        dataset_labels: List of dataset labels to build the index from
        top_k: Number of top similar labels to retrieve
    
    Returns:
        DataFrame with columns: ontology_label, dataset_labels, dataset_indices, similarity_scores
    """
    # Build index from dataset labels
    dataset_vectors = embed_texts(dataset_labels)
    dimension = dataset_vectors.shape[1]
    index = faiss.IndexFlatIP(dimension)
    index.add(dataset_vectors)
    
    mapping_data = []
    
    for ontology_label in query_labels:
        # Embed and normalize the ontology label
        query_vector = embed_texts([ontology_label])
        
        # Search for top-k similar dataset labels
        scores, indices = index.search(query_vector, top_k)
        
        # Get matched labels and round scores
        matched_labels = [dataset_labels[i] for i in indices[0]]
        rounded_scores = [float(round(float(score), 4)) for score in scores[0]]
        
        mapping_data.append({
            'ontology_label': ontology_label,
            'dataset_labels': matched_labels,
            'dataset_indices': indices.tolist()[0],
            'similarity_scores': rounded_scores
        })
    
    return pd.DataFrame(mapping_data)

## Class of interest 1: Plane

### audioset

In [None]:
mid_to_label_map = pd.read_csv(
    "../../data/metadata/audioset/mid_to_display_name.tsv",
    sep="\t",
    header=None,
    names=["mid", "display_name"],
)
mid_to_label_dict = dict(zip(mid_to_label_map["mid"], mid_to_label_map["display_name"]))

mid_to_label_map.head()

In [None]:
audioset_labels = pd.read_csv(
    "../../data/metadata/audioset/audioset_train_strong.tsv", sep="\t"
)
audioset_labels['label'] = audioset_labels['label'].apply(lambda x: mid_to_label_dict[x])
audioset_labels.head()
audioset_labels

In [None]:
audioset_label_list = audioset_labels['label'].unique().tolist()

similarity_df = build_similarity_dataframe(
    query_labels=labels,
    dataset_labels=audioset_label_list,
    top_k=6
)
similarity_df

In [None]:
#list all audioset labels found for similarity_df
all_found_labels = set()
for label in similarity_df['dataset_labels']:
    all_found_labels.update(label)
list(all_found_labels)

#### count all labels similar to "Plane"

Labels most similar to "plane":

In [None]:
planes = similarity_df[similarity_df['ontology_label'] == 'plane']['dataset_labels'].values[0]
planes


In [None]:
strong_audioset_labels_train = pd.read_csv("../../data/metadata/audioset/audioset_train_strong.tsv", sep="\t",header=None,names=["filename", "start_time", "end_time", "label","split", "caption"])
strong_audioset_labels_eval = pd.read_csv("../../data/metadata/audioset/audioset_eval_strong.tsv", sep="\t",header=None,names=["filename", "start_time", "end_time", "label", "split", "caption"])

# Merge the two dataframes into one, keeping source split info and resetting the index
strong_audioset_labels = pd.concat(
    [
        strong_audioset_labels_train.assign(split="train"),
        strong_audioset_labels_eval.assign(split="eval"),
    ],
    ignore_index=True,
)

# quick check
strong_audioset_labels.shape, strong_audioset_labels.head()


In [None]:
# replace the MID codes in 'label' with display names, keeping original MID if no mapping exists
strong_audioset_labels['label'] = strong_audioset_labels['label'].map(mid_to_label_dict).fillna(strong_audioset_labels['label'])
strong_audioset_labels.head()


In [None]:
# Count exact matches for each plane-related label
plane_counts = strong_audioset_labels['label'].isin(planes).value_counts().to_dict()
print(plane_counts)

# Also show counts broken down by split (train / eval)
strong_audioset_labels[strong_audioset_labels['label'].isin(planes)].groupby(['label', 'split']).size().unstack(fill_value=0)

### urban_sound_8k

(no airplane sounds in this dataset, contains only 10 classes)

### Esc-50

Load in labels from ESC-50

In [None]:
esc50_labels = pd.read_csv("../../data/metadata/esc50/esc50.csv")
esc50_labels.head()

In [None]:
unique_esc50_labels = esc50_labels['category'].unique().tolist()
unique_esc50_labels

In [None]:
esc_similarity_df = build_similarity_dataframe(
    query_labels=labels,
    dataset_labels=unique_esc50_labels,
    top_k=6
)
esc_similarity_df

Now we have all the information needed to count the number of samples in ESC-50 similar to "plane"

In [None]:
esc_similarity_df[esc_similarity_df['ontology_label'] == 'plane']['dataset_labels'].values[0]

Only "airplane" is similar enough to "plane" since we already have helicopter and engine in our ontology labels.

In [None]:
# count 'airplane' occurrences per fold
esc50_airplane_counts = esc50_labels[esc50_labels['category'] == 'airplane'].groupby('fold').size().sort_index()
esc50_airplane_counts

Unique airplane samples in ESC-50:

In [None]:
# All unique airplane recordings in ESC-50
airplane_rows = esc50_labels[esc50_labels['category'] == 'airplane']
unique_airplane_filenames = airplane_rows['filename'].unique()
len(unique_airplane_filenames)

In [None]:
# get all overlapping airplane samples
esc50_labels[esc50_labels["category"] == "airplane"].shape[0] - len(unique_airplane_filenames)

### FSD50K

In [None]:
fsd_vocab_labels = pd.read_csv("../../data/metadata/fsd50k_labels/vocabulary.csv",header =None,names =['label','mid'])
fsd_vocab_labels.head()

In [None]:
fsd_similarity_df = build_similarity_dataframe(
    query_labels=labels,
    dataset_labels=fsd_vocab_labels['label'].tolist(),
    top_k=6
)
fsd_similarity_df

In [None]:
plane_labels = fsd_similarity_df[fsd_similarity_df['ontology_label'] == 'plane']['dataset_labels'].values[0][0:2]
plane_labels

In [None]:
# concat train and eval labels
fsd_labels_train = pd.read_csv("../../data/metadata/fsd50k_labels/dev.csv")
fsd_labels_dev = pd.read_csv("../../data/metadata/fsd50k_labels/eval.csv")
fsd_labels = pd.concat([fsd_labels_train.assign(split="train"), fsd_labels_dev.assign(split="eval")], ignore_index=True)
fsd_labels.head()

In [None]:
# count 'plane' occurrences in FSD50K by split when any label matches plane_labels
plane_set = set(plane_labels)
fsd_labels['is_plane'] = fsd_labels['labels'].str.split(',').apply(lambda labs: any(l in plane_set for l in labs))

plane_counts_fsd = fsd_labels[fsd_labels['is_plane']].groupby('split').size().sort_index()
plane_counts_fsd

### Captdure (captioned sounds dataset)

(No planes, single sound sources mostly indoor sounds)

### Sounddesc (captioned sounds dataset bbc sound effects)

In [None]:
sounddescs_categories = pd.read_pickle("../../data/metadata/sounddesc/sounddescs_categories.pkl")
sounddescs_categories

Captions for later, specifically for clap embedding models

In [None]:
sounddescs_descriptions = pd.read_pickle("../../data/metadata/sounddesc/sounddescs_descriptions.pkl")
sounddescs_descriptions

In [None]:
# list all unique categories found in values of the dictionary
unique_sounddesc = list(set([item for sublist in sounddescs_categories.values() for item in sublist]))


sounddescs_similarity_df = build_similarity_dataframe(
    query_labels=labels,
    dataset_labels=unique_sounddesc,
    top_k=6
)
sounddescs_similarity_df

In [None]:
# similarity for 'plane'
sounddescs_similarity_df[sounddescs_similarity_df['ontology_label'] == 'plane']['dataset_labels'].values[0]
# only first one is really relevant: "Aircraft"

Count number of samples similar to "Aircraft" in Sounddescs

Load in the sounddesc cleaned/grouped splits

In [None]:
grouped_sounddesc_split = pd.DataFrame()
with open("./metadata/splits_sounddesc/group_filtered_split01/test_list.txt") as f:
    test_list = f.read().splitlines()
    grouped_sounddesc_split = pd.DataFrame({
        'filename': test_list,
        'split': 'test'
    })

with open("./metadata/splits_sounddesc/group_filtered_split01/train_list.txt") as f:
    train_list = f.read().splitlines()
    train_df = pd.DataFrame({
        'filename': train_list,
        'split': 'train'
    })
    grouped_sounddesc_split = pd.concat([grouped_sounddesc_split, train_df], ignore_index=True)

with open("./metadata/splits_sounddesc/group_filtered_split01/val_list.txt") as f:
    val_list = f.read().splitlines()
    val_df = pd.DataFrame({
        'filename': val_list,
        'split': 'val'
    })
    grouped_sounddesc_split = pd.concat([grouped_sounddesc_split, val_df], ignore_index=True)

grouped_sounddesc_split.head()

attach the categories to the filenames in the proper grouped split

In [None]:
grouped_sounddesc_split['categories'] = grouped_sounddesc_split['filename'].map(sounddescs_categories)
grouped_sounddesc_split.head()

In [None]:
# Count the number of samples that has "Aircraft" as their category grouped by split

aircraft_counts = grouped_sounddesc_split.explode('categories')
aircraft_counts = aircraft_counts[aircraft_counts['categories'] == 'Aircraft']
aircraft_counts = aircraft_counts.groupby('split').size().sort_index()
aircraft_counts



In [None]:
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

# Collect data from each dataset with consistent train/test splits

# AudioSet - combine 'eval' into 'test'
audioset_plane_df = strong_audioset_labels[strong_audioset_labels["label"].isin(planes)]
audioset_split_mapping = {'train': 'train', 'eval': 'test'}
audioset_plane_df['unified_split'] = audioset_plane_df['split'].map(audioset_split_mapping)
audioset_counts = audioset_plane_df.groupby("unified_split").size().to_dict()

# ESC-50 - folds 1-4 as train, fold 5 as test
esc50_airplane_df = esc50_labels[esc50_labels['category'] == 'airplane'].copy()
esc50_airplane_df['unified_split'] = esc50_airplane_df['fold'].apply(lambda x: 'train' if x <= 4 else 'test')
esc50_counts = esc50_airplane_df.groupby('unified_split').size().to_dict()

# FSD50K - 'eval' becomes 'test', 'train' stays 'train'
fsd_plane_df = fsd_labels[fsd_labels['is_plane']].copy()
fsd_split_mapping = {'train': 'train', 'eval': 'test'}
fsd_plane_df['unified_split'] = fsd_plane_df['split'].map(fsd_split_mapping)
fsd_plane_counts = fsd_plane_df.groupby('unified_split').size().to_dict()

# SoundDescs - combine 'val' with 'train', 'test' stays 'test'
sounddesc_aircraft_df = grouped_sounddesc_split.explode('categories')
sounddesc_aircraft_df = sounddesc_aircraft_df[sounddesc_aircraft_df['categories'] == 'Aircraft'].copy()
sounddesc_split_mapping = {'train': 'train', 'val': 'train', 'test': 'test'}
sounddesc_aircraft_df['unified_split'] = sounddesc_aircraft_df['split'].map(sounddesc_split_mapping)
sounddesc_counts = sounddesc_aircraft_df.groupby('unified_split').size().to_dict()

# Create individual bar charts for each dataset
fig_audioset = px.bar(
    x=list(audioset_counts.keys()),
    y=list(audioset_counts.values()),
    title="AudioSet - Airplane Samples by Split",
    labels={"x": "Split", "y": "Count"},
    text=list(audioset_counts.values()),
)
fig_audioset.update_traces(textposition="inside", textangle=0)
fig_audioset.update_layout(width=600, height=500)
fig_audioset.show()

# ESC-50
fig_esc50 = px.bar(
    x=list(esc50_counts.keys()),
    y=list(esc50_counts.values()),
    title="ESC-50 - Airplane Samples by Split",
    labels={"x": "Split", "y": "Count"},
    text=list(esc50_counts.values()),
)
fig_esc50.update_layout(width=600, height=500)
fig_esc50.update_traces(textposition="inside", textangle=0)
fig_esc50.show()

# FSD50K
fig_fsd = px.bar(
    x=list(fsd_plane_counts.keys()),
    y=list(fsd_plane_counts.values()),
    title="FSD50K - Airplane Samples by Split",
    labels={"x": "Split", "y": "Count"},
    text=list(fsd_plane_counts.values()),
)
fig_fsd.update_traces(textposition="inside", textangle=0)
fig_fsd.update_layout(width=600, height=500)
fig_fsd.show()

# SoundDescs
fig_sounddesc = px.bar(
    x=list(sounddesc_counts.keys()),
    y=list(sounddesc_counts.values()),
    title="SoundDescs - Aircraft Samples by Split",
    labels={"x": "Split", "y": "Count"},
    text=list(sounddesc_counts.values()),
)
fig_sounddesc.update_traces(textposition="inside", textangle=0)
fig_sounddesc.update_layout(width=600, height=500)
fig_sounddesc.show()

# Combined comparison - Total samples per dataset
total_samples = {
    "AudioSet": sum(audioset_counts.values()),
    "ESC-50": sum(esc50_counts.values()),
    "FSD50K": sum(fsd_plane_counts.values()),
    "SoundDescs": sum(sounddesc_counts.values()),
}

fig_total = px.bar(
    x=list(total_samples.keys()),
    y=list(total_samples.values()),
    title="Total Airplane/Aircraft Samples Across All Datasets",
    labels={"x": "Dataset", "y": "Total Count"},
    text=list(total_samples.values()),
    color=list(total_samples.keys()),
)
fig_total.update_traces(textposition="inside", textangle=0)
fig_total.update_layout(showlegend=False, width=600, height=500)
fig_total.show()

# Grand total
grand_total = sum(total_samples.values())
print(f"\nGrand Total Airplane Samples: {grand_total}")

# Create a stacked bar chart showing splits across datasets
split_data = []
for dataset, counts_dict in [
    ("AudioSet", audioset_counts),
    ("ESC-50", esc50_counts),
    ("FSD50K", fsd_plane_counts),
    ("SoundDescs", sounddesc_counts),
]:
    for split, count in counts_dict.items():
        split_data.append({"Dataset": dataset, "Split": split, "Count": count})

split_df = pd.DataFrame(split_data)

fig_stacked = px.bar(
    split_df,
    x="Dataset",
    y="Count",
    color="Split",
    title="Airplane Samples Distribution by Dataset and Split",
    barmode="stack",
    text="Count",
)
fig_stacked.update_traces(textposition="inside", textangle=0)
fig_stacked.update_layout(width=500, height=600)
fig_stacked.show()

In [None]:
# Create a pie chart showing overall train/test distribution across all datasets
overall_train = sum([counts_dict.get('train', 0) for counts_dict in [audioset_counts, esc50_counts, fsd_plane_counts, sounddesc_counts]])
overall_test = sum([counts_dict.get('test', 0) for counts_dict in [audioset_counts, esc50_counts, fsd_plane_counts, sounddesc_counts]])

fig_pie = go.Figure(data=[go.Pie(
    labels=['Train', 'Test'],
    values=[overall_train, overall_test],
    text=[overall_train, overall_test],
    textposition='inside',
    textinfo='label+value+percent',
    marker=dict(colors=['#636EFA', '#EF553B'])
)])

fig_pie.update_layout(
    title="Overall Train/Test Distribution Across All Datasets",
    width=600,
    height=500
)
fig_pie.show()

print(f"\nTotal Train Samples: {overall_train}")
print(f"Total Test Samples: {overall_test}")
print(f"Train/Test Ratio: {overall_train/overall_test:.2f}")

In [None]:
# Build a combined mapping DataFrame and create visualizations (heatmap + sankey).
# This cell flattens existing similarity DataFrames and draws interactive plots.
import plotly.express as px
import plotly.graph_objects as go
from collections import OrderedDict

def flatten_similarity_df(sim_df, dataset_name):
    rows = []
    for _, r in sim_df.iterrows():
        ont = r['ontology_label']
        labels = r['dataset_labels']
        indices = r['dataset_indices']
        scores = r['similarity_scores']
        for lbl, idx, score in zip(labels, indices, scores):
            rows.append({
                'ontology_label': ont,
                'dataset_label': lbl,
                'dataset_index': int(idx),
                'score': float(score),
                'dataset': dataset_name
            })
    return pd.DataFrame(rows)

dfs = []
# AudioSet
if 'similarity_df' in globals():
    dfs.append(flatten_similarity_df(similarity_df, 'AudioSet'))
else:
    try:
        audioset_label_list
    except NameError:
        audioset_labels_tmp = pd.read_csv('../../data/metadata/audioset/class_labels_audioset.csv')
        audioset_label_list = audioset_labels_tmp['display_name'].unique().tolist()
    dfs.append(flatten_similarity_df(build_similarity_dataframe(labels, audioset_label_list, top_k=3), 'AudioSet'))

# ESC-50
if 'esc_similarity_df' in globals():
    dfs.append(flatten_similarity_df(esc_similarity_df, 'ESC-50'))
else:
    try:
        unique_esc50_labels
    except NameError:
        esc50_tmp = pd.read_csv('../../data/metadata/esc50/esc50.csv')
        unique_esc50_labels = esc50_tmp['category'].unique().tolist()
    dfs.append(flatten_similarity_df(build_similarity_dataframe(labels, unique_esc50_labels, top_k=3), 'ESC-50'))

# FSD50K
if 'fsd_similarity_df' in globals():
    dfs.append(flatten_similarity_df(fsd_similarity_df, 'FSD50K'))
else:
    try:
        fsd_vocab_labels
    except NameError:
        fsd_vocab_labels = pd.read_csv('../../data/metadata/fsd50k_labels/vocabulary.csv', header=None, names=['label','mid'])
    dfs.append(flatten_similarity_df(build_similarity_dataframe(labels, fsd_vocab_labels['label'].tolist(), top_k=3), 'FSD50K'))

# SoundDescs
if 'sounddescs_similarity_df' in globals():
    dfs.append(flatten_similarity_df(sounddescs_similarity_df, 'SoundDescs'))
else:
    try:
        unique_sounddesc
    except NameError:
        unique_sounddesc = []
    if len(unique_sounddesc) > 0:
        dfs.append(flatten_similarity_df(build_similarity_dataframe(labels, unique_sounddesc, top_k=3), 'SoundDescs'))

if len(dfs) == 0:
    raise RuntimeError('No similarity DataFrames or label lists found. Run the similarity cells first.')

combined_df = pd.concat(dfs, ignore_index=True)
combined_df['ontology_label'] = combined_df['ontology_label'].astype(str)

# Aggregate duplicates by max score (keeps strongest match per pair)
agg_df = combined_df.groupby(['ontology_label','dataset','dataset_label'], as_index=False).agg({'score':'max'})

# Heatmap: pivot ontology x dataset_label (shows strongest similarity)
heat_pivot = agg_df.pivot_table(index='ontology_label', columns='dataset_label', values='score', aggfunc='max', fill_value=0)
if heat_pivot.shape[0] > 0 and heat_pivot.shape[1] > 0:
    fig_heat = px.imshow(heat_pivot, labels=dict(x='Dataset Label', y='Ontology Label', color='Similarity'),
                     aspect='auto', title='Ontology → Dataset Label Similarity Heatmap')
    fig_heat.update_layout(height=800, width=1000)
    fig_heat.show()
else:
    print('Heatmap: no data to display')

# Sankey: build nodes and aggregated link values across dataset groups
sankey_df = agg_df.copy()
# Optionally scale scores so sankey has visible widths
sankey_df['value'] = sankey_df['score']

ont_nodes = list(OrderedDict.fromkeys(sankey_df['ontology_label'].tolist()))
target_nodes = list(OrderedDict.fromkeys((sankey_df['dataset'] + '::' + sankey_df['dataset_label']).tolist()))
nodes = ont_nodes + target_nodes
node_index = {n:i for i,n in enumerate(nodes)}

sources = []
targets = []
values = []
labels = []
for _, r in sankey_df.iterrows():
    s = node_index[r['ontology_label']]
    t = node_index[r['dataset'] + '::' + r['dataset_label']]
    sources.append(s)
    targets.append(t)
    values.append(max(r['value'], 0.0))

if len(sources) > 0:
    link = dict(source=sources, target=targets, value=values)
    fig_sankey = go.Figure(data=[go.Sankey(node=dict(label=nodes, pad=15, thickness=15), link=link)])
    fig_sankey.update_layout(title_text='Ontology → Dataset Labels (Sankey)', font_size=10, height=800, width=1200)
    fig_sankey.show()
else:
    print('Sankey: no links to display')

# Quick check output
print('Saved: ontology_dataset_heatmap.html, ontology_dataset_sankey.html')    
fig_sankey.write_image('ontology_dataset_sankey.jpeg')    
fig_heat.write_image('ontology_dataset_heatmap.jpeg') # Save interactive HTMLs for sharing (optional)display(agg_df.sort_values(['ontology_label','score'], ascending=[True, False]).head(20))Combined mapping sample:')

In [None]:
# Sample 6 ontology and 6 dataset labels from the full combined_df (cell 58's output) and plot a heatmap for a comprehensive example
# Fixes: avoid using list.append()/extend() return values, ensure 'plane' in sampled ontology, and include aircraft-related dataset labels if present
import random
import pandas as pd
import plotly.express as px

# Ensure combined_df exists
if 'combined_df' not in globals():
    raise RuntimeError('combined_df not found. Run the combined mapping cell first.')

df_full = combined_df.copy()

# If any AudioSet labels are MIDs, map them to display names using mid_to_label_dict if available and dataset column exists
if 'mid_to_label_dict' in globals() and 'dataset' in df_full.columns:
    mask = df_full['dataset'] == 'AudioSet'
    df_full.loc[mask, 'dataset_label'] = df_full.loc[mask, 'dataset_label'].map(lambda x: mid_to_label_dict.get(x, x))

# Get all ontology labels and all unique dataset labels from combined_df
all_ontology = df_full['ontology_label'].astype(str).unique().tolist()
all_dataset = sorted(df_full['dataset_label'].astype(str).unique().tolist())

# Build sampled ontology list: pick up to 5 random and ensure 'plane' is included (if present) to make a 6-item example
sample_size = 6
if len(all_ontology) <= sample_size:
    sampled_ontology = all_ontology.copy()
else:
    # sample sample_size-1 and reserve one slot for 'plane' if available
    sampled_ontology = random.sample(all_ontology, sample_size - 1)
    if 'plane' in all_ontology and 'plane' not in sampled_ontology:
        sampled_ontology.append('plane')
    else:
        # if 'plane' not available, fill the last slot from remaining
        remaining = [o for o in all_ontology if o not in sampled_ontology]
        if remaining:
            sampled_ontology.append(random.choice(remaining))

# Build sampled dataset list: try to include known aircraft-related labels if present
preferred = ["Aircraft engine", "Fixed-wing aircraft, airplane", "Aircraft", "airplane"]
extras = [p for p in preferred if p in all_dataset]
if len(all_dataset) <= sample_size:
    sampled_dataset = all_dataset.copy()
else:
    # choose random labels excluding extras then add extras (up to sample_size)
    pool = [d for d in all_dataset if d not in extras]
    n_random = max(0, sample_size - len(extras))
    sampled_dataset = random.sample(pool, min(n_random, len(pool)))
    # add extras while keeping total <= sample_size
    for e in extras:
        if len(sampled_dataset) < sample_size:
            sampled_dataset.append(e)

# Final safety: ensure lengths >0
if len(sampled_ontology) == 0 or len(sampled_dataset) == 0:
    raise RuntimeError('Not enough ontology or dataset labels to sample for example.')

# Build a similarity matrix: rows=ontology, cols=dataset label, values=similarity score (max if multiple)
sim_matrix = pd.DataFrame(0.0, index=sampled_ontology, columns=sampled_dataset)
for _, row in df_full.iterrows():
    ont = str(row.get('ontology_label'))
    lbl = str(row.get('dataset_label'))
    score = float(row.get('score', 0.0))
    if ont in sampled_ontology and lbl in sampled_dataset:
        sim_matrix.loc[ont, lbl] = max(sim_matrix.loc[ont, lbl], score)

fig = px.imshow(
    sim_matrix,
    text_auto=True,
    color_continuous_scale='Blues',
    labels=dict(x='Dataset Label', y='Ontology Label', color='Similarity'),
    title='Sampled Ontology to Dataset Label Mapping (Comprehensive Example)'
)
fig.update_layout(width=800, height=520)
fig.show()

print('Sampled ontology labels:', list(sim_matrix.index))
print('Sampled dataset labels:', list(sim_matrix.columns))