In [1]:
from src.perch.inference import models
from ml_collections import config_dict

In [2]:
from src.datamodule.gadme_datamodule import GADMEDataModule
from src.datamodule.base_datamodule import DatasetConfig, LoadersConfig, LoaderConfig
from src.datamodule.components.transforms import TransformsWrapper, PreprocessingConfig, BatchTransformer
from src.datamodule.components.event_mapping import XCEventMapping
from src.datamodule.components.event_decoding import EventDecoding
from src.datamodule.components.feature_extraction import DefaultFeatureExtractor
from omegaconf import DictConfig

In [3]:
model_choice = "perch"
model_choice = "birdnet"

In [4]:
perch_tfhub_version = 2
perch_model_path = ''
birdnet_model_path = '/Users/moritzrichert/Downloads/V2.4/BirdNET_GLOBAL_6K_V2.4_Model_FP32.tflite'

In [5]:
embedding_model_key = "birdnet"

model_config = config_dict.ConfigDict()
if model_choice == "perch":
    model_config.window_size_s = 5.0
    model_config.hop_size_s = 5.0
    model_config.sample_rate = 32000
    model_config.tfhub_version = perch_tfhub_version
    model_config.model_path = perch_model_path
if model_choice == "birdnet":
    model_config.window_size_s = 3.0
    model_config.hop_size_s = 3.0
    model_config.sample_rate = 48000
    model_config.model_path = birdnet_model_path
    model_config.class_list_name = 'birdnet_v2_4'
    model_config.num_tflite_threads = 4


In [6]:
model_config.sample_rate

48000

In [7]:
model_class = models.model_class_map()[embedding_model_key]
embedding_model = model_class.from_config(model_config)

INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


In [8]:
embedding_model

