In [1]:
from src.perch.inference import models
from ml_collections import config_dict
import numpy as np

In [2]:
from src.datamodule.gadme_datamodule import GADMEDataModule
from src.datamodule.base_datamodule import DatasetConfig, LoadersConfig, LoaderConfig
from src.datamodule.components.transforms import TransformsWrapper, PreprocessingConfig, BatchTransformer
from src.datamodule.components.event_mapping import XCEventMapping
from src.datamodule.components.event_decoding import EventDecoding
from src.datamodule.components.feature_extraction import DefaultFeatureExtractor
from omegaconf import DictConfig

In [3]:
model_choice = "birdnet"
model_choice = "perch"
model_choice = "vggish"
model_choice = "yamnet"

In [4]:
perch_tfhub_version = 4
perch_model_path = ''
birdnet_model_path = '/Users/moritzrichert/Downloads/V2.4/BirdNET_GLOBAL_6K_V2.4_Model_FP32.tflite'

Yamnet and VGGISH s embeddings are different from perches and birdnets:

they are of shape (N, n, embed)

where N is the batch-size,

embed is the embedding dimension (1024/128)

and n is the number of frames that the model found (one for every 0.48 / 0.96 seconds)

In [5]:

model_config = config_dict.ConfigDict()
if model_choice == "perch":
    embedding_model_key = "taxonomy_model_tf"
    model_config.window_size_s = 5.0
    model_config.hop_size_s = 5.0
    model_config.sample_rate = 32000
    model_config.tfhub_version = perch_tfhub_version
    model_config.model_path = perch_model_path
    num_classes = 10932
    embed_dim = 1280
if model_choice == "birdnet":
    embedding_model_key = "birdnet"
    model_config.window_size_s = 3.0
    model_config.hop_size_s = 3.0
    model_config.sample_rate = 48000
    model_config.model_path = birdnet_model_path
    model_config.class_list_name = 'birdnet_v2_4'
    model_config.num_tflite_threads = 4
    num_classes = 6522
    embed_dim = 1024
if model_choice == "yamnet":
    embedding_model_key = "tfhub_model"
    num_classes = 521
    embed_dim = 1024
    model_config = models.TFHubModel.yamnet_cfg()
    model_config.window_size_s = 3
if model_choice == "vggish":
    embedding_model_key = "tfhub_model"
    num_classes = 128
    embed_dim = 128
    model_config = models.TFHubModel.vggish_cfg()
    model_config.window_size_s = 3



In [6]:
model_config.sample_rate

16000

In [7]:
model_choice

'yamnet'

In [8]:
embedding_model_key

'tfhub_model'

In [9]:
model_class = models.model_class_map()[embedding_model_key]
embedding_model = model_class.from_config(model_config)

In [10]:
embedding_model

TFHubModel(sample_rate=16000, model=<tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject object at 0x2ebab1120>, model_url='https://tfhub.dev/google/yamnet/1', embedding_index=1, logits_index=0, window_size_s=3)

In [11]:
dataset_name = "DBD-research-group/gadme_v1"
cache_dir = "/Volumes/BigChongusF/Datasets/Huggingface/gadme_v1/data"
dataset_config = DatasetConfig(cache_dir, "high_sierras", dataset_name, "high_sierras", 42, num_classes, 3, 0.2, "multiclass", None, model_config.sample_rate)
loaders_config = LoadersConfig()
loaders_config.train = LoaderConfig(12, True, 6, True, False, True, 2)
loaders_config.valid = LoaderConfig(12, False)
loaders_config.test = LoaderConfig(12, False)
# transforms_wrapper = TransformsWrapper(
#     task = "multiclass",
#     sampling_rate=model_config.sample_rate,
#     model_type="waveform",
#     max_length=model_config.window_size_s,
#     preprocessing=PreprocessingConfig(use_spectrogram=False,
#                                       hop_length=model_config.hop_size_s,
#                                       n_mels=128,
#                                       db_scale=False,
#                                       target_height=None,
#                                       target_width=1024,
#                                       normalize=False,
#                                       normalize_waveform="instance_min_max"),
#     waveform_augmentations=DictConfig({}),
#     spectrogram_augmentations=DictConfig({}),
#     decoding=EventDecoding(min_len=0, 
#                            max_len=model_config.window_size_s, 
#                            sampling_rate=model_config.sample_rate),
#     feature_extractor=DefaultFeatureExtractor(feature_size=1,
#                                               sampling_rate=model_config.sample_rate,
#                                               padding_value=0,
#                                               return_attention_mask=False)
#     )
# transforms_wrapper = EventDecoding(0, 10, 48000)
mapper = XCEventMapping(biggest_cluster=True,
                        event_limit=5,
                        no_call=True)
