In [None]:
#%% Download dataset
from datasets import load_dataset

dataset = load_dataset(
    path="ppeyret/NBMSet24",
    name="NBMSet24",
    cache_dir="D:/NBMSet24",
    num_proc=None, # put and integer (number of workers) here if you want to use multiprecessing
    trust_remote_code=True
)


In [None]:
training_dataset_path="D:/NBMSet24"
bg_noise_datapath="D:/Birdset/background_noise"
# bg_noise_datapath=r"D:\MNHN_no_bird_call\Annotations_Yves_pourPaulPeyret"

LOAD TRANSFORMS

In [None]:

from birdset.datamodule.components.transforms import (
    BirdSetTransformsWrapper,
    PreprocessingConfig,
)
from custom_event_decoding import CustomEventDecoding
from birdset.datamodule.components.event_decoding import EventDecoding
from birdset.datamodule.components.feature_extraction import DefaultFeatureExtractor
from birdset.datamodule.components.augmentations import (
    NoCallMixer,
    MultilabelMix,
    AddBackgroundNoise,
    PowerToDB,
)
from birdset.datamodule.components.resize import Resizer
from torch_audiomentations import AddColoredNoise, Gain
from torchaudio.transforms import Spectrogram, MelScale, FrequencyMasking, TimeMasking
from torchvision.transforms import RandomApply
import os


"""EVENT DECODING:
Loads audio from files, extracting a segment based on event timestamps or manually provided start/end times.
Ensures minimum and maximum segment length, adjusting as necessary.
Performs time extension when events are too short, centering them in an extended window.
Extracts a fixed-length random segment from extended events.
Resamples audio to a standardized sampling rate."""
decoder = CustomEventDecoding(
    min_len=1, max_len=5, sampling_rate=32000, extension_time=8, extracted_interval=5
)

feature_extractor = DefaultFeatureExtractor(
    feature_size=1, sampling_rate=32000, padding_value=0.0, return_attention_mask=False
)

nocall = NoCallMixer(
    directory=bg_noise_datapath,
    p=0.075,
    sampling_rate=32000,
    length=5,
)

wav_transforms = {
    # "multilabel_mix": MultilabelMix(
    #     p=0.7, min_snr_in_db=3.0, max_snr_in_db=30.0, mix_target="union"
    # ),
    # "add_background_noise": AddBackgroundNoise(
    #     p=0.5,
    #     min_snr_in_db=3,
    #     max_snr_in_db=30,
    #     sample_rate=32000,
    #     target_rate=32000,
    #     background_paths=bg_noise_datapath,
    # ),
    # "add_colored_noise": AddColoredNoise(
    #     p=0.2, max_f_decay=2, min_f_decay=-2, max_snr_in_db=30, min_snr_in_db=3
    # ),
    "gain": Gain(p=0.2, min_gain_in_db=-18, max_gain_in_db=6),
}

preprocessing = PreprocessingConfig(
    spectrogram_conversion=Spectrogram(n_fft=1024, hop_length=320, power=2.0),
    resizer=Resizer(db_scale=True, target_height=None, target_width=None),
    melscale_conversion=MelScale(n_mels=128, sample_rate=32000, n_stft=513),
    dbscale_conversion=PowerToDB(),
    normalize_spectrogram=True,
    mean=-4.268,
    std=4.569,
)

spec_transforms = {
#     "frequency_masking": RandomApply(
#         p=0.5, transforms=[FrequencyMasking(freq_mask_param=100, iid_masks=True)]
#     ),
#     "time_masking": RandomApply(
#         p=0.5, transforms=[TimeMasking(time_mask_param=100, iid_masks=True)]
#     ),
}


birdset_transforms = BirdSetTransformsWrapper(
    task="multilabel",
    sampling_rate=32000,
    model_type="vision",
    max_length=5,
    decoding=decoder,
    feature_extractor=feature_extractor,
    # nocall_sampler=nocall,
    waveform_augmentations=wav_transforms,
    preprocessing=preprocessing,
    spectrogram_augmentations=spec_transforms,
)

In [None]:
dataset["train"][0]

Filter labels with a list of species

In [None]:
import datasets
# Define the target species list (replace with actual species codes)
# target_species = ["Turdus iliacus", "Turdus philomelos", "Turdus merula"]
target_species=[
"Alauda arvensis",
"Motacilla alba",
"Motacilla flava",
"Branta bernicla",
"Nycticorax nycticorax",
"Calidris alpina",
"Gallinago gallinago",
"Coturnix coturnix",
"Anas platyrhynchos",
"Carduelis carduelis",
"Tringa nebularia",
"Tringa ochropus",
"Actitis hypoleucos",
"Strix aluco",
"Corvus corone",
"Numenius arquata",
"Numenius phaeopus",
"Gallinula chloropus",
"Charadrius hiaticula",
"Turdus iliacus",
"Turdus philomelos",
"Haematopus ostralegus",
"Ardea cinerea",
"Melanitta nigra",
"Turdus merula",
"Passer domesticus",
"Charadrius dubius",
"Fringilla coelebs",
"Anthus pratensis",
"Erithacus rubecula",
]

# Create a ClassLabel feature for target species
target_class_label = datasets.ClassLabel(names=target_species)

In [None]:

# Function to assign the new class label
def assign_target_label(example):
    # Check if ebird_code is in target species
    if example["label"] in target_species:
        example["target_label"] = target_class_label.str2int(example["label"])
        example["ebird_code"] = target_class_label.str2int(example["label"])
        example["ebird_code_multilabel"] = [target_class_label.str2int(example["label"])]
    else:
        # Assign None (-1) for non-target species
        example["target_label"] = -1
        example["ebird_code"] = -1
        example["ebird_code_multilabel"] = [-1]
    return example

