# Waveform model

Build and train model that takes raw waveforms as input

## Import libraries

In [None]:
import os
import torch
import torch.nn as nn
import torchsummary
import torch.nn.functional as F
import librosa
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
from pathlib import Path
from sklearn.metrics import confusion_matrix

## Create custom dataset

In [None]:
SAMPLE_RATE = 16000
FRAME_LENGTH = SAMPLE_RATE  # 1 second per frame (16000 samples at 16 kHz)
FRAME_HOP = FRAME_LENGTH // 2  # 50% overlap between frames

class BirdAudioDataset(Dataset):
    def __init__(self, root_dir, sample_rate=SAMPLE_RATE, frame_length=FRAME_LENGTH, frame_hop=FRAME_HOP, transform=None):
        self.root_dir = root_dir
        self.sample_rate = sample_rate
        self.frame_length = frame_length
        self.frame_hop = frame_hop
        self.transform = transform
        self.audio_paths = []
        self.labels = []
        self.label_encoder = LabelEncoder()
        
        # Load audio file paths and labels
        for label in os.listdir(root_dir)[:4]:
            label_dir = os.path.join(root_dir, label)
            if os.path.isdir(label_dir):
                for file_name in os.listdir(label_dir)[:2]:
                    if file_name.endswith(".wav") or file_name.endswith(".mp3"):
                        self.audio_paths.append(os.path.join(label_dir, file_name))
                        self.labels.append(label)
        
        # Encode labels (it assumes that in all directories there will be the same directories in the same order)
        self.labels = self.label_encoder.fit_transform(self.labels)

    def __len__(self):
        return len(self.audio_paths)
    
    def __getitem__(self, idx):
        audio_path = self.audio_paths[idx]
        label = self.labels[idx]

        try:
            waveform, sr = librosa.load(audio_path, sr=self.sample_rate)  # Resamples to sample_rate if needed
        except Exception as e:
            print(f"Error loading {audio_path}: {e}")
            return None

        max_length = self.sample_rate * 20  # 20 seconds
        if waveform.shape[0] > max_length:
            waveform = waveform[:max_length]
            
        waveform = torch.tensor(waveform)

        # Ensure the waveform is at least as long as the frame_length
        if waveform.shape[0] < self.frame_length:
            # Pad or truncate the waveform if it's too short (use padding_value=0)
            waveform = F.pad(waveform, (0, self.frame_length - waveform.shape[0]))

        # Split waveform into frames with specified length and hop
        frames = waveform.unfold(dimension=0, size=self.frame_length, step=self.frame_hop)
        
        # Apply any additional transform
        if self.transform:
            frames = self.transform(frames)

        return frames, label