# Implementing linear probing pipeline from Ghani
Trying to 100% simulate ghani setup

### 1. Load dataset

In [2]:
from birdset.datamodule.beans_datamodule import BEANSDataModule
from birdset.datamodule.base_datamodule import DatasetConfig

datasetconfig = DatasetConfig(dataset_name='beans_watkins', hf_path='DBD-research-group/beans_watkins', hf_name='default')

datamodule = BEANSDataModule(dataset=datasetconfig)
dataset = datamodule._load_data()
dataset['train']

Map:   0%|          | 0/1017 [00:00<?, ? examples/s]

Map:   0%|          | 0/339 [00:00<?, ? examples/s]

Map:   0%|          | 0/339 [00:00<?, ? examples/s]

Map:   0%|          | 0/203 [00:00<?, ? examples/s]

Dataset({
    features: ['audio', 'labels'],
    num_rows: 1017
})

### 2. Load model and set parameters

In [3]:
from birdset.modules.models.perch import PerchModel
import torch.nn as nn

num_classes = 31
sampling_rate = 32_000 
window_length = 5

perch_network = PerchModel(num_classes=num_classes, tfhub_version=4, gpu_to_use=0)

2024-06-18 15:21:02.225414: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-18 15:21:02.225491: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-18 15:21:02.225537: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-18 15:21:02.234045: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-06-18 15:21:04.429950: I tensorflow/core/comm

### 3. Batch and Preprocess the dataset 

reprocess the dataset 

In [4]:
import torch
import torchaudio

# Resample function (#! Move resampler out)
# Get embeddings
def get_embedding(audio):
    # Get waveform and sampling rate
    waveform = torch.tensor(audio['array'], dtype=torch.float32)
    dataset_sampling_rate = audio['sampling_rate']
    # Resample audio
    audio = resample_audio(waveform, dataset_sampling_rate, sampling_rate)
    
    # Zero-padding
    audio = zero_pad(waveform)
    
    # Check if audio is too long 
    if waveform.shape[0] > window_length * sampling_rate:
        return frame_and_average(waveform)    
    else:
        return perch_network.get_embeddings(audio)[0] # To just use embeddings not logits

# Resample function
def resample_audio(audio, orig_sr, target_sr):
    resampler = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)
    return resampler(audio)

# Zero-padding function
def zero_pad(audio):
    desired_num_samples = window_length * sampling_rate 
    current_num_samples = audio.shape[0]
    padding = desired_num_samples - current_num_samples
    if padding > 0:
        #print('padding')
        pad_left = padding // 2
        pad_right = padding - pad_left
        audio = torch.nn.functional.pad(audio, (pad_left, pad_right))
    return audio

# Average multiple embeddings function
def frame_and_average(audio):
    # Ensure the waveform is mono
    #if audio.size(0) > 1:
        #print("What")
        #audio = audio.mean(dim=0, keepdim=True)
    
    # Frame the audio
    frame_size = window_length * sampling_rate
    hop_size = window_length * sampling_rate
    frames = audio.unfold(0, frame_size, hop_size)
    
    # Generate embeddings for each frame
    l = []
    for frame in frames:
        embedding = perch_network.get_embeddings(frame) 
        l.append(embedding[0]) # To just use embeddings not logits
    
    embeddings = torch.stack(tuple(l))
    
    # Average the embeddings
    averaged_embedding = embeddings.mean(dim=0)
    
    return averaged_embedding


In [48]:
from torch.utils.data import DataLoader

def preprocess(item):
    audio = item['audio']
    return get_embedding(audio)

def collate_fn(batch):
    batch_new = {}
    audios = [preprocess(item) for item in batch]
    batch_new['audio'] =  torch.stack(tuple(audios), dim=0)
    
    #batch_new['labels'] = torch.stack([torch.nn.functional.one_hot(torch.tensor(item['labels'],  dtype=torch.long), num_classes=num_classes) for item in batch]).float() #* For one hot-encoding 
    batch_new['labels'] = torch.tensor([item['labels'] for item in batch])
    return batch_new

