# Training baseline model notebook

## Imports and constants

In [27]:
import sys
import os
sys.path.append('../')

import torch
import onnx
import onnxruntime as ort
import librosa
import pandas as pd
import numpy as np

from glob import glob
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from ml_base.model import BaselineBirdClassifier

In [2]:
TRAIN_DATA_PATH = os.path.realpath('../data/train_data_s3/')
MODEL_SAVE_PATH = os.path.realpath('../data/models')
VAL_FRAC = 0.1
BATCH_SIZE = 16
SAMPLE_LEN_SEC = 10
SAMPLE_RATE = 32000
EPOCHS_COUNT = 2
EVAL_EVERY_EPOCHS = 10

## Searching for all downloaded audio files

Use `download_data_s3.py` to download all the "checked" data

In [3]:
all_files = glob(os.path.join(TRAIN_DATA_PATH, '**/*.ogg'))

In [4]:
len(all_files)

1021

Classes count:

In [5]:
len(glob(os.path.join(TRAIN_DATA_PATH, '*')))

149

In [6]:
all_df = pd.DataFrame({'file_path': all_files})
all_df['class'] = all_df['file_path'].apply(lambda filepath: os.path.basename(os.path.dirname(filepath)))

Converting class to class id:

In [7]:
CLASS2ID = {classname: i for i, classname in enumerate(all_df['class'].unique())}
ID2CLASS = {i: classname for classname, i in CLASS2ID.items()}

In [9]:
all_df['class_id'] = all_df['class'].apply(CLASS2ID.get)

In [10]:
all_df

Unnamed: 0,file_path,class,class_id
0,E:\_UNIVER\UCU\2 sem\MLOps\bird-project\data\t...,asbfly,0
1,E:\_UNIVER\UCU\2 sem\MLOps\bird-project\data\t...,ashdro1,1
2,E:\_UNIVER\UCU\2 sem\MLOps\bird-project\data\t...,ashdro1,1
3,E:\_UNIVER\UCU\2 sem\MLOps\bird-project\data\t...,ashpri1,2
4,E:\_UNIVER\UCU\2 sem\MLOps\bird-project\data\t...,ashpri1,2
...,...,...,...
1016,E:\_UNIVER\UCU\2 sem\MLOps\bird-project\data\t...,zitcis1,148
1017,E:\_UNIVER\UCU\2 sem\MLOps\bird-project\data\t...,zitcis1,148
1018,E:\_UNIVER\UCU\2 sem\MLOps\bird-project\data\t...,zitcis1,148
1019,E:\_UNIVER\UCU\2 sem\MLOps\bird-project\data\t...,zitcis1,148


### Train-val split

In [11]:
val_df = all_df.sample(int(VAL_FRAC * len(all_df)))
train_df = all_df.loc[~all_df.index.isin(val_df.index)]
len(train_df), len(val_df)

(919, 102)

### Datasets creation

In [12]:
class AudioDataset(Dataset):
    def __init__(self, paths, labels=None, sample_len=SAMPLE_LEN_SEC, sr=SAMPLE_RATE):
        assert labels is None or len(paths) == len(labels), "Data and targets should be of the same samples count"
        self.paths = paths
        self.labels = labels
        self.sample_len = sample_len
        self.sr = sr

    def __getitem__(self, i):
        audio, sr = librosa.load(self.paths[i], sr=self.sr)

        if self.sample_len is not None:
            desired_len = self.sample_len * sr
            if len(audio) >desired_len:
                audio = audio[:desired_len]
            else:
                audio =  np.pad(audio, (0, desired_len - len(audio)))

        if self.labels is not None:
            return audio, self.labels[i]
        else:
            return audio

    def __len__(self):
        return len(self.paths)

In [13]:
train_ds = AudioDataset(train_df['file_path'].tolist(), train_df['class_id'].tolist())
val_ds = AudioDataset(val_df['file_path'].tolist(), val_df['class_id'].tolist(), sample_len=None)

In [14]:
train_ds[3]

(array([ 9.1765696e-07, -3.1738637e-05, -4.2713637e-06, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00], dtype=float32),
 2)

In [15]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=1)

## Training

### Training preparation

In [16]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [17]:
model = BaselineBirdClassifier(len(CLASS2ID), sr=SAMPLE_RATE).to(device)

STFT kernels created, time used = 0.0278 seconds


In [18]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.RAdam(model.parameters(), lr=1e-3)

### Training itself

In [19]:
batch_num = 0

