# Implementing linear probing pipeline from Ghani
Trying to 100% simulate ghani setup

### 1. Load dataset

In [1]:
from birdset.datamodule.beans_datamodule import BEANSDataModule
from birdset.datamodule.base_datamodule import DatasetConfig

datasetconfig = DatasetConfig(dataset_name='beans_watkins', hf_path='DBD-research-group/beans_watkins', hf_name='default')

datamodule = BEANSDataModule(dataset=datasetconfig)
dataset = datamodule._load_data()
dataset['train']

Map:   0%|          | 0/1017 [00:00<?, ? examples/s]

Map:   0%|          | 0/339 [00:00<?, ? examples/s]

Map:   0%|          | 0/339 [00:00<?, ? examples/s]

Map:   0%|          | 0/203 [00:00<?, ? examples/s]

Dataset({
    features: ['audio', 'labels'],
    num_rows: 1017
})

### 2. Load model and set parameters

In [2]:
from birdset.modules.models.perch import PerchModel
import torch.nn as nn

num_classes = 31
sampling_rate = 32_000 
window_length = 5

perch_network = PerchModel(num_classes=num_classes, tfhub_version=4, gpu_to_use=0)



2024-06-17 22:42:50.700071: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-17 22:42:50.700175: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-17 22:42:50.700209: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-17 22:42:50.708775: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-06-17 22:42:52.998956: I tensorflow/core/comm

### 3. Batch and Preprocess the dataset 

reprocess the dataset 

In [9]:
import torch
import torchaudio

# Resample function (#! Move resampler out)
# Get embeddings
def get_embedding(audio):
    # Get waveform and sampling rate
    waveform = torch.tensor(audio['array'], dtype=torch.float32)
    dataset_sampling_rate = audio['sampling_rate']
    # Resample audio
    audio = resample_audio(waveform, dataset_sampling_rate, sampling_rate)
    
    # Zero-padding
    audio = zero_pad(waveform)
    
    # Check if audio is too long 
    if waveform.shape[0] > window_length * sampling_rate:
        return frame_and_average(waveform)    
    else:
        return perch_network.get_embeddings(audio)[0] # To just use embeddings not logits

# Resample function
def resample_audio(audio, orig_sr, target_sr):
    resampler = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)
    return resampler(audio)

# Zero-padding function
def zero_pad(audio):
    desired_num_samples = window_length * sampling_rate 
    current_num_samples = audio.shape[0]
    padding = desired_num_samples - current_num_samples
    if padding > 0:
        #print('padding')
        pad_left = padding // 2
        pad_right = padding - pad_left
        audio = torch.nn.functional.pad(audio, (pad_left, pad_right))
    return audio

# Average multiple embeddings function
def frame_and_average(audio):
    # Ensure the waveform is mono
    #if audio.size(0) > 1:
        #print("What")
        #audio = audio.mean(dim=0, keepdim=True)
    
    # Frame the audio
    frame_size = window_length * sampling_rate
    hop_size = window_length * sampling_rate
    frames = audio.unfold(0, frame_size, hop_size)
    
    # Generate embeddings for each frame
    l = []
    for frame in frames:
        embedding = perch_network.get_embeddings(frame) 
        l.append(embedding[0]) # To just use embeddings not logits
    
    embeddings = torch.stack(tuple(l))
    
    # Average the embeddings
    averaged_embedding = embeddings.mean(dim=0)
    
    return averaged_embedding


In [10]:
from torch.utils.data import DataLoader

def preprocess(item):
    audio = item['audio']
    return get_embedding(audio)

def collate_fn(batch):
    batch_new = {}
    audios = [preprocess(item) for item in batch]
    batch_new['audio'] =  torch.stack(tuple(audios), dim=0)
    batch_new['labels'] = torch.tensor([item['labels'] for item in batch])
    return batch_new

