# Implementing linear probing pipeline from Ghani
Trying to 100% simulate ghani setup

### 1. Load dataset

In [1]:
from birdset.datamodule.beans_datamodule import BEANSDataModule
from birdset.datamodule.base_datamodule import DatasetConfig

datasetconfig = DatasetConfig(dataset_name='beans_watkins', hf_path='DBD-research-group/beans_watkins', hf_name='default')

datamodule = BEANSDataModule(dataset=datasetconfig)
dataset = datamodule._load_data()
dataset['train']

Map:   0%|          | 0/1017 [00:00<?, ? examples/s]

Map:   0%|          | 0/339 [00:00<?, ? examples/s]

Map:   0%|          | 0/339 [00:00<?, ? examples/s]

Map:   0%|          | 0/203 [00:00<?, ? examples/s]

Dataset({
    features: ['audio', 'labels'],
    num_rows: 1017
})

### 2. Load model and set parameters

In [2]:
from birdset.modules.models.perch import PerchModel
import torch.nn as nn

num_classes = 31
sampling_rate = 32_000 
window_length = 5

perch_network = PerchModel(num_classes=num_classes, tfhub_version=4, gpu_to_use=0)



2024-06-17 12:02:33.914340: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-17 12:02:33.914446: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-17 12:02:33.914479: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-17 12:02:33.923028: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-06-17 12:02:36.032213: I tensorflow/core/comm

### 3. Batch and Preprocess the dataset 

reprocess the dataset 

In [25]:
import torch
import torchaudio

# Resample function (#! Move resampler out)
# Get embeddings
def get_embedding(audio):
    # Get waveform and sampling rate
    waveform = torch.tensor(audio['array'], dtype=torch.float32)
    dataset_sampling_rate = audio['sampling_rate']
    # Resample audio
    audio = resample_audio(waveform, dataset_sampling_rate, sampling_rate)
    
    # Zero-padding
    audio = zero_pad(waveform)
    
    # Check if audio is too long 
    if waveform.shape[0] > window_length * sampling_rate:
        return frame_and_average(waveform)    
    else:
        return perch_network.get_embeddings(audio)[0] # To just use embeddings not logits

# Resample function
def resample_audio(audio, orig_sr, target_sr):
    resampler = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)
    return resampler(audio)

# Zero-padding function
def zero_pad(audio):
    desired_num_samples = window_length * sampling_rate 
    current_num_samples = audio.shape[0]
    padding = desired_num_samples - current_num_samples
    if padding > 0:
        #print('padding')
        pad_left = padding // 2
        pad_right = padding - pad_left
        audio = torch.nn.functional.pad(audio, (pad_left, pad_right))
    return audio

# Average multiple embeddings function
def frame_and_average(audio):
    # Ensure the waveform is mono
    #if audio.size(0) > 1:
        #print("What")
        #audio = audio.mean(dim=0, keepdim=True)
    
    # Frame the audio
    frame_size = window_length * sampling_rate
    hop_size = window_length * sampling_rate
    frames = audio.unfold(0, frame_size, hop_size)
    
    # Generate embeddings for each frame
    l = []
    for frame in frames:
        embedding = perch_network.get_embeddings(frame) 
        l.append(embedding[0]) # To just use embeddings not logits
    
    embeddings = torch.stack(tuple(l))
    
    # Average the embeddings
    averaged_embedding = embeddings.mean(dim=0)
    
    return averaged_embedding


In [47]:
from torch.utils.data import DataLoader

def preprocess(item):
    audio = item['audio']
    return get_embedding(audio)

def collate_fn(batch):
    batch_new = {}
    audios = [preprocess(item) for item in batch]
    batch_new['audio'] =  torch.stack(tuple(audios), dim=0)
    batch_new['labels'] = torch.tensor([item['labels'] for item in batch])
    return batch_new

