In [1]:
from src.datamodule.gadme_datamodule import GADMEDataModule
from src.datamodule.base_datamodule import DatasetConfig, LoadersConfig, LoaderConfig
from src.datamodule.components.transforms import TransformsWrapper, PreprocessingConfig, BaseTransforms
from src.datamodule.components.event_mapping import XCEventMapping
from src.datamodule.components.event_decoding import EventDecoding
from src.datamodule.components.feature_extraction import DefaultFeatureExtractor
from omegaconf import DictConfig
import pandas as pd
from datasets import load_dataset, load_from_disk, Audio, DatasetDict, Dataset, IterableDataset, IterableDatasetDict
import torch
from tqdm import tqdm


import src.modules.models.embedding_models.perch_tf_embedding_model as embed
import src.modules.models.embedding_models.embedding_classifier_model as classifier

In [2]:
birdnet = embed.BirdNetTfEmbeddingModel()

INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


In [3]:
num_classes = 22
sample_rate = 48000
window_size_s = 3.0
learning_rate = 1e-3
num_epochs = 5

In [4]:
from src.datamodule.base_datamodule import DatasetConfig, LoadersConfig
from src.datamodule.components.event_mapping import XCEventMapping
from src.datamodule.components.transforms import BaseTransforms


class GADMEDatasetDataModule(GADMEDataModule):
    def __init__(self, dataset: DatasetConfig = ..., loaders: LoadersConfig = ..., transforms: BaseTransforms = ..., mapper: XCEventMapping = ...):
        super().__init__(dataset, loaders, transforms, mapper)
    
    def _load_data(self, decode: bool = False):
        dataset =  super()._load_data(decode)
        self.dataset = dataset
        return dataset

In [5]:
from typing import Literal


from src.datamodule.components.event_decoding import EventDecoding
from src.datamodule.components.feature_extraction import DefaultFeatureExtractor


class EmbeddingTransforms(BaseTransforms):
    def __init__(self, embedding_model: embed.TfEmbeddingModel, task: Literal['multiclass', 'multilabel'] = "multiclass", sampling_rate: int = 3200, max_length: int = 5, decoding: EventDecoding | None = None, feature_extractor: DefaultFeatureExtractor | None = None) -> None:
        super().__init__(task, sampling_rate, max_length, decoding, feature_extractor)
        self.embedding_model = embedding_model
    
    def augment_waveform_batch(self, waveform_batch, attention_mask, batch):
        embeddings = self.embedding_model(waveform_batch)
        return embeddings

In [6]:
from typing import Any
import lightning as L
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler

class DummyModule(L.LightningModule):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
    
    def forward(self, *args: Any, **kwargs: Any) -> Any:
        return super().forward(*args, **kwargs)
    
    def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
        return super().training_step(*args, **kwargs)
    
    def configure_optimizers(self) -> OptimizerLRScheduler:
        return super().configure_optimizers()

In [7]:
dataset_name = "DBD-research-group/gadme_v1"
cache_dir = "/Volumes/BigChongusF/Datasets/Huggingface/gadme_v1/data"
dataset_config = DatasetConfig(cache_dir, "high_sierras", dataset_name, "high_sierras", 2, num_classes, 3, 0.2, "multiclass", None, sample_rate)
loaders_config = LoadersConfig()
loaders_config.train = LoaderConfig(12, True, 6, True, False, True, 2)
loaders_config.valid = LoaderConfig(12, False)
loaders_config.test = LoaderConfig(12, False)
mapper = XCEventMapping(biggest_cluster=True,
                        event_limit=5,
                        no_call=True)
transforms_wrapper = BaseTransforms(
    task = "multiclass",
    sampling_rate=sample_rate,
    max_length=window_size_s)
dm = GADMEDatasetDataModule(dataset_config, loaders_config, transforms_wrapper, mapper)

In [8]:
dm.prepare_data()
dm.setup("fit")

Saving the dataset (0/1 shards):   0%|          | 0/21650 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5413 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10296 [00:00<?, ? examples/s]

In [9]:
def birdnet_forward(x):
    input_values = x["input_values"]
    # print(input_values)
    # print(input_values.shape)
    x["input_values"] = birdnet.forward_embed(input_values, None)
    print(x)
    return x

In [10]:
train_dataset = dm.train_dataset
train_dataset

Dataset({
    features: ['filepath', 'labels', 'detected_events', 'start_time', 'end_time', 'no_call_events'],
    num_rows: 21650
})

In [11]:
def birdnet_forward(x):
    x = x["input_values"]
    embeddings = birdnet.forward_embed(x, device=None)
    return {"embeddings": embeddings}

In [19]:
new_train = train_dataset.map(birdnet_forward, batch_size=100, batched=True)
new_train

Map:   0%|          | 0/21650 [00:00<?, ? examples/s]

Dataset({
    features: ['filepath', 'labels', 'detected_events', 'start_time', 'end_time', 'no_call_events', 'embeddings'],
    num_rows: 21650
})

In [23]:
new_train.set_format("np")

In [24]:
path = "/Volumes/BigChongusF/Datasets/gadme_embeddings/high_sierras"
new_train.save_to_disk(path)

Saving the dataset (0/1 shards):   0%|          | 0/21650 [00:00<?, ? examples/s]

In [25]:
new_ds = load_from_disk(path)

In [29]:
len(new_ds["embeddings"][0])

1024

In [13]:
# train_dataloader = dm.train_dataloader()
# train_dataloader

In [14]:
# model = birdnet

In [15]:
# birdnet = torch.compile(birdnet)

In [16]:
# def generate(data_loader, model):
#     model = torch.compile(model)
#     for batch in tqdm(data_loader):
#         input_values = batch["input_values"]
#         with torch.no_grad():
#             embeddings = model.forward_embed(input_values, None)
#         for i, e in zip(input_values, embeddings):
#             yield {"input_values": i, "embeddings": embeddings}

In [17]:
# dss = Dataset.from_generator(generator=generate, gen_kwargs={"data_loader": train_dataloader, "model": birdnet})
# dss