transforms_wrapper = BatchTransformer(
    task = "multiclass",
    sampling_rate=model_config.sample_rate,
    max_length=model_config.window_size_s)
dm = GADMEDataModule(dataset_config, loaders_config, transforms_wrapper, mapper)

In [12]:
dm.prepare_data()
dm.setup("fit")

Saving the dataset (0/1 shards):   0%|          | 0/21650 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5413 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10296 [00:00<?, ? examples/s]

In [13]:
loader = dm.train_dataloader()
loader

<torch.utils.data.dataloader.DataLoader at 0x2eb89dbd0>

In [14]:
x = next(iter(dm.train_dataloader()))
x

{'input_values': tensor([[-0.0012, -0.0033, -0.0088,  ...,  0.0044,  0.0056, -0.0010],
         [ 0.0029,  0.0037,  0.0020,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0008, -0.0010, -0.0006,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0035,  0.0014, -0.0048,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0015,  0.0012, -0.0003,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0005, -0.0007, -0.0024,  ...,  0.0000,  0.0000,  0.0000]]),
 'labels': tensor([ 0,  0, 13, 11, 19, 12, 13, 14, 19,  3, 21,  0])}

In [15]:
np_batch = x["input_values"].numpy()
np_batch

array([[-0.00121539, -0.00328653, -0.00875196, ...,  0.00439776,
         0.00556342, -0.00104277],
       [ 0.00285388,  0.00370478,  0.00199397, ...,  0.        ,
         0.        ,  0.        ],
       [-0.00077894, -0.00100327, -0.00060368, ...,  0.        ,
         0.        ,  0.        ],
       ...,
       [ 0.00350025,  0.00141034, -0.0047687 , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.00153291,  0.00120396, -0.00032726, ...,  0.        ,
         0.        ,  0.        ],
       [-0.00049899, -0.00074622, -0.00235357, ...,  0.        ,
         0.        ,  0.        ]], dtype=float32)

In [16]:
np_batch.shape

(12, 48000)

In [17]:
len(np_batch)

12

In [18]:
len(np_batch[0])

48000

In [19]:
if len(np_batch.shape) > 2:
    print(len(np_batch[0][0]))

In [20]:
np_batch[0][0]

-0.0012153923

In [21]:
240000/48000

5.0

In [22]:
z = model_config.sample_rate * model_config.window_size_s
z

48000

In [23]:
embeddings = embedding_model.batch_embed(np_batch)
embeddings