train_loader = DataLoader(dataset['train'], batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(dataset['test'], batch_size=32, shuffle=False, collate_fn=collate_fn)
val_loader = DataLoader(dataset['valid'], batch_size=32, shuffle=False, collate_fn=collate_fn)

# Example of iterating through the DataLoader
for batch in train_loader:
    print(batch.keys())
    print(batch['audio'])
    print(batch['labels'])
    print(batch['audio'].shape)    
    break

dict_keys(['audio', 'labels'])
tensor([[[-0.0153, -0.0212, -0.0061,  ..., -0.0100,  0.1000,  0.0346]],

        [[-0.0792, -0.0941, -0.0358,  ..., -0.0205,  0.0859,  0.0776]],

        [[-0.0760, -0.0653, -0.0081,  ..., -0.0532,  0.1101,  0.0993]],

        ...,

        [[-0.0229,  0.0147,  0.0107,  ...,  0.0608,  0.1445,  0.0305]],

        [[ 0.0732, -0.1096, -0.0057,  ...,  0.0099,  0.0103, -0.0198]],

        [[ 0.1139,  0.0057, -0.0244,  ...,  0.0200,  0.1124, -0.0140]]])
tensor([ 7,  3, 27, 19, 21,  7, 12, 14, 10,  9,  9, 17, 30, 12, 16, 29, 15, 24,
         8, 19, 19,  5, 23, 18, 18, 19,  3, 29, 23, 14, 16, 12])
torch.Size([32, 1, 1280])


### 4. Train the classifier

In [49]:
gpu_id = 1  # Change this to the ID of the GPU you want to use
device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}') #! Not working right now

Using device: cuda:1


In [50]:
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch.nn as nn
import torch.optim as optim

# Define your classifier model
class Classifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(input_size, num_classes)
        self.log_softmax = nn.LogSoftmax(dim=1) #* self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = x.squeeze(1)
        x = self.log_softmax(self.fc(x))
        return x

# Set the input size and number of classes
input_size = 1280

# Create an instance of your classifier model
classifier = Classifier(input_size, num_classes).to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss() #* nn.BCELoss()
optimizer = optim.AdamW(classifier.parameters(), lr=1e-5, weight_decay=0.01)

# Set the number of training epochs
num_epochs = 100  

early_stopping_patience = 5
best_loss = float('inf')
patience_counter = 0

# Training loop
for epoch in range(num_epochs):
    classifier.train()
    train_loss = 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs = batch['audio'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        outputs = classifier(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # Calculate average loss for this epoch
    train_loss /= len(train_loader)

    # Validate the model (assuming you have a validation loader)
    classifier.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch['audio'].to(device)
            labels = batch['labels'].to(device)
            outputs = classifier(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
    
    val_loss /= len(val_loader)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Early stopping
    if val_loss < best_loss:
        best_loss = val_loss
        patience_counter = 0
        # Save the best model weights if necessary
        torch.save(classifier.state_dict(), 'best_model.pth')
    else:
        patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print("Early stopping triggered")
            break

Epoch 1/100: 100%|██████████| 32/32 [00:32<00:00,  1.02s/it]


Epoch 1/100, Train Loss: 3.4394, Val Loss: 3.4329


Epoch 2/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 2/100, Train Loss: 3.4337, Val Loss: 3.4278


Epoch 3/100: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]


Epoch 3/100, Train Loss: 3.4282, Val Loss: 3.4227


Epoch 4/100: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]


Epoch 4/100, Train Loss: 3.4226, Val Loss: 3.4177


Epoch 5/100: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Epoch 5/100, Train Loss: 3.4171, Val Loss: 3.4125


Epoch 6/100: 100%|██████████| 32/32 [00:32<00:00,  1.02s/it]


Epoch 6/100, Train Loss: 3.4117, Val Loss: 3.4076


Epoch 7/100: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]


Epoch 7/100, Train Loss: 3.4063, Val Loss: 3.4025


