# Implementing linear probing pipeline from Ghani
Trying to 100% simulate ghani setup

### 1. Load dataset

In [4]:
from birdset.datamodule.beans_datamodule import BEANSDataModule
from birdset.datamodule.base_datamodule import DatasetConfig

datasetconfig = DatasetConfig(dataset_name='beans_watkins', hf_path='DBD-research-group/beans_watkins', hf_name='default')

datamodule = BEANSDataModule(dataset=datasetconfig)
dataset = datamodule._load_data()
dataset['train']

Dataset({
    features: ['audio', 'labels'],
    num_rows: 1017
})

This can be used to check the class distribution

In [2]:
from collections import Counter

print(dataset['train'][0])
label_counts = dict(Counter(dataset['train']['labels']))
print(label_counts)

{'audio': {'path': 'Mac-3-A-3.wav', 'array': array([ 2.13882100e-04,  4.85118391e-04, -2.17375666e-04, ...,
        9.30107664e-04,  7.66232726e-04,  6.42658269e-05]), 'sampling_rate': 32000}, 'labels': 9}
{9: 51, 5: 13, 2: 43, 8: 20, 7: 44, 0: 91, 1: 31, 4: 25, 6: 27, 3: 70}


(Only use next two cells if intended) This part is for removing specific classes from watkins (for this the conversion from class name to int in beans_datamodule has to be commented out)

In [2]:
from collections import Counter

#! Here we remove all labels that have less than x examples
x = 15 
label_counts = dict(Counter(dataset['train']['labels']))

filtered_labels = {label: count for label, count in label_counts.items() if count < x}
print(filtered_labels)
# Remove additional labels
labels_to_remove = ['Fin,_Finback_Whale', 'Northern_Right_Whale']

for label in labels_to_remove:
    filtered_labels[label] = label_counts[label]
    
# Create a new dataset excluding the filtered labels
dataset = dataset.filter(lambda example: example['labels'] not in filtered_labels)

print(dataset)

{}


KeyError: 'Fin,_Finback_Whale'

In [4]:
# Convert labels back to ids
labels = set()
for split in dataset.keys():
    labels.update(dataset[split]["labels"])

label_to_id = {lbl: i for i, lbl in enumerate(labels)}

def label_to_id_fn(batch):
    for i in range(len(batch['labels'])):
        batch['labels'][i] = label_to_id[batch['labels'][i]]
    return batch


dataset = dataset.map(
    label_to_id_fn,
    batched=True,
    batch_size=500,
    load_from_cache_file=True
)
print(len(labels))

Map:   0%|          | 0/1017 [00:00<?, ? examples/s]

Map:   0%|          | 0/339 [00:00<?, ? examples/s]

Map:   0%|          | 0/339 [00:00<?, ? examples/s]

Map:   0%|          | 0/203 [00:00<?, ? examples/s]

31


### 2. Load model and set parameters

Perch:

In [22]:
from birdset.modules.models.perch import PerchModel
import torch.nn as nn

num_classes = 31 #! Don't forget to change this
sampling_rate = 32_000 # Try 48_000 here
window_length = 5
input_size = 1280

perch_network = PerchModel(num_classes=num_classes, tfhub_version=4, gpu_to_use=0)

BirdNET:

In [20]:
from birdset.modules.models.birdnet import BirdNetModel
import torch.nn as nn

num_classes = 31 #! Don't forget to change this
sampling_rate = 48_000 
window_length = 3
input_size = 1024

perch_network = BirdNetModel(num_classes=num_classes, model_path='../../checkpoints/birdnet/BirdNET_GLOBAL_6K_V2.4_Model', train_classifier=False)



### 3. Batch and Preprocess the dataset 

#### 3.1. k-sample the dataset

In [27]:
print("Number of samples in the training set:", len(dataset['train']))
print("Number of samples in the validation set:", len(dataset['valid']))
print("Number of samples in the testing set:", len(dataset['test']))

Number of samples in the training set: 1017
Number of samples in the validation set: 339
Number of samples in the testing set: 339


In [8]:
import random
from collections import defaultdict
from datasets import concatenate_datasets, DatasetDict, Dataset

# Define the number of samples per class
samples_per_class = 32

# Merge the train, valid, and test splits
merged_data = concatenate_datasets([dataset['train'], dataset['valid'], dataset['test']])
merged_data = merged_data.shuffle()

# Create a dictionary to store the selected samples per class
selected_samples = defaultdict(list)
rest_samples = []
# Iterate over the merged data and select the desired number of samples per class
for sample in merged_data:
    label = sample['labels']
    if len(selected_samples[label]) < samples_per_class:
        selected_samples[label].append(sample)
    else:
        rest_samples.append(sample)    

# Flatten the selected samples into a single list
selected_samples = [sample for samples in selected_samples.values() for sample in samples]

