# Helper

## Dataset

In [None]:
from datasets import load_dataset, Dataset

def load_birdset_train_split(name: str):
    return load_dataset(
    path="DBD-research-group/BirdSet",
    name=name,
    cache_dir=f"/home/rantjuschin/data_birdset/{name}",
    trust_remote_code=True,
    split="train"
)

## Data Selection

In [None]:
from datasets import Dataset
import soundfile as sf

def get_all_legal_indeces(dataset: Dataset, condition: callable) -> list[int]:
    legal_indeces = []
    
    for i in range(len(dataset)):
        if condition(dataset, i):
            legal_indeces.append(i)

    return legal_indeces


def strict_condition(dataset: Dataset, idx: int) -> bool:
    """
    This condition only allows files that up to 5s long so that no event detection has to occur when sampling.
    """
    file_info = sf.info(dataset[idx]["filepath"])
    if file_info.duration <= 5:
        return True
    

def lenient_condition(dataset: Dataset, idx: int) -> bool:
    """
    This condition allows files up to 10s but only if one bird occurence is in the file.
    """
    file_info = sf.info(dataset[idx]["filepath"])
    if file_info.duration <= 10 and (not dataset[idx]["ebird_code_secondary"]) and len(dataset[idx]["ebird_code_multilabel"]) == 1:
        return True

    
def count_labels_from_legal_indeces(dataset: Dataset, indeces: list[int]) -> dict:
    all_labels = set(dataset["ebird_code"])
    unique_label_counts = {label: 0 for label in all_labels}
    for i in indeces:
       unique_label_counts[dataset[i]["ebird_code"]] = unique_label_counts.get(dataset[i]["ebird_code"]) + 1
    return unique_label_counts



## Few-Shot analysis

In [None]:
def analyse_with_few_shot_counts(few_shot_counts: list[int], dataset_label_counts: dict) -> dict:
    results = {}
    for condition, label_counts in dataset_label_counts.items():
        results[condition] = {}
        for few_shot_count in few_shot_counts:
            limited_label_counts = {label: min(few_shot_count, count) for label, count in label_counts.items()}
            
            max_possible_sample_count = len(label_counts) * few_shot_count
            actual_sample_count = sum(limited_label_counts.values())
            total_difference = max_possible_sample_count - actual_sample_count

            max_difference = 0
            min_difference = few_shot_count
            incomplete_labels = 0
            for label_count in limited_label_counts.values():
                if label_count < few_shot_count:
                    incomplete_labels += 1
                    max_difference = max(max_difference, few_shot_count - label_count)
                    min_difference = min(min_difference, few_shot_count - label_count)
            
            # avg and deviaton are only computed on labels that are actually missing samples
            if incomplete_labels > 0:
                avg_difference = total_difference / max(incomplete_labels, 1)
                
                total_difference_to_avg = 0
                for label_count in limited_label_counts.values():
                    if label_count < few_shot_count:
                        total_difference_to_avg += (label_count - avg_difference)**2
                std_deviation = total_difference_to_avg / max((incomplete_labels - 1), 1) 
                std_deviation = std_deviation**(1/2)
            else:
                avg_difference = 0
                std_deviation = 0

            results[condition][few_shot_count] = {"max_possible_sample_count":max_possible_sample_count, 
                                    "actual_sample_count": actual_sample_count,
                                    "total_difference": total_difference, 
                                    "incomplete_labels": incomplete_labels,
                                    "max_difference": max_difference,
                                    "min_difference": min_difference,
                                    "avg_difference": avg_difference,
                                    "std_deviation": std_deviation,
                                    }
    return results

def print_results(results):
    for condition in results.keys():
        print(f"\n{condition} results:")
        printed_first_row = False
        first_row = "\t"
        for few_shot_count, values in results[condition].items():
            if not printed_first_row:
                for key in values.keys():
                    if key in ["max_difference", "min_difference", "avg_difference", "std_deviation"]:
                        first_row += key + "\t\t" 
                    else:
                        first_row += key + "\t"
                print(first_row)
                printed_first_row = True

            output = str(few_shot_count) + "\t\t\t"
            for value in values.values():
                output += f"{value:.2f}\t\t\t"
            print(output)

# Analysis

## PER

In [None]:
PER = load_birdset_train_split("PER")
PER

In [None]:
PER_label_counts = {"strict_condition": count_labels_from_legal_indeces(PER, get_all_legal_indeces(PER, strict_condition)),
                    "lenient_condition": count_labels_from_legal_indeces(PER, get_all_legal_indeces(PER, lenient_condition))}

In [None]:
results = analyse_with_few_shot_counts([2, 4, 8, 16, 32, 64, 128], PER_label_counts)
print_results(results)

In [None]:
# how many classes are completly empty
strict_empty = 0
lenient_empty = 0
for label, count in PER_label_counts["strict_condition"].items():
    if count == 0:
        strict_empty += 1
        
for label, count in PER_label_counts["lenient_condition"].items():
    if count == 0:
        lenient_empty += 1

print(strict_empty, lenient_empty)

## NES

In [None]:
NES = load_birdset_train_split("NES")
NES

