In [13]:
pip install torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
[0mNote: you may need to restart the kernel to use updated packages.


In [15]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

import torchaudio
from torchaudio.transforms import *
from torchvision.transforms import Compose

from tqdm import tqdm
from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [5]:
labels = os.listdir("/kaggle/input/birdclef-2023/train_audio")
labels = pd.Series(labels)

In [6]:
class BirdRNNDataset(Dataset):
    def __init__(self, root_dir, transforms=None):
        super(BirdRNNDataset, self).__init__()
        self.root_dir = root_dir
        
        self.labels = os.listdir(root_dir)
        
        self.len_labels = len(labels)
        labels_ind = torch.arange(0, self.len_labels)
        
        self.targets = F.one_hot(labels_ind)
        
        self.items = [(label, elem) for label in labels \
                      for elem in os.listdir(os.path.join(root_dir, label))]
        
        self.transforms = transforms
        
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, index):
        label, filename = self.items[index]
        filepath = os.path.join(self.root_dir, label, filename)
        
        audio, _ = torchaudio.load(filepath)
        audio = self.to_mono(audio)
        
        if self.transforms:
            audio = transforms(audio)
        
        target_ind = self.labels.index(label)
        target = self.targets[target_ind]
        
        return audio, target
    
        
    def to_mono(self, audio):
        return torch.mean(audio, axis=0)
    
    def labels_count(self):
        return self.len_labels

In [7]:
def collate_batch(data):
    inputs, labels = zip(*data)
    max_length = max(inputs, key=lambda x: x.size()[1]).size()[1]
    dim_size = inputs[0].size()[0]
    
    batch = torch.zeros(size=(len(data), dim_size, max_length))
    for i, elem in enumerate(inputs):
        batch[i] =  torch.cat([elem, torch.zeros(dim_size, max_length - elem.size()[1])], axis=1)
    
    batch = batch.transpose(1, 2)
    targets = torch.stack(labels)
    
    return batch, targets

In [8]:
class BirdNet(nn.Module):
    def __init__(self, n_mels, n_labels, bidirectional=False, hidden_size=600, lstm_dropout=0.):
        super(BirdNet, self).__init__()
        self.gru1 = nn.GRU(input_size=n_mels, hidden_size=hidden_size, batch_first=True, bidirectional=bidirectional)
        self.gru2 = nn.GRU(input_size=hidden_size, hidden_size=hidden_size, batch_first=True, bidirectional=bidirectional)
        self.fc1 = nn.Linear(in_features=hidden_size, out_features=hidden_size//2)
        self.fc2 = nn.Linear(in_features=hidden_size//2, out_features=n_labels)
        
        self.dropout=nn.Dropout()
        self.relu = nn.ReLU()
        
        self.softmax = nn.Softmax()
        
    def forward(self, x):
        output1, h_n = self.gru1(x)
        rel1 = self.relu(output1)
        
        output2 = self.gru(rel1, h_n)
        rel2 = self.rel(output2)
        
        dense1 = self.fc1(rel)
        drop = self.dropout(dense1)
        dense2 = self.fc2(drop)
        #preds = self.softmax(dense2)
        return dense2

In [9]:
TRAIN_DIR = "/kaggle/input/birdclef-2023/train_audio"

device = "cuda" if torch.cuda.is_available() else "cpu"

resample_freq = 16000
freq_mask = 10
time_stretch_coeff = 0.8
n_fft = 1024
n_mels = 256
hidden_size = 600
batch_size = 4


In [10]:
transforms = Compose([
    Resample(orig_freq=32000, new_freq=resample_freq),
    Spectrogram(n_fft=n_fft, power=2.),
    MelScale(n_mels=n_mels, n_stft=n_fft // 2 + 1, sample_rate=resample_freq)
])


  "At least one mel filterbank has all zero values. "


In [11]:
ds = BirdRNNDataset(TRAIN_DIR, bidirectional=True, transforms=transforms)
train_ds, val_ds = torch.utils.data.random_split(ds, [0.8, 0.2])

train_dl = DataLoader(train_ds,
                    batch_size=batch_size,
                    collate_fn=collate_batch,
                     )

val_dl = DataLoader(val_ds,
                    batch_size=batch_size,
                    collate_fn=collate_batch,
                   )

In [12]:
for audio, target in train_dl:
    print(audio.shape)
    break

torch.Size([4, 2072, 256])


In [16]:
n_labels = ds.labels_count()

model = BirdNet(n_mels=n_mels, hidden_size=hidden_size, n_labels=n_labels)
model.to(device)

criterion = nn.CrossEntropyLoss() 
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [17]:
summary(model, (2072, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
               GRU-1  [[-1, 2072, 600], [-1, 2, 600]]               0
              ReLU-2            [-1, 2072, 600]               0
            Linear-3            [-1, 2072, 300]         180,300
           Dropout-4            [-1, 2072, 300]               0
            Linear-5            [-1, 2072, 264]          79,464
Total params: 259,764
Trainable params: 259,764
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 2.02
Forward/backward pass size (MB): 11358.69
Params size (MB): 0.99
Estimated Total Size (MB): 11361.71
----------------------------------------------------------------


In [None]:
n_epochs = 5

train_history = []
val_history = []

for epoch in range(n_epochs):
    print(f'Epoch: {epoch + 1}')
    train_epoch_history = []
    val_epoch_history = []
    
    model.train()
    for data, label in tqdm(train_dl):
        data = data.to(device)
        target = label.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_epoch_history.append(loss.item())
        
    model.eval()
    for data, label in tqdm(val_dl):
        data = data.to(device)
        target = label.to(device)
        
        output = model(data)
        loss = criterion(output, target)
        val_epoch_history.append(loss.item())
    
    train_average_loss = np.mean(train_epoch_history)
    val_average_loss = np.mean(val_epoch_history)
    train_history.append(train_average_loss)
    val_history.append(val_average_loss)
    print(f'train loss: {train_average_loss}')
    print(f'val loss: {val_average_loss}')

In [None]:
torch.save(model.state_dict, './output.pt')

In [None]:
plt.plot(train_history)
plt.plot(val_history)
plt.show()

In [None]:
from IPython.display import FileLink
FileLink(r'output.pt')