# Split the selected samples into training, validation, and testing sets
test_ratio = 0.5

num_samples = len(rest_samples)
num_test_samples = int(test_ratio * num_samples)

train_data = selected_samples
test_data = rest_samples[:num_test_samples]
val_data = rest_samples[num_test_samples:]

train_data = Dataset.from_dict({key: [sample[key] for sample in train_data] for key in train_data[0]})
test_data = Dataset.from_dict({key: [sample[key] for sample in test_data] for key in test_data[0]})
val_data = Dataset.from_dict({key: [sample[key] for sample in val_data] for key in val_data[0]})

# Print the number of samples in each split
print("Number of samples in the training set:", len(train_data))
print("Number of samples in the validation set:", len(val_data))
print("Number of samples in the testing set:", len(test_data))

# Combine into a DatasetDict
datasett = DatasetDict({
    'train': train_data,
    'valid': val_data,
    'test': test_data
})

Number of samples in the training set: 940
Number of samples in the validation set: 378
Number of samples in the testing set: 377


Preprocess the dataset 

In [16]:
import torch
import torchaudio

# Resample function (#! Move resampler out)
# Get embeddings
def get_embedding(audio):
    # Get waveform and sampling rate
    waveform = torch.tensor(audio['array'], dtype=torch.float32)
    dataset_sampling_rate = audio['sampling_rate']
    # Resample audio
    audio = resample_audio(waveform, dataset_sampling_rate, sampling_rate)
    #print('Audio length:', audio.shape[0]/sampling_rate)
    # Zero-padding
    audio = zero_pad(waveform)
    
    # Check if audio is too long 
    if waveform.shape[0] > window_length * sampling_rate:
        return frame_and_average(waveform)    
    else:
        return perch_network.get_embeddings(audio)[0] # To just use embeddings not logits

# Resample function
def resample_audio(audio, orig_sr, target_sr):
    resampler = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)
    return resampler(audio)

# Zero-padding function
def zero_pad(audio):
    desired_num_samples = window_length * sampling_rate 
    current_num_samples = audio.shape[0]
    padding = desired_num_samples - current_num_samples
    if padding > 0:
        #print('padding')
        pad_left = padding // 2
        pad_right = padding - pad_left
        audio = torch.nn.functional.pad(audio, (pad_left, pad_right))
    return audio

# Average multiple embeddings function
def frame_and_average(audio):
    # Ensure the waveform is mono
    #if audio.size(0) > 1:
        #print("What")
        #audio = audio.mean(dim=0, keepdim=True)
    
    # Frame the audio
    frame_size = window_length * sampling_rate
    hop_size = window_length * sampling_rate
    frames = audio.unfold(0, frame_size, hop_size)
    
    # Generate embeddings for each frame
    l = []
    for frame in frames:
        embedding = perch_network.get_embeddings(frame) 
        l.append(embedding[0]) # To just use embeddings not logits
    
    embeddings = torch.stack(tuple(l))
    
    # Average the embeddings
    averaged_embedding = embeddings.mean(dim=0)
    
    return averaged_embedding


In [29]:
from torch.utils.data import DataLoader

def preprocess(item):
    audio = item['audio']
    return get_embedding(audio)

def collate_fn(batch):
    batch_new = {}
    audios = [preprocess(item) for item in batch]
    batch_new['audio'] =  torch.stack(tuple(audios), dim=0)
    
    #batch_new['labels'] = torch.stack([torch.nn.functional.one_hot(torch.tensor(item['labels'],  dtype=torch.long), num_classes=num_classes) for item in batch]).float() #* For one hot-encoding 
    batch_new['labels'] = torch.tensor([item['labels'] for item in batch])
    return batch_new