train_loader = DataLoader(dataset['train'], batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(dataset['test'], batch_size=32, shuffle=False, collate_fn=collate_fn)
val_loader = DataLoader(dataset['valid'], batch_size=32, shuffle=False, collate_fn=collate_fn)

# Example of iterating through the DataLoader
for batch in train_loader:
    print(batch.keys())
    print(batch['audio'])
    print(batch['labels'])
    print(batch['audio'].shape)    
    break

dict_keys(['audio', 'labels'])
tensor([[[-0.0308,  0.0046, -0.0026,  ...,  0.0564,  0.2090,  0.0186]],

        [[-0.0090, -0.0200, -0.0013,  ...,  0.0420,  0.2236,  0.0357]],

        [[-0.0533, -0.1032, -0.0113,  ...,  0.1077,  0.1384, -0.0003]],

        ...,

        [[ 0.1132, -0.0464, -0.0139,  ...,  0.0277,  0.1555,  0.0788]],

        [[ 0.0108, -0.0542,  0.0072,  ...,  0.1012,  0.0907, -0.0042]],

        [[-0.0511, -0.0258, -0.0234,  ...,  0.0520,  0.1342, -0.0264]]])
tensor([28,  0, 21, 25,  1,  1,  9,  3,  8,  5, 13,  2,  8,  5, 19,  2,  3, 21,
        20, 23, 13,  3, 10, 19, 13, 22, 20,  3,  4,  4, 26, 13])
torch.Size([32, 1, 1280])


### 4. Train the classifier

In [11]:
gpu_id = 1  # Change this to the ID of the GPU you want to use
device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}') #! Not working right now

Using device: cuda:1


In [18]:
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch.nn as nn
import torch.optim as optim

# Define your classifier model
class Classifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(input_size, num_classes)
    
    def forward(self, x):
        x = self.fc(x)
        x = x.squeeze(1)
        return x

class Classifier2(nn.Module):
    def __init__(self, input_size, num_classes):
        super(Classifier2, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, num_classes)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = x.squeeze(1)
        return x

# Set the input size and number of classes
input_size = 1280

# Create an instance of your classifier model
classifier = Classifier(input_size, num_classes).to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(classifier.parameters(), lr=1e-5, weight_decay=0.01)

# Set the number of training epochs
num_epochs = 50  

early_stopping_patience = 5
best_loss = float('inf')
patience_counter = 0

