In [1]:
from projects.callbird.src.datasets.load_test_dataset import load_test_dataset
from projects.callbird.src.datasets.load_train_dataset import load_train_dataset

In [2]:
def _process_loaded_multitask_data(dataset, decode: bool = True):
    def add_event_columns(example):
        example["detected_events"] = (example["start_time"], example["end_time"])
        # TODO: Fix event_cluster value
        example["event_cluster"] = [0]
        return example

    dataset = dataset.map(add_event_columns)

    # dataset = self._ensure_train_test_splits(dataset)
    def add_multilabel_column(example):
        example["ebird_code_multilabel"] = example["ebird_code"]
        example["call_type_multilabel"] = example["short_call_type"]
        return example
    
    dataset = dataset.map(add_multilabel_column)

    ebird_labels = set()
    calltype_labels = set()

    for split in dataset.keys():
        ebird_labels.update(dataset[split]["ebird_code"])
        calltype_labels.update(dataset[split]["short_call_type"])

    ebird_labels = sorted(list(ebird_labels))
    calltype_labels = sorted(list(calltype_labels))

    ebird_label_to_id = {lbl: i for i, lbl in enumerate(ebird_labels)}
    calltype_label_to_id = {lbl: i for i, lbl in enumerate(calltype_labels)}

    def label_to_id_fn(batch):
        for i in range(len(batch["ebird_code_multilabel"])):
            batch["ebird_code_multilabel"][i] = ebird_label_to_id[batch["ebird_code_multilabel"][i]]

        for i in range(len(batch["call_type_multilabel"])):
            batch["call_type_multilabel"][i] = calltype_label_to_id[batch["call_type_multilabel"][i]]

        return batch

    dataset = dataset.map(
        label_to_id_fn,
        batched=True,
        batch_size=500,
        load_from_cache_file=True,
        num_proc=1,
    )

    return dataset

In [3]:
test_dataset = load_test_dataset()
train_dataset = load_train_dataset()

test_dataset = _process_loaded_multitask_data(test_dataset)
train_dataset = _process_loaded_multitask_data(train_dataset)

test_multilabel_map = {}
train_multilabel_map = {}

for example in test_dataset['train']:
    multi_label = example['ebird_code_multilabel']
    ebird_code = example['ebird_code']

    if ebird_code in test_multilabel_map and multi_label != test_multilabel_map[ebird_code]:
        print(f"Conflict found for eBird code {ebird_code}: train={multi_label}, test={test_multilabel_map[ebird_code]}")
    else:
        test_multilabel_map[ebird_code] = multi_label

for example in train_dataset['train']:
    multi_label = example['ebird_code_multilabel']
    ebird_code = example['ebird_code']

    if ebird_code in train_multilabel_map and multi_label != train_multilabel_map[ebird_code]:
        print(f"Conflict found for eBird code {ebird_code}: train={multi_label}, test={train_multilabel_map[ebird_code]}")
    else:
        train_multilabel_map[ebird_code] = multi_label

# Check differences in map
for ebird_code in set(test_multilabel_map.keys()).union(set(train_multilabel_map.keys())):
    test_label = test_multilabel_map.get(ebird_code)
    train_label = train_multilabel_map.get(ebird_code)
    if test_label != train_label:
        print(f"Difference found for eBird code {ebird_code}: train={train_label}, test={test_label}")


Resolving data files:   0%|          | 0/153 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/56 [00:00<?, ?it/s]