BirdNet(sample_rate=48000, model_path='/Users/moritzrichert/Downloads/V2.4/BirdNET_GLOBAL_6K_V2.4_Model_FP32.tflite', model=<tensorflow.lite.python.interpreter.Interpreter object at 0x10643dc00>, tflite=True, class_list=ClassList(namespace='birdnet', classes=('Abroscopus albogularis', 'Abroscopus schisticeps', 'Abroscopus superciliaris', 'Aburria aburri', 'Acanthagenys rufogularis', 'Acanthidops bairdi', 'Acanthis cabaret', 'Acanthis flammea', 'Acanthis hornemanni', 'Acanthisitta chloris', 'Acanthiza apicalis', 'Acanthiza chrysorrhoa', 'Acanthiza ewingii', 'Acanthiza inornata', 'Acanthiza lineata', 'Acanthiza nana', 'Acanthiza pusilla', 'Acanthiza reguloides', 'Acanthiza uropygialis', 'Acanthorhynchus superciliosus', 'Acanthorhynchus tenuirostris', 'Accipiter badius', 'Accipiter bicolor', 'Accipiter cirrocephalus', 'Accipiter cooperii', 'Accipiter fasciatus', 'Accipiter gentilis', 'Accipiter gularis', 'Accipiter hiogaster', 'Accipiter melanoleucus', 'Accipiter minullus', 'Accipiter nis

In [9]:
bird_net_classes = len(embedding_model.class_list.classes)
bird_net_classes

6522

In [10]:
dataset_name = "DBD-research-group/gadme_v1"
cache_dir = "/Volumes/BigChongusF/Datasets/Huggingface/gadme_v1/data"
dataset_config = DatasetConfig(cache_dir, "high_sierras", dataset_name, "high_sierras", 42, bird_net_classes, 3, 0.2, "multiclass", None, model_config.sample_rate)
loaders_config = LoadersConfig()
loaders_config.train = LoaderConfig(12, True, 6, True, False, True, 2)
loaders_config.valid = LoaderConfig(12, False)
loaders_config.test = LoaderConfig(12, False)
transforms_wrapper = TransformsWrapper(
    task = "multiclass",
    sampling_rate=model_config.sample_rate,
    model_type="waveform",
    max_length=model_config.window_size_s,
    preprocessing=PreprocessingConfig(use_spectrogram=False,
                                      hop_length=model_config.hop_size_s,
                                      n_mels=128,
                                      db_scale=False,
                                      target_height=None,
                                      target_width=1024,
                                      normalize=False,
                                      normalize_waveform="instance_min_max"),
    waveform_augmentations=DictConfig({}),
    spectrogram_augmentations=DictConfig({}),
    decoding=EventDecoding(min_len=0, 
                           max_len=model_config.window_size_s, 
                           sampling_rate=model_config.sample_rate),
    feature_extractor=DefaultFeatureExtractor(feature_size=1,
                                              sampling_rate=model_config.sample_rate,
                                              padding_value=0,
                                              return_attention_mask=False)
    )
# transforms_wrapper = EventDecoding(0, 10, 48000)
mapper = XCEventMapping(biggest_cluster=True,
                        event_limit=5,
                        no_call=True)
transforms_wrapper = BatchTransformer(
    task = "multiclass",
    sampling_rate=model_config.sample_rate,
    max_length=model_config.window_size_s)
dm = GADMEDataModule(dataset_config, loaders_config, transforms_wrapper, mapper)

In [11]:
dm.prepare_data()
dm.setup("fit")

Saving the dataset (0/1 shards):   0%|          | 0/21650 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5413 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10296 [00:00<?, ? examples/s]

In [12]:
loader = dm.train_dataloader()
loader

<torch.utils.data.dataloader.DataLoader at 0x2e2ce3910>

In [13]:
x = next(iter(dm.train_dataloader()))
x

{'input_values': tensor([[ 0.0006,  0.0009,  0.0004,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0123, -0.0137, -0.0082,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0005, -0.0026, -0.0050,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0012, -0.0012, -0.0013,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0080,  0.0052,  0.0015,  ..., -0.0117, -0.0109, -0.0065],
         [ 0.0059,  0.0062,  0.0046,  ...,  0.0000,  0.0000,  0.0000]]),
 'labels': tensor([13,  7,  9, 13, 14,  7,  0, 12,  7,  2, 12, 11])}

In [14]:
np_batch = x["input_values"].numpy()
np_batch

array([[ 0.00059441,  0.00087525,  0.00035514, ...,  0.        ,
         0.        ,  0.        ],
       [-0.01229766, -0.01365871, -0.00821954, ...,  0.        ,
         0.        ,  0.        ],
       [-0.00052826, -0.00259265, -0.00501468, ...,  0.        ,
         0.        ,  0.        ],
       ...,
       [ 0.00119571, -0.00121162, -0.00128617, ...,  0.        ,
         0.        ,  0.        ],
       [ 0.00799493,  0.00516469,  0.00149914, ..., -0.01168465,
        -0.01091061, -0.00654386],
       [ 0.00591717,  0.00624451,  0.0045801 , ...,  0.        ,
         0.        ,  0.        ]], dtype=float32)

In [15]:
np_batch.shape

(12, 144000)

In [16]:
len(np_batch)

12

In [17]:
len(np_batch[0])

144000

In [18]:
if len(np_batch.shape) > 2:
    print(len(np_batch[0][0]))

In [19]:
np_batch[0][0]

0.00059440685

In [20]:
240000/48000

5.0

In [21]:
z = model_config.sample_rate * model_config.window_size_s
z

144000.0

In [22]:
embeddings = embedding_model.batch_embed(np_batch)
embeddings

InferenceOutputs(embeddings=array([[[[0.02574717, 0.        , 0.03946407, ..., 0.6262875 ,
          0.        , 1.4018865 ]]],


       [[[0.34200057, 0.0816358 , 0.1005272 , ..., 0.22244847,
          0.04099405, 1.2699491 ]]],


       [[[0.04089956, 0.00187135, 0.        , ..., 0.88778466,
          0.25469065, 1.473136  ]]],


       ...,


       [[[0.09645306, 0.24277925, 0.02428765, ..., 0.17862274,
          0.00747283, 0.0868518 ]]],


       [[[1.0342367 , 0.3493255 , 0.39317372, ..., 0.09325168,
          0.10487243, 1.3692036 ]]],


       [[[0.        , 0.04789897, 0.        , ..., 0.29648256,
          0.07495558, 0.        ]]]], dtype=float32), logits={'birdnet_v2_4': array([[[[-15.359092, -17.526382, -15.720421, ..., -17.054527,
          -14.200952, -16.955507]]],


       [[[-13.314184, -14.549339, -13.92495 , ..., -11.491795,
          -11.827736, -11.909972]]],


       [[[ -8.064534, -14.039677, -10.188151, ..., -11.571695,
          -10.022747, -12.033895]]],


 

In [23]:
len(embeddings.embeddings[0][0][0])

1024

In [24]:
embeddings.embeddings.shape

(12, 1, 1, 1024)

In [25]:
embeddings.embeddings.squeeze().shape

(12, 1024)

In [26]:
embeddings.logits["birdnet_v2_4"].shape

(12, 1, 1, 6522)

In [27]:
from lightning import LightningModule
import torch
from torch import nn
from typing import List, Literal, Tuple
import torchmetrics

In [28]:
from typing import Any


from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler

class EmbeddingModel(nn.Module):
    def __init__(self, tf_model):
        super().__init__()
        self.tf_model = tf_model
    
    def forward(self, x):
        inference_outputs = self.tf_model(x)
        embeddings = inference_outputs.embeddings
        embeddings = embeddings.squeeze()
        return embeddings
        

class LightninLinearEmbeddings(LightningModule):
    def __init__(self, tf_model, num_classes=6522, embed_dim=1024, device = "mps"):
        super().__init__()
        self.tf_model = tf_model
        self.acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.criterion = nn.CrossEntropyLoss()
        self.linear = nn.Linear(in_features=embed_dim, out_features=num_classes, device=device)
        self.d_device = device
    
    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx:int):
        x = batch["input_values"]
        y = batch["labels"]
        x = x.cpu()
        embeddings = self.tf_model(x)        
        embeddings = torch.from_numpy(embeddings).to(self.d_device)
        preds = self.linear(embeddings)
        loss = self.criterion(preds, y)
        acc = self.acc(preds, y)
        print(f"Acc {acc} and Loss {loss}")
        return loss
    
    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
        x = batch["input_values"]
        y = batch["labels"]
        x = x.cpu()
        embeddings = self.tf_model(x)
        embeddings = torch.from_numpy(embeddings).to(self.d_device)
        preds = self.linear(embeddings)
        loss = self.criterion(preds, y)
        acc = self.acc(preds, y)
        print("Test")
        print(f"Acc {acc} and Loss {loss}")
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
         

In [29]:
from lightning import Trainer

In [30]:
model = LightninLinearEmbeddings(EmbeddingModel(embedding_model.batch_embed))
model

LightninLinearEmbeddings(
  (tf_model): EmbeddingModel()
  (acc): MulticlassAccuracy()
  (criterion): CrossEntropyLoss()
  (linear): Linear(in_features=1024, out_features=6522, bias=True)
)

In [31]:
trainer = Trainer()
trainer.fit(model = model, train_dataloaders=loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type               | Params
-------------------------------------------------
0 | tf_model  | EmbeddingModel     | 0     
1 | acc       | MulticlassAccuracy | 0     
2 | criterion | CrossEntropyLoss   | 0     
3 | linear    | Linear             | 6.7 M 
-------------------------------------------------
6.7 M     Trainable params
0         Non-trainable params
6.7 M     Total params
26.740    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

Acc 0.0 and Loss 9.02544116973877
Acc 0.0 and Loss 8.607823371887207
Acc 0.25 and Loss 7.833974361419678
Acc 0.0833333358168602 and Loss 7.852588176727295
Acc 0.25 and Loss 6.947601318359375
Acc 0.25 and Loss 6.696372985839844
Acc 0.1666666716337204 and Loss 6.174665451049805
Acc 0.0 and Loss 5.735023021697998
Acc 0.4166666567325592 and Loss 4.8323073387146
Acc 0.0833333358168602 and Loss 5.1402082443237305
Acc 0.1666666716337204 and Loss 4.236105442047119
Acc 0.25 and Loss 3.1883957386016846
Acc 0.0833333358168602 and Loss 3.9142544269561768
Acc 0.1666666716337204 and Loss 3.55049204826355
Acc 0.1666666716337204 and Loss 3.532536745071411
Acc 0.3333333432674408 and Loss 3.626990556716919
Acc 0.5833333134651184 and Loss 2.239464282989502
Acc 0.1666666716337204 and Loss 3.4831488132476807
Acc 0.3333333432674408 and Loss 1.9347213506698608
Acc 0.25 and Loss 3.780644178390503
Acc 0.5833333134651184 and Loss 2.2031023502349854
Acc 0.3333333432674408 and Loss 3.2066104412078857
Acc 0.166666