InferenceOutputs(embeddings=array([[[[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
          0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],

        [[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
          0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],

        [[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
          0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],

        [[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
          0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],

        [[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
          0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],

        [[0.0000000e+00, 0.0000000e+00, 4.3896177e-01, ...,
          0.0000000e+00, 0.0000000e+00, 0.0000000e+00]]],


       [[[0.0000000e+00, 3.3383083e-02, 0.0000000e+00, ...,
          0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],

        [[0.0000000e+00, 1.8367519e-03, 6.9607049e-04, ...,
          0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],

        [[0.0000000e+00, 0

In [24]:
if len(embeddings.embeddings.shape)>3:
    print(len(embeddings.embeddings[0][0][0]))

1024


In [25]:
embeddings.embeddings.shape

(12, 6, 1, 1024)

In [26]:
embeddings.embeddings.squeeze().shape

(12, 6, 1024)

In [27]:
3*128

384

In [28]:
shp = embeddings.embeddings.squeeze().shape
shp

(12, 6, 1024)

In [29]:
embeddings.embeddings[0]

array([[[0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ]],

       [[0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ]],

       [[0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ]],

       [[0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ]],

       [[0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ]],

       [[0.        , 0.        , 0.43896177, ..., 0.        ,
         0.        , 0.        ]]], dtype=float32)

In [30]:
logit = None
if type(embeddings.logits) is dict:
    logit = list(embeddings.logits.keys())[0]

In [31]:
if logit is not None:
    print(embeddings.logits[logit].shape)

(12, 6, 521)


In [32]:
embeddings.embeddings[0].shape

(6, 1, 1024)

In [33]:
from lightning import LightningModule
import torch
from torch import nn
from typing import List, Literal, Tuple
import torchmetrics

In [34]:
a = np.arange(12*128).reshape(12, 128)
a

array([[   0,    1,    2, ...,  125,  126,  127],
       [ 128,  129,  130, ...,  253,  254,  255],
       [ 256,  257,  258, ...,  381,  382,  383],
       ...,
       [1152, 1153, 1154, ..., 1277, 1278, 1279],
       [1280, 1281, 1282, ..., 1405, 1406, 1407],
       [1408, 1409, 1410, ..., 1533, 1534, 1535]])

In [35]:
a.mean(axis=(1))

array([  63.5,  191.5,  319.5,  447.5,  575.5,  703.5,  831.5,  959.5,
       1087.5, 1215.5, 1343.5, 1471.5])

In [36]:

from typing import Any


from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler


class AudioSetEmbeddingModel(nn.Module):
    def __init__(self, tf_model):
        super().__init__()
        self.tf_model = tf_model
    
    def forward(self, x):
        inference = self.tf_model(x)
        embeddings = inference.embeddings
        embeddings = embeddings.squeeze()
        embeddings = embeddings.mean(axis=(1))
        return embeddings
        

class EmbeddingModel(nn.Module):
    def __init__(self, tf_model):
        super().__init__()
        self.tf_model = tf_model
    
    def forward(self, x):
        inference_outputs = self.tf_model(x)
        embeddings = inference_outputs.embeddings
        embeddings = embeddings.squeeze()
        return embeddings
        

class LightninLinearEmbeddings(LightningModule):
    def __init__(self, tf_model, num_classes=6522, embed_dim=1024, device = "mps"):
        super().__init__()
        self.tf_model = tf_model
        self.acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.criterion = nn.CrossEntropyLoss()
        self.linear = nn.Linear(in_features=embed_dim, out_features=num_classes, device=device)
        self.d_device = device
    
    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx:int):
        x = batch["input_values"]
        y = batch["labels"]
        x = x.cpu()
        embeddings = self.tf_model(x)        
        embeddings = torch.from_numpy(embeddings).to(self.d_device)
        preds = self.linear(embeddings)
        loss = self.criterion(preds, y)
        acc = self.acc(preds, y)
        print(f"Acc {acc} and Loss {loss}")
        return loss
    
    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
        x = batch["input_values"]
        y = batch["labels"]
        x = x.cpu()
        embeddings = self.tf_model(x)
        embeddings = torch.from_numpy(embeddings).to(self.d_device)
        preds = self.linear(embeddings)
        loss = self.criterion(preds, y)
        acc = self.acc(preds, y)
        print("Test")
        print(f"Acc {acc} and Loss {loss}")
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
         

In [37]:
from lightning import Trainer

In [38]:
model = LightninLinearEmbeddings(EmbeddingModel(embedding_model.batch_embed), num_classes=num_classes, embed_dim=embed_dim)
model

LightninLinearEmbeddings(
  (tf_model): EmbeddingModel()
  (acc): MulticlassAccuracy()
  (criterion): CrossEntropyLoss()
  (linear): Linear(in_features=1024, out_features=521, bias=True)
)

In [39]:
model = LightninLinearEmbeddings(AudioSetEmbeddingModel(embedding_model.batch_embed), num_classes=num_classes, embed_dim=embed_dim)
model

LightninLinearEmbeddings(
  (tf_model): AudioSetEmbeddingModel()
  (acc): MulticlassAccuracy()
  (criterion): CrossEntropyLoss()
  (linear): Linear(in_features=1024, out_features=521, bias=True)
)

In [40]:
num_classes

521

In [41]:
trainer = Trainer()
trainer.fit(model = model, train_dataloaders=loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type                   | Params
-----------------------------------------------------
0 | tf_model  | AudioSetEmbeddingModel | 0     
1 | acc       | MulticlassAccuracy     | 0     
2 | criterion | CrossEntropyLoss       | 0     
3 | linear    | Linear                 | 534 K 
-----------------------------------------------------
534 K     Trainable params
0         Non-trainable params
534 K     Total params
2.136     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

Acc 0.0 and Loss 6.623044490814209
Acc 0.0 and Loss 5.742731094360352
Acc 0.4166666567325592 and Loss 4.735328674316406
Acc 0.3333333432674408 and Loss 4.602878093719482
Acc 0.0833333358168602 and Loss 4.3408660888671875
Acc 0.25 and Loss 4.866981029510498
Acc 0.1666666716337204 and Loss 4.550601959228516
Acc 0.25 and Loss 4.8293938636779785
Acc 0.25 and Loss 3.7563722133636475
Acc 0.25 and Loss 4.516570091247559
Acc 0.0833333358168602 and Loss 5.568239212036133
Acc 0.3333333432674408 and Loss 5.132167339324951
Acc 0.1666666716337204 and Loss 4.660449504852295
Acc 0.25 and Loss 3.95861554145813
Acc 0.1666666716337204 and Loss 4.7371978759765625
Acc 0.25 and Loss 3.798006772994995
Acc 0.3333333432674408 and Loss 3.372311592102051
Acc 0.25 and Loss 3.3003361225128174
Acc 0.1666666716337204 and Loss 5.242767810821533
Acc 0.25 and Loss 4.877814769744873
Acc 0.3333333432674408 and Loss 4.004086971282959
Acc 0.1666666716337204 and Loss 4.807047367095947
Acc 0.1666666716337204 and Loss 4.3561