min_eval_loss = np.inf
for epoch in tqdm(range(EPOCHS_COUNT), desc='Epoch'):
    running_loss = 0.
    last_loss = 0.

    for audios, labels in train_loader:
        audios = audios.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(audios)

        loss = loss_fn(outputs, labels)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
        if batch_num % EVAL_EVERY_EPOCHS == EVAL_EVERY_EPOCHS - 1:
            last_loss = running_loss / EVAL_EVERY_EPOCHS
            print(f'Batch {batch_num + 1}. Loss: {last_loss:.6f}.', end=' ')
            running_loss = 0.

            model.eval()
            eval_running_loss = 0.
            with torch.no_grad():
                for audios, labels in val_loader:
                    audios = audios.to(device)
                    labels = labels.to(device)

                    outputs = model(audios)
                    loss = loss_fn(outputs, labels)

                    eval_running_loss += loss.item()
            
            print(f'Val loss: {eval_running_loss/len(val_ds):.6f}.')

            if eval_running_loss < min_eval_loss:
                min_eval_loss = eval_running_loss
                print("Saving the model")

                torch.save(model.state_dict(), os.path.join(MODEL_SAVE_PATH, f'baseline-{len(CLASS2ID)}.pt'))

            model.train()
        batch_num += 1

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Batch 10. Loss: 4.997318. Val loss: 5.002830.
Saving the model
Batch 20. Loss: 5.001742. Val loss: 5.002751.
Saving the model
Batch 30. Loss: 5.001061. Val loss: 5.002618.
Saving the model
Batch 40. Loss: 5.004285. Val loss: 5.002512.
Saving the model
Batch 50. Loss: 5.001182. Val loss: 5.002398.
Saving the model


Epoch:  50%|█████     | 1/2 [02:06<02:06, 126.35s/it]

Batch 60. Loss: 1.001172. Val loss: 5.002195.
Saving the model
Batch 70. Loss: 5.000711. Val loss: 5.001924.
Saving the model
Batch 80. Loss: 4.998894. Val loss: 5.001608.
Saving the model
Batch 90. Loss: 4.998351. Val loss: 5.001476.
Saving the model
Batch 100. Loss: 5.001676. Val loss: 5.001209.
Saving the model
Batch 110. Loss: 4.998280. Val loss: 5.000982.
Saving the model


Epoch: 100%|██████████| 2/2 [03:43<00:00, 111.62s/it]


## ONNX export

Loading the best model:

In [20]:
model.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, f'baseline-{len(CLASS2ID)}.pt'), map_location=torch.device('cpu')))
model.eval()

BaselineBirdClassifier(
  (feature_extractor): Sequential(
    (0): STFT(n_fft=1024, Fourier Kernel size=(513, 1, 1024), iSTFT=False, trainable=False)
    (1): MelScale()
    (2): AmplitudeToDB()
  )
  (backbone): LSTM(64, 32, num_layers=3, batch_first=True, dropout=0.05, bidirectional=True)
  (head): Sequential(
    (0): Dropout(p=0.1, inplace=False)
    (1): Linear(in_features=64, out_features=16, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=16, out_features=149, bias=True)
    (5): Sigmoid()
  )
)

And exporting it to ONNX:

In [22]:
torch_input = torch.randn(8, SAMPLE_RATE*SAMPLE_LEN_SEC)
torch.onnx.export(model.cpu(),
                  torch_input,
                 os.path.join(MODEL_SAVE_PATH, f'baseline-{len(CLASS2ID)}.onnx'),
                 export_params=True,
                 do_constant_folding=True,
                 input_names = ['input'],
                 output_names = ['output'],
                 dynamic_axes={'input' : {0: 'batch_size', 1: 'sample_length'},
                               'output' : {0: 'batch_size'}}
)

  if self.num_samples < self.pad_amount:
  if return_spec:
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


Checking that everything is OK:

In [28]:
onnx_model = onnx.load(os.path.join(MODEL_SAVE_PATH, f'baseline-{len(CLASS2ID)}.onnx'))
onnx.checker.check_model(onnx_model)

In [35]:
ort_sess = ort.InferenceSession(os.path.join(MODEL_SAVE_PATH, f'baseline-{len(CLASS2ID)}.onnx'))
outputs = ort_sess.run(None, {'input': np.random.randn(1, 128302).astype(np.float32)})

In [39]:
outputs[0][0][:10], outputs[0].shape

(array([0.5311749 , 0.5527598 , 0.5498919 , 0.4817451 , 0.55824643,
        0.4672618 , 0.53212756, 0.48078275, 0.56325716, 0.49729684],
       dtype=float32),
 (1, 149))