Epoch 8/100: 100%|██████████| 32/32 [00:33<00:00,  1.06s/it]


Epoch 8/100, Train Loss: 3.4008, Val Loss: 3.3975


Epoch 9/100: 100%|██████████| 32/32 [00:30<00:00,  1.05it/s]


Epoch 9/100, Train Loss: 3.3954, Val Loss: 3.3925


Epoch 10/100: 100%|██████████| 32/32 [00:31<00:00,  1.01it/s]


Epoch 10/100, Train Loss: 3.3901, Val Loss: 3.3875


Epoch 11/100: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]


Epoch 11/100, Train Loss: 3.3847, Val Loss: 3.3825


Epoch 12/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 12/100, Train Loss: 3.3794, Val Loss: 3.3776


Epoch 13/100: 100%|██████████| 32/32 [00:31<00:00,  1.01it/s]


Epoch 13/100, Train Loss: 3.3739, Val Loss: 3.3727


Epoch 14/100: 100%|██████████| 32/32 [00:32<00:00,  1.01s/it]


Epoch 14/100, Train Loss: 3.3686, Val Loss: 3.3678


Epoch 15/100: 100%|██████████| 32/32 [00:32<00:00,  1.02s/it]


Epoch 15/100, Train Loss: 3.3633, Val Loss: 3.3629


Epoch 16/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 16/100, Train Loss: 3.3581, Val Loss: 3.3581


Epoch 17/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 17/100, Train Loss: 3.3527, Val Loss: 3.3532


Epoch 18/100: 100%|██████████| 32/32 [00:31<00:00,  1.02it/s]


Epoch 18/100, Train Loss: 3.3476, Val Loss: 3.3484


Epoch 19/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 19/100, Train Loss: 3.3423, Val Loss: 3.3436


Epoch 20/100: 100%|██████████| 32/32 [00:33<00:00,  1.06s/it]


Epoch 20/100, Train Loss: 3.3372, Val Loss: 3.3388


Epoch 21/100: 100%|██████████| 32/32 [00:31<00:00,  1.02it/s]


Epoch 21/100, Train Loss: 3.3320, Val Loss: 3.3341


Epoch 22/100: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]


Epoch 22/100, Train Loss: 3.3268, Val Loss: 3.3293


Epoch 23/100: 100%|██████████| 32/32 [00:32<00:00,  1.01s/it]


Epoch 23/100, Train Loss: 3.3217, Val Loss: 3.3246


Epoch 24/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 24/100, Train Loss: 3.3164, Val Loss: 3.3200


Epoch 25/100: 100%|██████████| 32/32 [00:33<00:00,  1.06s/it]


Epoch 25/100, Train Loss: 3.3118, Val Loss: 3.3152


Epoch 26/100: 100%|██████████| 32/32 [00:33<00:00,  1.03s/it]


Epoch 26/100, Train Loss: 3.3062, Val Loss: 3.3106


Epoch 27/100: 100%|██████████| 32/32 [00:31<00:00,  1.02it/s]


Epoch 27/100, Train Loss: 3.3016, Val Loss: 3.3060


Epoch 28/100: 100%|██████████| 32/32 [00:31<00:00,  1.02it/s]


Epoch 28/100, Train Loss: 3.2966, Val Loss: 3.3013


Epoch 29/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 29/100, Train Loss: 3.2914, Val Loss: 3.2968


Epoch 30/100: 100%|██████████| 32/32 [00:32<00:00,  1.01s/it]


Epoch 30/100, Train Loss: 3.2866, Val Loss: 3.2922


Epoch 31/100: 100%|██████████| 32/32 [00:34<00:00,  1.07s/it]


Epoch 31/100, Train Loss: 3.2814, Val Loss: 3.2876


Epoch 32/100: 100%|██████████| 32/32 [00:34<00:00,  1.07s/it]


Epoch 32/100, Train Loss: 3.2765, Val Loss: 3.2831


Epoch 33/100: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Epoch 33/100, Train Loss: 3.2717, Val Loss: 3.2785