# Training loop
for epoch in range(num_epochs):
    classifier.train()
    train_loss = 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs = batch['audio'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        outputs = classifier(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # Calculate average loss for this epoch
    train_loss /= len(train_loader)

    # Validate the model (assuming you have a validation loader)
    classifier.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch['audio'].to(device)
            labels = batch['labels'].to(device)
            outputs = classifier(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
    
    val_loss /= len(val_loader)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Early stopping
    if val_loss < best_loss:
        best_loss = val_loss
        patience_counter = 0
        # Save the best model weights if necessary
        torch.save(classifier.state_dict(), 'best_model.pth')
    else:
        patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print("Early stopping triggered")
            break

Epoch 1/50: 100%|██████████| 32/32 [00:35<00:00,  1.10s/it]


Epoch 1/50, Train Loss: 3.4328, Val Loss: 3.4309


Epoch 2/50: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Epoch 2/50, Train Loss: 3.4272, Val Loss: 3.4258


Epoch 3/50: 100%|██████████| 32/32 [00:34<00:00,  1.06s/it]


Epoch 3/50, Train Loss: 3.4215, Val Loss: 3.4206


Epoch 4/50: 100%|██████████| 32/32 [00:34<00:00,  1.07s/it]


Epoch 4/50, Train Loss: 3.4161, Val Loss: 3.4155


Epoch 5/50: 100%|██████████| 32/32 [00:35<00:00,  1.10s/it]


Epoch 5/50, Train Loss: 3.4105, Val Loss: 3.4105


Epoch 6/50: 100%|██████████| 32/32 [00:33<00:00,  1.06s/it]


Epoch 6/50, Train Loss: 3.4050, Val Loss: 3.4055


Epoch 7/50: 100%|██████████| 32/32 [00:35<00:00,  1.11s/it]


Epoch 7/50, Train Loss: 3.3996, Val Loss: 3.4004


Epoch 8/50: 100%|██████████| 32/32 [00:33<00:00,  1.03s/it]


Epoch 8/50, Train Loss: 3.3943, Val Loss: 3.3954


Epoch 9/50: 100%|██████████| 32/32 [00:33<00:00,  1.06s/it]


Epoch 9/50, Train Loss: 3.3887, Val Loss: 3.3905


Epoch 10/50: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Epoch 10/50, Train Loss: 3.3834, Val Loss: 3.3855


Epoch 11/50: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Epoch 11/50, Train Loss: 3.3781, Val Loss: 3.3806


Epoch 12/50: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]


Epoch 12/50, Train Loss: 3.3728, Val Loss: 3.3757


Epoch 13/50: 100%|██████████| 32/32 [00:34<00:00,  1.07s/it]


Epoch 13/50, Train Loss: 3.3675, Val Loss: 3.3708


Epoch 14/50: 100%|██████████| 32/32 [00:35<00:00,  1.12s/it]


Epoch 14/50, Train Loss: 3.3621, Val Loss: 3.3660


Epoch 15/50: 100%|██████████| 32/32 [00:34<00:00,  1.07s/it]


Epoch 15/50, Train Loss: 3.3568, Val Loss: 3.3612


Epoch 16/50: 100%|██████████| 32/32 [00:35<00:00,  1.10s/it]


Epoch 16/50, Train Loss: 3.3517, Val Loss: 3.3564


Epoch 17/50: 100%|██████████| 32/32 [00:34<00:00,  1.09s/it]


Epoch 17/50, Train Loss: 3.3465, Val Loss: 3.3515


Epoch 18/50: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Epoch 18/50, Train Loss: 3.3414, Val Loss: 3.3467


Epoch 19/50: 100%|██████████| 32/32 [00:35<00:00,  1.11s/it]


Epoch 19/50, Train Loss: 3.3360, Val Loss: 3.3420


Epoch 20/50: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Epoch 20/50, Train Loss: 3.3311, Val Loss: 3.3372


Epoch 21/50: 100%|██████████| 32/32 [00:34<00:00,  1.09s/it]


Epoch 21/50, Train Loss: 3.3257, Val Loss: 3.3325


Epoch 22/50: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Epoch 22/50, Train Loss: 3.3207, Val Loss: 3.3278


Epoch 23/50: 100%|██████████| 32/32 [00:35<00:00,  1.10s/it]


Epoch 23/50, Train Loss: 3.3154, Val Loss: 3.3231


Epoch 24/50: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Epoch 24/50, Train Loss: 3.3104, Val Loss: 3.3185


Epoch 25/50: 100%|██████████| 32/32 [00:35<00:00,  1.10s/it]


Epoch 25/50, Train Loss: 3.3053, Val Loss: 3.3138


Epoch 26/50: 100%|██████████| 32/32 [00:35<00:00,  1.10s/it]


Epoch 26/50, Train Loss: 3.3003, Val Loss: 3.3091


Epoch 27/50: 100%|██████████| 32/32 [00:35<00:00,  1.11s/it]


Epoch 27/50, Train Loss: 3.2952, Val Loss: 3.3045


Epoch 28/50: 100%|██████████| 32/32 [00:34<00:00,  1.07s/it]


Epoch 28/50, Train Loss: 3.2903, Val Loss: 3.3000


Epoch 29/50: 100%|██████████| 32/32 [00:35<00:00,  1.12s/it]


Epoch 29/50, Train Loss: 3.2852, Val Loss: 3.2954


Epoch 30/50: 100%|██████████| 32/32 [00:35<00:00,  1.12s/it]


Epoch 30/50, Train Loss: 3.2803, Val Loss: 3.2907


Epoch 31/50: 100%|██████████| 32/32 [00:33<00:00,  1.06s/it]


Epoch 31/50, Train Loss: 3.2750, Val Loss: 3.2862


Epoch 32/50: 100%|██████████| 32/32 [00:34<00:00,  1.09s/it]


Epoch 32/50, Train Loss: 3.2702, Val Loss: 3.2817


Epoch 33/50: 100%|██████████| 32/32 [00:34<00:00,  1.09s/it]


Epoch 33/50, Train Loss: 3.2656, Val Loss: 3.2772


Epoch 34/50: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 34/50, Train Loss: 3.2606, Val Loss: 3.2727


Epoch 35/50: 100%|██████████| 32/32 [00:35<00:00,  1.12s/it]


Epoch 35/50, Train Loss: 3.2558, Val Loss: 3.2682


Epoch 36/50: 100%|██████████| 32/32 [00:32<00:00,  1.03s/it]


Epoch 36/50, Train Loss: 3.2508, Val Loss: 3.2637


Epoch 37/50: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Epoch 37/50, Train Loss: 3.2459, Val Loss: 3.2593


Epoch 38/50: 100%|██████████| 32/32 [00:34<00:00,  1.07s/it]


Epoch 38/50, Train Loss: 3.2410, Val Loss: 3.2549


Epoch 39/50: 100%|██████████| 32/32 [00:34<00:00,  1.07s/it]


Epoch 39/50, Train Loss: 3.2361, Val Loss: 3.2504


Epoch 40/50: 100%|██████████| 32/32 [00:35<00:00,  1.10s/it]


Epoch 40/50, Train Loss: 3.2314, Val Loss: 3.2460


Epoch 41/50: 100%|██████████| 32/32 [00:34<00:00,  1.09s/it]


Epoch 41/50, Train Loss: 3.2265, Val Loss: 3.2417


Epoch 42/50: 100%|██████████| 32/32 [00:34<00:00,  1.09s/it]


Epoch 42/50, Train Loss: 3.2219, Val Loss: 3.2373


Epoch 43/50: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Epoch 43/50, Train Loss: 3.2173, Val Loss: 3.2329


Epoch 44/50: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Epoch 44/50, Train Loss: 3.2124, Val Loss: 3.2286


Epoch 45/50: 100%|██████████| 32/32 [00:34<00:00,  1.07s/it]


Epoch 45/50, Train Loss: 3.2077, Val Loss: 3.2242


Epoch 46/50: 100%|██████████| 32/32 [00:33<00:00,  1.06s/it]


Epoch 46/50, Train Loss: 3.2027, Val Loss: 3.2199


Epoch 47/50: 100%|██████████| 32/32 [00:34<00:00,  1.07s/it]


Epoch 47/50, Train Loss: 3.1983, Val Loss: 3.2156


Epoch 48/50: 100%|██████████| 32/32 [00:32<00:00,  1.01s/it]


Epoch 48/50, Train Loss: 3.1932, Val Loss: 3.2113


Epoch 49/50: 100%|██████████| 32/32 [00:33<00:00,  1.06s/it]


Epoch 49/50, Train Loss: 3.1888, Val Loss: 3.2070


Epoch 50/50: 100%|██████████| 32/32 [00:35<00:00,  1.11s/it]


Epoch 50/50, Train Loss: 3.1839, Val Loss: 3.2028


In [19]:
from sklearn.metrics import accuracy_score, roc_auc_score
import torchmetrics

# Set the model to evaluation mode
perch_network.eval()


# Initialize the metrics
metrics = torchmetrics.MetricCollection({
    'T1Accuracy': torchmetrics.Accuracy(
        task="multiclass",
        num_classes=num_classes,
        top_k=1
    ),
    'T3Accuracy': torchmetrics.Accuracy(
        task="multiclass",
        num_classes=num_classes,
        top_k=3
    ),
    'AUROC': torchmetrics.AUROC(
        task="multiclass",
        num_classes=num_classes,
        average='macro'
    ),
    'F1': torchmetrics.F1Score(
        task="multiclass",
        num_classes=num_classes
    )
}).to(device)

# Iterate over the test_loader
for batch in test_loader:
    # Forward pass
    inputs = batch['audio'].to(device)
    labels = batch['labels'].to(device)
    with torch.no_grad():
        outputs = classifier(inputs)
        outputs = outputs.squeeze(1)
    
    # Update the metrics
    metrics(outputs, labels)

# Compute and print the metric values
metric_values = metrics.compute()
for metric_name, metric_value in metric_values.items():
    print(f"{metric_name}: {metric_value}")


AUROC: 0.9358636140823364
F1: 0.32743361592292786
T1Accuracy: 0.32743361592292786
T3Accuracy: 0.6342182755470276