train_loader = DataLoader(dataset['train'], batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(dataset['test'], batch_size=32, shuffle=False, collate_fn=collate_fn)
val_loader = DataLoader(dataset['valid'], batch_size=32, shuffle=False, collate_fn=collate_fn)

# Example of iterating through the DataLoader
for batch in train_loader:
    print(batch.keys())
    print(batch['audio'])
    print(batch['labels'])
    print(batch['audio'].shape)    
    break

2024-07-02 01:38:46.110718: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator jax2tf_infer_fn_/assert_equal_1/Assert/AssertGuard/Assert


dict_keys(['audio', 'labels'])
tensor([[[ 0.0078,  0.0187, -0.0098,  ..., -0.0284,  0.1116, -0.0138]],

        [[ 0.0024, -0.0420, -0.0278,  ...,  0.1036,  0.0731, -0.0307]],

        [[-0.0604, -0.0278, -0.0089,  ...,  0.1065,  0.1766, -0.0013]],

        ...,

        [[ 0.0173, -0.0715,  0.0045,  ..., -0.0214,  0.0518, -0.0411]],

        [[-0.0092, -0.0189,  0.0296,  ...,  0.0156,  0.2073, -0.0058]],

        [[ 0.0504,  0.0031,  0.0155,  ...,  0.0448,  0.1135,  0.0121]]])
tensor([30, 25, 10,  1, 19, 13, 15, 29,  8, 28,  2,  7,  2,  5, 16, 24,  4, 10,
        23, 25, 22, 13, 14, 24, 18, 20, 26, 18, 22, 15, 17, 19])
torch.Size([32, 1, 1280])


### 4. Train the classifier

In [30]:
gpu_id = 0  # Change this to the ID of the GPU you want to use
device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'
print(f'Using device: {device}') #! Not working right now

Using device: cuda:0


In [31]:
from tqdm import tqdm

import torch.nn as nn
import torch.optim as optim

# Define your classifier model
class Classifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(input_size, num_classes)
        self.log_softmax = nn.LogSoftmax(dim=1) #* self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = x.squeeze(1)
        x = self.log_softmax(self.fc(x))
        return x

# Create an instance of your classifier model
classifier = Classifier(input_size, num_classes).to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss() #* nn.BCELoss()
optimizer = optim.AdamW(classifier.parameters(), lr=1e-5, weight_decay=0.01)

# Set the number of training epochs
num_epochs = 25

early_stopping_patience = 5
best_loss = float('inf')
patience_counter = 0

# Training loop
for epoch in range(num_epochs):
    classifier.train()
    train_loss = 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs = batch['audio'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        outputs = classifier(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # Calculate average loss for this epoch
    train_loss /= len(train_loader)

    # Validate the model (assuming you have a validation loader)
    classifier.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch['audio'].to(device)
            labels = batch['labels'].to(device)
            outputs = classifier(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
    
    val_loss /= len(val_loader)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Early stopping
    if val_loss < best_loss:
        best_loss = val_loss
        patience_counter = 0
        # Save the best model weights if necessary
        torch.save(classifier.state_dict(), 'best_model.pth')
    else:
        patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print("Early stopping triggered")
            break

Epoch 1/25: 100%|██████████| 4/4 [00:01<00:00,  2.65it/s]


Epoch 1/25, Train Loss: 3.4358, Val Loss: 3.4363


Epoch 2/25: 100%|██████████| 4/4 [00:02<00:00,  1.78it/s]


Epoch 2/25, Train Loss: 3.4341, Val Loss: 3.4358


Epoch 3/25: 100%|██████████| 4/4 [00:01<00:00,  2.26it/s]


Epoch 3/25, Train Loss: 3.4331, Val Loss: 3.4353


Epoch 4/25: 100%|██████████| 4/4 [00:01<00:00,  2.23it/s]


Epoch 4/25, Train Loss: 3.4326, Val Loss: 3.4348


Epoch 5/25: 100%|██████████| 4/4 [00:01<00:00,  2.06it/s]


Epoch 5/25, Train Loss: 3.4311, Val Loss: 3.4343


Epoch 6/25: 100%|██████████| 4/4 [00:01<00:00,  2.09it/s]


Epoch 6/25, Train Loss: 3.4306, Val Loss: 3.4338


Epoch 7/25: 100%|██████████| 4/4 [00:02<00:00,  1.99it/s]


Epoch 7/25, Train Loss: 3.4298, Val Loss: 3.4333


Epoch 8/25: 100%|██████████| 4/4 [00:01<00:00,  2.17it/s]


Epoch 8/25, Train Loss: 3.4290, Val Loss: 3.4329


Epoch 9/25: 100%|██████████| 4/4 [00:01<00:00,  2.28it/s]


Epoch 9/25, Train Loss: 3.4274, Val Loss: 3.4324


Epoch 10/25: 100%|██████████| 4/4 [00:01<00:00,  2.15it/s]


Epoch 10/25, Train Loss: 3.4268, Val Loss: 3.4319


Epoch 11/25: 100%|██████████| 4/4 [00:02<00:00,  1.91it/s]


Epoch 11/25, Train Loss: 3.4255, Val Loss: 3.4314


Epoch 12/25: 100%|██████████| 4/4 [00:01<00:00,  2.20it/s]


Epoch 12/25, Train Loss: 3.4244, Val Loss: 3.4309


Epoch 13/25: 100%|██████████| 4/4 [00:02<00:00,  1.92it/s]


Epoch 13/25, Train Loss: 3.4239, Val Loss: 3.4304


Epoch 14/25: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s]


Epoch 14/25, Train Loss: 3.4227, Val Loss: 3.4299


Epoch 15/25: 100%|██████████| 4/4 [00:01<00:00,  2.08it/s]


Epoch 15/25, Train Loss: 3.4221, Val Loss: 3.4294


Epoch 16/25: 100%|██████████| 4/4 [00:01<00:00,  2.20it/s]


Epoch 16/25, Train Loss: 3.4206, Val Loss: 3.4289


Epoch 17/25: 100%|██████████| 4/4 [00:01<00:00,  2.53it/s]


Epoch 17/25, Train Loss: 3.4197, Val Loss: 3.4284


Epoch 18/25: 100%|██████████| 4/4 [00:01<00:00,  2.11it/s]


Epoch 18/25, Train Loss: 3.4185, Val Loss: 3.4279


Epoch 19/25: 100%|██████████| 4/4 [00:02<00:00,  1.96it/s]


Epoch 19/25, Train Loss: 3.4179, Val Loss: 3.4274


Epoch 20/25: 100%|██████████| 4/4 [00:01<00:00,  2.46it/s]


Epoch 20/25, Train Loss: 3.4164, Val Loss: 3.4269


Epoch 21/25: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s]


Epoch 21/25, Train Loss: 3.4157, Val Loss: 3.4264


Epoch 22/25: 100%|██████████| 4/4 [00:02<00:00,  1.96it/s]


Epoch 22/25, Train Loss: 3.4152, Val Loss: 3.4259


Epoch 23/25: 100%|██████████| 4/4 [00:01<00:00,  2.64it/s]


Epoch 23/25, Train Loss: 3.4140, Val Loss: 3.4254


Epoch 24/25: 100%|██████████| 4/4 [00:01<00:00,  2.13it/s]


Epoch 24/25, Train Loss: 3.4129, Val Loss: 3.4250


Epoch 25/25: 100%|██████████| 4/4 [00:02<00:00,  1.85it/s]


Epoch 25/25, Train Loss: 3.4117, Val Loss: 3.4245


In [32]:
from sklearn.metrics import accuracy_score, roc_auc_score
import torchmetrics

# Set the model to evaluation mode
perch_network.eval()


# Initialize the metrics
metrics = torchmetrics.MetricCollection({
    'T1Accuracy': torchmetrics.Accuracy(
        task="multiclass",
        num_classes=num_classes,
        top_k=1
    ),
    'T3Accuracy': torchmetrics.Accuracy(
        task="multiclass",
        num_classes=num_classes,
        top_k=3
    ),
    'AUROC': torchmetrics.AUROC(
        task="multiclass",
        num_classes=num_classes,
        average='macro'
    ),
    'F1': torchmetrics.F1Score(
        task="multiclass",
        num_classes=num_classes
    )
}).to(device)

# Iterate over the test_loader
for batch in test_loader:
    # Forward pass
    inputs = batch['audio'].to(device)
    labels = batch['labels'].to(device)
    #labels = torch.argmax(labels, dim=1) #* For one hot-encoding 
    with torch.no_grad():
        outputs = classifier(inputs)
        outputs = outputs.squeeze(1)
    
    # Update the metrics
    metrics(outputs, labels)

# Compute and print the metric values
metric_values = metrics.compute()
for metric_name, metric_value in metric_values.items():
    print(f"{metric_name}: {metric_value}")


AUROC: 0.6013455390930176
F1: 0.05605095624923706
T1Accuracy: 0.05605095624923706
T3Accuracy: 0.1719745248556137


### <u>Ghani datasets</u>

|Dataset|Classes|Available?|
|-------|-------|----------|
|Godwit Calls|5|No part of a master thesis|
|Yellowhammer Dialects|2|Probably not (Only two classes anyway)|
|Bats|5|Yes but pitch shifting and two sources of which one is private|
|Watkins|32|Yes but removed some classes|
|RFCX Frog & Bird|12+13|Yes but for detection and not split in BEANS|

### <u>Results with Perch</u>
These are the results in this isolated run whereas we compare them to the Birdset Pipeline setup. We used 25 Epochs.
| Dataset         | Classes|AUROC (BirdsetPipeline results) | T1 (B.P.) | Audio lengths |Samples per class|
|--------------------|---|---------------------|-----------------|-----|----|
| beans_watkins      |31|**89** (85)                   |**32%** (23%)|Different lengths 1-45s|~30|
| beans_bats         |10|**79**  (78)                  |38% (**39%**)|0-5s|600|
| beans_cbi          |264|   (96)                 |(51%)|4-10s (Mostly 10)|~50-70|
| beans_dogs         |10|**78**    (75)                |**28%** (31%)|2-30s|13-70|
| beans_humbugdb     |14|**69** (66)|**46%** (12%)|1-55s|~70 or ~400|

### <u>Results with BirdNET</u>
Used 25 Epochs
| Dataset         |AUROC | T1 |
|------------------|---|---|
| beans_watkins      |84|23%|
| beans_bats         |||
| beans_cbi          |||
| beans_dogs         |||
| beans_humbugdb     |||