Epoch 34/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 34/100, Train Loss: 3.2667, Val Loss: 3.2740


Epoch 35/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 35/100, Train Loss: 3.2618, Val Loss: 3.2695


Epoch 36/100: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Epoch 36/100, Train Loss: 3.2566, Val Loss: 3.2651


Epoch 37/100: 100%|██████████| 32/32 [00:32<00:00,  1.02s/it]


Epoch 37/100, Train Loss: 3.2520, Val Loss: 3.2606


Epoch 38/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 38/100, Train Loss: 3.2469, Val Loss: 3.2561


Epoch 39/100: 100%|██████████| 32/32 [00:32<00:00,  1.01s/it]


Epoch 39/100, Train Loss: 3.2424, Val Loss: 3.2516


Epoch 40/100: 100%|██████████| 32/32 [00:32<00:00,  1.00s/it]


Epoch 40/100, Train Loss: 3.2374, Val Loss: 3.2472


Epoch 41/100: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Epoch 41/100, Train Loss: 3.2327, Val Loss: 3.2428


Epoch 42/100: 100%|██████████| 32/32 [00:33<00:00,  1.03s/it]


Epoch 42/100, Train Loss: 3.2278, Val Loss: 3.2384


Epoch 43/100: 100%|██████████| 32/32 [00:32<00:00,  1.01s/it]


Epoch 43/100, Train Loss: 3.2231, Val Loss: 3.2341


Epoch 44/100: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]


Epoch 44/100, Train Loss: 3.2187, Val Loss: 3.2297


Epoch 45/100: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]


Epoch 45/100, Train Loss: 3.2135, Val Loss: 3.2253


Epoch 46/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 46/100, Train Loss: 3.2084, Val Loss: 3.2210


Epoch 47/100: 100%|██████████| 32/32 [00:32<00:00,  1.02s/it]


Epoch 47/100, Train Loss: 3.2039, Val Loss: 3.2167


Epoch 48/100: 100%|██████████| 32/32 [00:31<00:00,  1.02it/s]


Epoch 48/100, Train Loss: 3.1995, Val Loss: 3.2123


Epoch 49/100: 100%|██████████| 32/32 [00:32<00:00,  1.01s/it]


Epoch 49/100, Train Loss: 3.1944, Val Loss: 3.2081


Epoch 50/100: 100%|██████████| 32/32 [00:32<00:00,  1.01s/it]


Epoch 50/100, Train Loss: 3.1901, Val Loss: 3.2038


Epoch 51/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 51/100, Train Loss: 3.1850, Val Loss: 3.1995


Epoch 52/100: 100%|██████████| 32/32 [00:32<00:00,  1.02s/it]


Epoch 52/100, Train Loss: 3.1804, Val Loss: 3.1952


Epoch 53/100: 100%|██████████| 32/32 [00:33<00:00,  1.06s/it]


Epoch 53/100, Train Loss: 3.1761, Val Loss: 3.1910


Epoch 54/100: 100%|██████████| 32/32 [00:32<00:00,  1.03s/it]


Epoch 54/100, Train Loss: 3.1710, Val Loss: 3.1867


Epoch 55/100: 100%|██████████| 32/32 [00:32<00:00,  1.02s/it]


Epoch 55/100, Train Loss: 3.1669, Val Loss: 3.1825


Epoch 56/100: 100%|██████████| 32/32 [00:34<00:00,  1.07s/it]


Epoch 56/100, Train Loss: 3.1619, Val Loss: 3.1783


Epoch 57/100: 100%|██████████| 32/32 [00:32<00:00,  1.02s/it]


Epoch 57/100, Train Loss: 3.1570, Val Loss: 3.1741


Epoch 58/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 58/100, Train Loss: 3.1526, Val Loss: 3.1699


Epoch 59/100: 100%|██████████| 32/32 [00:33<00:00,  1.03s/it]


Epoch 59/100, Train Loss: 3.1487, Val Loss: 3.1658


Epoch 60/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 60/100, Train Loss: 3.1438, Val Loss: 3.1616