train_loader = DataLoader(dataset['train'], batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(dataset['test'], batch_size=32, shuffle=False, collate_fn=collate_fn)
val_loader = DataLoader(dataset['valid'], batch_size=32, shuffle=False, collate_fn=collate_fn)

# Example of iterating through the DataLoader
for batch in train_loader:
    print(batch.keys())
    print(batch['audio'])
    print(batch['labels'])
    print(batch['audio'].shape)    
    break

dict_keys(['audio', 'labels'])
tensor([[[ 1.0092e-02,  2.1218e-01,  3.0808e-03,  ...,  5.2571e-05,
           1.1198e-01, -2.5705e-02]],

        [[-6.4186e-02, -2.8728e-02,  1.0271e-01,  ...,  1.0650e-01,
           6.4827e-02, -7.4666e-03]],

        [[ 2.4990e-02, -8.8441e-02, -2.1365e-02,  ..., -3.5433e-03,
           2.9607e-02, -6.0116e-02]],

        ...,

        [[-3.4528e-02, -1.7473e-02,  1.9105e-02,  ...,  4.8836e-02,
           2.3408e-01,  3.1219e-02]],

        [[-2.4960e-02,  1.1930e-01,  2.1271e-02,  ...,  3.5058e-03,
           5.9749e-03, -3.6126e-02]],

        [[-6.3046e-02,  1.5308e-01,  3.1311e-02,  ..., -2.0878e-03,
           6.5515e-02,  2.8494e-02]]])
tensor([19, 17, 22, 23,  8, 12, 27, 29, 12,  8, 18, 12, 20,  8, 21,  9,  0,  7,
         9, 21, 10, 12, 21,  8, 12, 26, 25, 27, 14,  2, 12, 15])
torch.Size([32, 1, 1280])


### 4. Train the classifier

In [51]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [53]:
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch.nn as nn
import torch.optim as optim

# Define your classifier model
class Classifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(input_size, num_classes)
    
    def forward(self, x):
        x = self.fc(x)
        x = x.squeeze(1)
        return x

# Set the input size and number of classes
input_size = 1280  # Update with your desired input size
num_classes = 31  # Update with your desired number of classes

# Create an instance of your classifier model
classifier = Classifier(input_size, num_classes).to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(classifier.parameters(), lr=1e-5, weight_decay=0.01)

# Set the number of training epochs
num_epochs = 50  

early_stopping_patience = 5
best_loss = float('inf')
patience_counter = 0

# Training loop
for epoch in range(num_epochs):
    classifier.train()
    train_loss = 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs = batch['audio'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        outputs = classifier(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # Calculate average loss for this epoch
    train_loss /= len(train_loader)

    # Validate the model (assuming you have a validation loader)
    classifier.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch['audio'].to(device)
            labels = batch['labels'].to(device)
            outputs = classifier(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
    
    val_loss /= len(val_loader)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Early stopping
    if val_loss < best_loss:
        best_loss = val_loss
        patience_counter = 0
        # Save the best model weights if necessary
        torch.save(classifier.state_dict(), 'best_model.pth')
    else:
        patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print("Early stopping triggered")
            break

Epoch 1/50:   0%|          | 0/32 [00:00<?, ?it/s]

Epoch 1/50:  22%|██▏       | 7/32 [00:04<00:13,  1.86it/s]

In [None]:
# Training loop
for epoch in range(num_epochs):
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        # Get the inputs and labels from the batch
        inputs = batch['audio']
        labels = batch['labels']
        
        # Forward pass
        outputs = classifier(inputs)
        outputs = outputs.squeeze(1)

        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    
    # Print the loss for this epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

In [42]:
from sklearn.metrics import accuracy_score, roc_auc_score
import torchmetrics

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# Set the model to evaluation mode
perch_network.eval()


# Initialize the metrics
metrics = torchmetrics.MetricCollection({
    'T1Accuracy': torchmetrics.Accuracy(
        task="multiclass",
        num_classes=num_classes,
        top_k=1
    ),
    'T3Accuracy': torchmetrics.Accuracy(
        task="multiclass",
        num_classes=num_classes,
        top_k=3
    ),
    'AUROC': torchmetrics.AUROC(
        task="multiclass",
        num_classes=num_classes,
        average='macro'
    ),
    'F1': torchmetrics.F1Score(
        task="multiclass",
        num_classes=num_classes
    )
})

# Iterate over the test_loader
for batch in test_loader:
    # Move the batch to the device
    #batch = {key: value.to(device) for key, value in batch.items()}
    
    # Forward pass
    inputs = batch['audio']
    labels = batch['labels']
    with torch.no_grad():
        outputs = classifier(inputs)
        outputs = outputs.squeeze(1)
    
    # Update the metrics
    metrics(outputs, labels)

# Compute and print the metric values
metric_values = metrics.compute()
for metric_name, metric_value in metric_values.items():
    print(f"{metric_name}: {metric_value}")


cuda
AUROC: 0.7451189756393433
F1: 0.11209439486265182
T1Accuracy: 0.11209439486265182
T3Accuracy: 0.2979350984096527