# Apply transformation
dataset = dataset.map(assign_target_label)

# Set target_label as a ClassLabel feature, ignoring non-target values
dataset = dataset.filter(lambda x: x["target_label"] != -1).cast_column("target_label", target_class_label)


In [None]:
# # Define the list of labels you want to keep
# from datasets import DatasetDict
# target_labels = {"Turdus iliacus", "Turdus philomelos", "Turdus merula"}  # Use a set for faster lookup

# # Filter train and test splits
# filtered_train = dataset["train"].filter(lambda example: example["label"] in target_labels)
# filtered_test = dataset["test"].filter(lambda example: example["label"] in target_labels)

# # Replace the original dataset with filtered subsets
# filtered_dataset = DatasetDict()
# filtered_dataset["train"] = filtered_train
# filtered_dataset["test"] = filtered_test

# # Verify the filtering
# print(filtered_dataset)

In [None]:
# columns_to_keep = {"filepath", "labels", "start_time", "end_time"}

# removable_train_columns = [
#     column for column in dataset["train"].column_names if column not in columns_to_keep
# ]
# removable_test_columns = [
#     column for column in dataset["test"].column_names if column not in columns_to_keep
# ]
# print(removable_test_columns, "\n", removable_train_columns)
# # %%
# dataset["train"] = dataset["train"].remove_columns(removable_train_columns)
# dataset["test"] = dataset["test"].remove_columns(removable_test_columns)
# # %%
# print(dataset)

In [None]:
unique_labels_train = set(dataset["train"]["label"])
unique_labels_test = set(dataset["test"]["label"])

# Get all unique labels across both splits
all_unique_labels = unique_labels_train.union(unique_labels_test)

# Print the unique labels
print("Unique labels in train set:", unique_labels_train)
print("Unique labels in test set:", unique_labels_test)
print("All unique labels:", all_unique_labels)

In [None]:
dataset = dataset.rename_column("target_label", "labels")

In [None]:
# Convert to one hot
from toolkit import classes_one_hot

dataset = dataset.map(
    lambda batch: classes_one_hot(batch, num_classes=len(target_species)),
    batched=True,
    batch_size=300,
    load_from_cache_file=True,
    num_proc=1,
    desc=f"One-hot-encoding labels.",
)

In [None]:
dataset["train"][0]

In [None]:
dataset_table=dataset

In [None]:
dataset_table['train'][0]['filepath']

In [None]:
dataset["train"].set_transform(birdset_transforms, output_all_columns=False)

In [None]:
# If you have a test dataset already sliced you may want remove eventmapping from test transorm
test_transforms = BirdSetTransformsWrapper(
    task="multilabel",
    sampling_rate=32000,
    model_type="vision",
    max_length=5,
    decoding=decoder,
    feature_extractor=feature_extractor,
    nocall_sampler=None,
    waveform_augmentations=[],
    preprocessing=preprocessing,
    spectrogram_augmentations=[],
)

dataset["test"].set_transform(test_transforms, output_all_columns=False)# %%


In [None]:
dataset["valid"]=dataset["test"] # TODO: create a real test split for testing 

In [None]:

from toolkit import CustomDatamodule
# wrapping up in datamodule
datamodule = CustomDatamodule(
    dataset=dataset, batch_size=32, num_workers=0, num_classes=len(target_species), task="multilabel"
)

datamodule.dataset

In [None]:
dl = datamodule.train_dataloader()

In [None]:
import matplotlib.pyplot as plt
import numpy as np


sample = next(iter(dl))["input_values"][0]

plt.imshow(np.flipud(sample.squeeze().numpy()))

In [None]:
# # EXPORT ALL TENSORS TO PNG
# from toolkit import save_dataloader_tensors_as_png

# save_dataloader_tensors_as_png(dl,output_dir="tensor_plots")

IMPORT THE MODEL

In [None]:
N_EPOCH=10
num_classes=len(target_species)

In [None]:
from lightning_module import ConvNextClassifierLightningModule

model = ConvNextClassifierLightningModule(num_classes=num_classes, num_epochs=N_EPOCH)





In [None]:
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import RichModelSummary
from lightning.pytorch.loggers import MLFlowLogger

mlflow_logger = MLFlowLogger(
    experiment_name="first_model", tracking_uri="mlruns/"
)


model_checkpoint = ModelCheckpoint(
    dirpath="./callback_checkpoints",
    monitor="val/BCEWithLogitsLoss",
    verbose=False,
    save_last=False,
    save_top_k=1,
    mode="min",
    auto_insert_metric_name=False,
    save_weights_only=False,
    every_n_train_steps=None,
    train_time_interval=None,
    every_n_epochs=1,
    save_on_train_epoch_end=None,
)

rich_model_summary = RichModelSummary(max_depth=1)

trainer = L.Trainer(
    min_epochs=1,
    max_epochs=N_EPOCH,
    gradient_clip_val=0.5,
    precision=16,
    accumulate_grad_batches=1,
    callbacks=[model_checkpoint, rich_model_summary],
    logger=mlflow_logger
)

In [None]:
trainer.fit(datamodule=datamodule, model=model)

In [None]:
trainer.callback_metrics

In [None]:
ckpt_path = trainer.checkpoint_callback.best_model_path
print(ckpt_path)

In [None]:
trainer.test(datamodule=datamodule, model=model)

In [None]:
trainer.test(datamodule=datamodule, model=model, ckpt_path=ckpt_path)