Epoch 61/100: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]


Epoch 61/100, Train Loss: 3.1392, Val Loss: 3.1574


Epoch 62/100: 100%|██████████| 32/32 [00:34<00:00,  1.06s/it]


Epoch 62/100, Train Loss: 3.1347, Val Loss: 3.1533


Epoch 63/100: 100%|██████████| 32/32 [00:33<00:00,  1.06s/it]


Epoch 63/100, Train Loss: 3.1299, Val Loss: 3.1492


Epoch 64/100: 100%|██████████| 32/32 [00:32<00:00,  1.02s/it]


Epoch 64/100, Train Loss: 3.1259, Val Loss: 3.1451


Epoch 65/100: 100%|██████████| 32/32 [00:30<00:00,  1.05it/s]


Epoch 65/100, Train Loss: 3.1210, Val Loss: 3.1409


Epoch 66/100: 100%|██████████| 32/32 [00:33<00:00,  1.06s/it]


Epoch 66/100, Train Loss: 3.1169, Val Loss: 3.1368


Epoch 67/100: 100%|██████████| 32/32 [00:31<00:00,  1.00it/s]


Epoch 67/100, Train Loss: 3.1121, Val Loss: 3.1328


Epoch 68/100: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]


Epoch 68/100, Train Loss: 3.1076, Val Loss: 3.1287


Epoch 69/100: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]


Epoch 69/100, Train Loss: 3.1036, Val Loss: 3.1245


Epoch 70/100: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]


Epoch 70/100, Train Loss: 3.0990, Val Loss: 3.1206


Epoch 71/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 71/100, Train Loss: 3.0949, Val Loss: 3.1165


Epoch 72/100: 100%|██████████| 32/32 [00:32<00:00,  1.03s/it]


Epoch 72/100, Train Loss: 3.0902, Val Loss: 3.1125


Epoch 73/100: 100%|██████████| 32/32 [00:33<00:00,  1.06s/it]


Epoch 73/100, Train Loss: 3.0858, Val Loss: 3.1085


Epoch 74/100: 100%|██████████| 32/32 [00:32<00:00,  1.02s/it]


Epoch 74/100, Train Loss: 3.0816, Val Loss: 3.1045


Epoch 75/100: 100%|██████████| 32/32 [00:32<00:00,  1.01s/it]


Epoch 75/100, Train Loss: 3.0773, Val Loss: 3.1005


Epoch 76/100: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]


Epoch 76/100, Train Loss: 3.0722, Val Loss: 3.0965


Epoch 77/100: 100%|██████████| 32/32 [00:32<00:00,  1.00s/it]


Epoch 77/100, Train Loss: 3.0684, Val Loss: 3.0925


Epoch 78/100: 100%|██████████| 32/32 [00:32<00:00,  1.01s/it]


Epoch 78/100, Train Loss: 3.0636, Val Loss: 3.0885


Epoch 79/100: 100%|██████████| 32/32 [00:32<00:00,  1.01s/it]


Epoch 79/100, Train Loss: 3.0595, Val Loss: 3.0846


Epoch 80/100: 100%|██████████| 32/32 [00:34<00:00,  1.07s/it]


Epoch 80/100, Train Loss: 3.0550, Val Loss: 3.0806


Epoch 81/100: 100%|██████████| 32/32 [00:32<00:00,  1.03s/it]


Epoch 81/100, Train Loss: 3.0511, Val Loss: 3.0767


Epoch 82/100: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]


Epoch 82/100, Train Loss: 3.0465, Val Loss: 3.0727


Epoch 83/100: 100%|██████████| 32/32 [00:32<00:00,  1.01s/it]


Epoch 83/100, Train Loss: 3.0420, Val Loss: 3.0688


Epoch 84/100: 100%|██████████| 32/32 [00:32<00:00,  1.03s/it]


Epoch 84/100, Train Loss: 3.0379, Val Loss: 3.0649


Epoch 85/100: 100%|██████████| 32/32 [00:33<00:00,  1.06s/it]