In [None]:
NES_label_counts = {"strict_condition": count_labels_from_legal_indeces(NES, get_all_legal_indeces(NES, strict_condition)),
                    "lenient_condition": count_labels_from_legal_indeces(NES, get_all_legal_indeces(NES, lenient_condition))}

In [None]:
results = analyse_with_few_shot_counts([2, 4, 8, 16, 32, 64, 128], NES_label_counts)
print_results(results)

In [None]:
# how many classes are completly empty
strict_empty = 0
lenient_empty = 0
for label, count in NES_label_counts["strict_condition"].items():
    if count == 0:
        strict_empty += 1
        
for label, count in NES_label_counts["lenient_condition"].items():
    if count == 0:
        lenient_empty += 1

print(strict_empty, lenient_empty)

## UHH

In [None]:
UHH = load_birdset_train_split("UHH")
UHH

In [None]:
UHH_label_counts = {"strict_condition": count_labels_from_legal_indeces(UHH, get_all_legal_indeces(UHH, strict_condition)),
                    "lenient_condition": count_labels_from_legal_indeces(UHH, get_all_legal_indeces(UHH, lenient_condition))}

In [None]:
results = analyse_with_few_shot_counts([2, 4, 8, 16, 32, 64, 128], UHH_label_counts)
print_results(results)

In [None]:
# how many classes are completly empty
strict_empty = 0
lenient_empty = 0
for label, count in UHH_label_counts["strict_condition"].items():
    if count == 0:
        strict_empty += 1
        
for label, count in UHH_label_counts["lenient_condition"].items():
    if count == 0:
        lenient_empty += 1

print(strict_empty, lenient_empty)

## HSN

In [None]:
HSN = load_birdset_train_split("HSN")
HSN

In [None]:
HSN_label_counts = {"strict_condition": count_labels_from_legal_indeces(HSN, get_all_legal_indeces(HSN, strict_condition)),
                    "lenient_condition": count_labels_from_legal_indeces(HSN, get_all_legal_indeces(HSN, lenient_condition))}

In [None]:
results = analyse_with_few_shot_counts([2, 4, 8, 16, 32, 64, 128], HSN_label_counts)
print_results(results)

## NBP

In [None]:
NBP = load_birdset_train_split("NBP")
NBP

In [None]:
NBP_label_counts = {"strict_condition": count_labels_from_legal_indeces(NBP, get_all_legal_indeces(NBP, strict_condition)),
                    "lenient_condition": count_labels_from_legal_indeces(NBP, get_all_legal_indeces(NBP, lenient_condition))}

In [None]:
results = analyse_with_few_shot_counts([2, 4, 8, 16, 32, 64, 128], NBP_label_counts)
print_results(results)

## POW

In [None]:
POW = load_birdset_train_split("POW")
POW

In [None]:
POW_label_counts = {"strict_condition": count_labels_from_legal_indeces(POW, get_all_legal_indeces(POW, strict_condition)),
                    "lenient_condition": count_labels_from_legal_indeces(POW, get_all_legal_indeces(POW, lenient_condition))}

In [None]:
results = analyse_with_few_shot_counts([2, 4, 8, 16, 32, 64, 128], POW_label_counts)
print_results(results)

## SSW

In [None]:
SSW = load_birdset_train_split("SSW")
SSW

In [None]:
SSW_label_counts = {"strict_condition": count_labels_from_legal_indeces(SSW, get_all_legal_indeces(SSW, strict_condition)),
                    "lenient_condition": count_labels_from_legal_indeces(SSW, get_all_legal_indeces(SSW, lenient_condition))}

In [None]:
results = analyse_with_few_shot_counts([2, 4, 8, 16, 32, 64, 128], SSW_label_counts)
print_results(results)

## SNE

In [None]:
SNE = load_birdset_train_split("SNE")
SNE

In [None]:
SNE_label_counts = {"strict_condition": count_labels_from_legal_indeces(SNE, get_all_legal_indeces(SNE, strict_condition)),
                    "lenient_condition": count_labels_from_legal_indeces(SNE, get_all_legal_indeces(SNE, lenient_condition))}

In [None]:
results = analyse_with_few_shot_counts([2, 4, 8, 16, 32, 64, 128], SNE_label_counts)
print_results(results)

# Subset Testing

In [2]:
from resources.utils.few_shot.few_shot_subset_creation import create_few_shot_subset
from datasets import load_dataset
subset_one = create_few_shot_subset(
    name="HSN", 
    cache_dir="/mnt/stud/work/rantjuschin/datasets/HSN",
    data_selection_condition="lenient",
    fill_up=False)
#subset_two = create_few_shot_subset(dataset, data_selection_condition=StrictCondition(), fill_up=True)
print(subset_one)
#print(subset_two)

Loading dataset
Applying condition to training data
Mapping satisfying recordings.
Selecting samples for subset
Selecting relevant columns and renaming
One-hot encoding labels


Map:   0%|          | 0/103 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['filepath', 'labels', 'detected_events', 'start_time', 'end_time'],
        num_rows: 103
    })
    test: Dataset({
        features: ['filepath', 'labels', 'detected_events', 'start_time', 'end_time'],
        num_rows: 12000
    })
})