Epoch 85/100, Train Loss: 3.0340, Val Loss: 3.0610


Epoch 86/100: 100%|██████████| 32/32 [00:33<00:00,  1.03s/it]


Epoch 86/100, Train Loss: 3.0289, Val Loss: 3.0571


Epoch 87/100: 100%|██████████| 32/32 [00:32<00:00,  1.02s/it]


Epoch 87/100, Train Loss: 3.0250, Val Loss: 3.0532


Epoch 88/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 88/100, Train Loss: 3.0205, Val Loss: 3.0493


Epoch 89/100: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]


Epoch 89/100, Train Loss: 3.0170, Val Loss: 3.0455


Epoch 90/100: 100%|██████████| 32/32 [00:35<00:00,  1.10s/it]


Epoch 90/100, Train Loss: 3.0120, Val Loss: 3.0416


Epoch 91/100: 100%|██████████| 32/32 [00:33<00:00,  1.05s/it]


Epoch 91/100, Train Loss: 3.0079, Val Loss: 3.0377


Epoch 92/100: 100%|██████████| 32/32 [00:34<00:00,  1.06s/it]


Epoch 92/100, Train Loss: 3.0032, Val Loss: 3.0339


Epoch 93/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 93/100, Train Loss: 3.0001, Val Loss: 3.0300


Epoch 94/100: 100%|██████████| 32/32 [00:32<00:00,  1.02s/it]


Epoch 94/100, Train Loss: 2.9955, Val Loss: 3.0262


Epoch 95/100: 100%|██████████| 32/32 [00:32<00:00,  1.03s/it]


Epoch 95/100, Train Loss: 2.9908, Val Loss: 3.0224


Epoch 96/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 96/100, Train Loss: 2.9868, Val Loss: 3.0185


Epoch 97/100: 100%|██████████| 32/32 [00:32<00:00,  1.02s/it]


Epoch 97/100, Train Loss: 2.9832, Val Loss: 3.0147


Epoch 98/100: 100%|██████████| 32/32 [00:32<00:00,  1.03s/it]


Epoch 98/100, Train Loss: 2.9784, Val Loss: 3.0110


Epoch 99/100: 100%|██████████| 32/32 [00:33<00:00,  1.04s/it]


Epoch 99/100, Train Loss: 2.9742, Val Loss: 3.0071


Epoch 100/100: 100%|██████████| 32/32 [00:30<00:00,  1.04it/s]


Epoch 100/100, Train Loss: 2.9702, Val Loss: 3.0034


In [51]:
from sklearn.metrics import accuracy_score, roc_auc_score
import torchmetrics

# Set the model to evaluation mode
perch_network.eval()


# Initialize the metrics
metrics = torchmetrics.MetricCollection({
    'T1Accuracy': torchmetrics.Accuracy(
        task="multiclass",
        num_classes=num_classes,
        top_k=1
    ),
    'T3Accuracy': torchmetrics.Accuracy(
        task="multiclass",
        num_classes=num_classes,
        top_k=3
    ),
    'AUROC': torchmetrics.AUROC(
        task="multiclass",
        num_classes=num_classes,
        average='macro'
    ),
    'F1': torchmetrics.F1Score(
        task="multiclass",
        num_classes=num_classes
    )
}).to(device)

# Iterate over the test_loader
for batch in test_loader:
    # Forward pass
    inputs = batch['audio'].to(device)
    labels = batch['labels'].to(device)
    #labels = torch.argmax(labels, dim=1) #* For one hot-encoding 
    with torch.no_grad():
        outputs = classifier(inputs)
        outputs = outputs.squeeze(1)
    
    # Update the metrics
    metrics(outputs, labels)

# Compute and print the metric values
metric_values = metrics.compute()
for metric_name, metric_value in metric_values.items():
    print(f"{metric_name}: {metric_value}")


AUROC: 0.9557977318763733
F1: 0.37168142199516296
T1Accuracy: 0.37168142199516296
T3Accuracy: 0.6991150379180908
