### Audio-based respiratory diseases classification

The dataset source: [Respiratory Sound Database](https://www.kaggle.com/vbookshelf/respiratory-sound-database).

The dataset: __920__ annotated wav recordings of varying length - 10s to 90s from 126 patients.
The patients span all age groups - children, adults and the elderly.
The wav files were taken from different lung locations and with various recording devices.

In this example several audio features will be extracted first and  then deep learning model will be used as a classifier.
All the observations will be considered as independent irrespective of patient ids. However, in fact, patient id information with lung locations can be used further to improve the model.

The source code [GitHub link](https://github.com/alexander-pv/educational_repo/tree/master/ds_tasks/audio_respiratory_diseases_classification).

In [1]:
import os
import sys
import warnings
warnings.filterwarnings('ignore')

path_append_flag = True
if path_append_flag:
    sys.path.append('../')
    path_append_flag = False

import copy
import pandas as pd
import numpy as np
import torch
import torch .nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
import librosa
import torchaudio
import seaborn as sns
import multiprocessing as mp
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from src.dataset import RespiratoryDataset
from src.model import RDClassifier
from src import training

sns.set()

seed = 42

In [2]:
torch.cuda.is_available()

True

In [3]:
%load_ext watermark
%watermark -p pandas,numpy,seaborn,matplotlib,torch,torchaudio,librosa -vw

Python implementation: CPython
Python version       : 3.9.7
IPython version      : 7.31.0

pandas    : 1.3.4
numpy     : 1.20.3
seaborn   : 0.11.2
matplotlib: 3.5.0
torch     : 1.10.1
torchaudio: 0.10.1
librosa   : 0.8.1

Watermark: 2.3.0



In [4]:
dataset = RespiratoryDataset(
    annotation_path='../data/Respiratory_Sound_Database/patient_diagnosis.csv',
    audio_dir='../data/Respiratory_Sound_Database/audio_and_txt_files',
    transform=None,
    target_sample_rate=16000,
    name='default',
    seed=seed,
    
)

default subset classes mapping:
{'Asthma': 0, 'Bronchiectasis': 1, 'Bronchiolitis': 2, 'COPD': 3, 'Healthy': 4, 'LRTI': 5, 'Pneumonia': 6, 'URTI': 7}
default subset was created. Size: 920


In [5]:
dataset.df_prepared_annot.head()

Unnamed: 0,patient_id,class_label,wav_path
0,130,COPD,../data/Respiratory_Sound_Database/audio_and_t...
1,138,COPD,../data/Respiratory_Sound_Database/audio_and_t...
2,130,COPD,../data/Respiratory_Sound_Database/audio_and_t...
3,158,COPD,../data/Respiratory_Sound_Database/audio_and_t...
4,104,COPD,../data/Respiratory_Sound_Database/audio_and_t...


In [6]:
dataset.df_prepared_annot.groupby(by=['class_label'])['patient_id'].count().sort_values(ascending=False)

class_label
COPD              793
Pneumonia          37
Healthy            35
URTI               23
Bronchiectasis     16
Bronchiolitis      13
LRTI                2
Asthma              1
Name: patient_id, dtype: int64

Unfortunately, the dataset has only 2 LRTI wav records and 1 Asthma one. These labels should be dropped.
To deal with class-imbalance problem there are several options, for example:

1. Weighted loss function.
2. Weighted batching.

In [7]:
dataset = RespiratoryDataset(
    annotation_path='../data/Respiratory_Sound_Database/patient_diagnosis.csv',
    audio_dir='../data/Respiratory_Sound_Database/audio_and_txt_files',
    transform=None,
    target_sample_rate=16000,
    exclude_classes=('LRTI', 'Asthma'),
    name='default',
    seed=seed,
)

default subset classes mapping:
{'Bronchiectasis': 0, 'Bronchiolitis': 1, 'COPD': 2, 'Healthy': 3, 'Pneumonia': 4, 'URTI': 5}
default subset was created. Size: 917


In [8]:
dataset.df_prepared_annot.groupby(by=['class_label'])['patient_id'].count().sort_values(ascending=False)

class_label
COPD              793
Pneumonia          37
Healthy            35
URTI               23
Bronchiectasis     16
Bronchiolitis      13
Name: patient_id, dtype: int64

Stratified subset split is needed for train, validation and test. Fixed seed for pseudorandom number generator prevents from subsets overlaps.

In [9]:
train_size, val_size = 0.7, 0.6
n_mfcc, n_mels, n_bands = 64, 128, 6
sample_rate = 16000
batch_size = 32
n_epochs = 25
channels = 1
features_length = sum((n_mfcc, n_mels, n_bands+1))
exclude_classes = ('LRTI', 'Asthma')
device = 'cuda'
weighted_loss = False

num_workers = mp.cpu_count()//2
prefetch = 4

feature_extractors = (
    torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc),
    torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate, n_mels=n_mels),
    lambda x: torch.Tensor(
        librosa.feature.spectral_contrast(
            S=np.abs(librosa.stft(x.numpy()[0])), sr=sample_rate, n_bands=n_bands)
    ).unsqueeze(0)
)

In [10]:
datasets, dataloaders = {}, {}
dataloaders = {}
for name in ('train', 'val', 'test'):
    datasets.update(
        {
            name: RespiratoryDataset(
                annotation_path='../data/Respiratory_Sound_Database/patient_diagnosis.csv',
                audio_dir='../data/Respiratory_Sound_Database/audio_and_txt_files',
                transform=feature_extractors,
                target_sample_rate=sample_rate,
                exclude_classes=exclude_classes,
                name=name,
                train_size=train_size,
                val_size=val_size,
                seed=seed,

            )
        }
    )
    
    dataloaders.update(
        {name: DataLoader(datasets[name],
                         batch_size=batch_size,
                         shuffle=True,
                         num_workers=num_workers,
                         prefetch_factor=prefetch
                         )
         
        }
    )


class_dict = datasets['train'].class_map
obs_count = datasets['train'].subset_annotation['class_label'].value_counts(
).to_dict()
obs_count = {class_dict[k]: v for k, v in obs_count.items()}

train subset classes mapping:
{'Bronchiectasis': 0, 'Bronchiolitis': 1, 'COPD': 2, 'Healthy': 3, 'Pneumonia': 4, 'URTI': 5}
train subset was created. Size: 641
val subset classes mapping:
{'Bronchiectasis': 0, 'Bronchiolitis': 1, 'COPD': 2, 'Healthy': 3, 'Pneumonia': 4, 'URTI': 5}
val subset was created. Size: 165
test subset classes mapping:
{'Bronchiectasis': 0, 'Bronchiolitis': 1, 'COPD': 2, 'Healthy': 3, 'Pneumonia': 4, 'URTI': 5}
test subset was created. Size: 111


In [11]:
dataloaders

{'train': <torch.utils.data.dataloader.DataLoader at 0x7f8365715a90>,
 'val': <torch.utils.data.dataloader.DataLoader at 0x7f8365717df0>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x7f83656d9ac0>}

Model initialization and training. By default, loss function has no weights.

The number of epochs will be limited to 25 because of dataset size and Jupyter which saves epochs logs spam.

In [12]:
model = RDClassifier(input_shape=(batch_size, channels, features_length), n_classes=len(class_dict.keys()))

In [13]:
loss_weights = training.set_weighted_loss(label_to_count=obs_count, device=device) if weighted_loss else None
criterion = torch.nn.CrossEntropyLoss(weight=loss_weights)
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

In [14]:
training.model_train(
    net=model,
    train_loader=dataloaders['train'],
    val_loader=dataloaders['val'],
    epochs=n_epochs,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    device=device,
    save_path='../weights',
    class_dict=class_dict,
    model_name='rd_classifier',
    verbose=0
)

100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.28it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.24it/s]


[SaveBestModelCallback] val_loss was improved: inf -> 0.827966570854187. Model was saved.
[12.180683 sec.][Epoch 1] train_loss: 2.3886371655389667, val_loss: 0.827966570854187, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.29it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.23it/s]


[SaveBestModelCallback] val_loss was improved: 0.827966570854187 -> 0.42635267972946167. Model was saved.
[12.195124 sec.][Epoch 2] train_loss: 0.4988556755706668, val_loss: 0.42635267972946167, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.18it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.19it/s]


[12.397045 sec.][Epoch 3] train_loss: 0.75326386699453, val_loss: 0.49514541029930115, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.02it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.23it/s]


[SaveBestModelCallback] val_loss was improved: 0.42635267972946167 -> 0.40727174282073975. Model was saved.
[13.366245 sec.][Epoch 4] train_loss: 0.2772985593182966, val_loss: 0.40727174282073975, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.18it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.16it/s]


[12.426057 sec.][Epoch 5] train_loss: 0.4763962486758828, val_loss: 0.46971023082733154, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.13it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.17it/s]


[SaveBestModelCallback] val_loss was improved: 0.40727174282073975 -> 0.32516953349113464. Model was saved.
[12.908104 sec.][Epoch 6] train_loss: 0.22997818852309138, val_loss: 0.32516953349113464, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.05it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.05it/s]


[SaveBestModelCallback] val_loss was improved: 0.32516953349113464 -> 0.3183504641056061. Model was saved.
[13.507227 sec.][Epoch 7] train_loss: 0.20976010174490511, val_loss: 0.3183504641056061, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  1.92it/s]
100%|█████████████████████████████████████████████| 6/6 [00:03<00:00,  1.90it/s]


[14.128706 sec.][Epoch 8] train_loss: 1.6567334469873458, val_loss: 0.3954455554485321, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:11<00:00,  1.85it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.05it/s]


[SaveBestModelCallback] val_loss was improved: 0.3183504641056061 -> 0.28562048077583313. Model was saved.
[14.570497 sec.][Epoch 9] train_loss: 0.1935860259036417, val_loss: 0.28562048077583313, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.07it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.05it/s]


[13.088851 sec.][Epoch 10] train_loss: 0.1764235484879464, val_loss: 0.3442363142967224, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.08it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.19it/s]


[12.851225 sec.][Epoch 11] train_loss: 3.7020758234430104, val_loss: 0.3527678847312927, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  1.98it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.13it/s]


[13.420935 sec.][Epoch 12] train_loss: 2.249194824602455, val_loss: 0.39600199460983276, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.14it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.17it/s]


[12.592979 sec.][Epoch 13] train_loss: 0.23690630169582505, val_loss: 0.35751792788505554, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.09it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.08it/s]


[12.920663 sec.][Epoch 14] train_loss: 0.9464375688694417, val_loss: 0.2891508936882019, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.13it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.19it/s]


[12.605518 sec.][Epoch 15] train_loss: 0.21819687425158918, val_loss: 0.47460365295410156, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.07it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.15it/s]


[12.960888 sec.][Epoch 16] train_loss: 0.1519299757637782, val_loss: 0.3634588420391083, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.11it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.14it/s]


[12.756235 sec.][Epoch 17] train_loss: 0.13840530323795974, val_loss: 0.4350879192352295, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.11it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.12it/s]


[SaveBestModelCallback] val_loss was improved: 0.28562048077583313 -> 0.21398834884166718. Model was saved.
[13.081311 sec.][Epoch 18] train_loss: 0.11442522157449275, val_loss: 0.21398834884166718, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.13it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.21it/s]


[12.579839 sec.][Epoch 19] train_loss: 0.10258940717903897, val_loss: 0.33991217613220215, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.07it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.09it/s]


[13.025398 sec.][Epoch 20] train_loss: 0.07551983585290145, val_loss: 0.32219451665878296, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.07it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.11it/s]


[12.989878 sec.][Epoch 21] train_loss: 0.06822906743036583, val_loss: 0.3603914678096771, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.11it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.17it/s]


[12.715775 sec.][Epoch 22] train_loss: 0.31295500590931624, val_loss: 0.27668967843055725, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  1.99it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.09it/s]


[13.4094 sec.][Epoch 23] train_loss: 0.12359182850923389, val_loss: 0.3752691447734833, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.07it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.07it/s]

[13.064438 sec.][Epoch 24] train_loss: 0.0657175074738916, val_loss: 0.3324386477470398, learning_rate: 0.001.





In [28]:
best_model = RDClassifier(input_shape=(batch_size, channels, features_length), n_classes=len(class_dict.keys()))
best_model.load_state_dict(torch.load(os.path.join('..', 'weights', 'pytorch_rd_classifier_epoch_18.pth')))
best_model.to(device)
best_model.eval()

RDClassifier(
  (conv1): Conv1d(1, 128, kernel_size=(5,), stride=(1,))
  (conv2): Conv1d(128, 256, kernel_size=(5,), stride=(1,))
  (maxpool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(256, 512, kernel_size=(5,), stride=(1,))
  (dropout): Dropout(p=0.2, inplace=False)
  (linear1): Linear(in_features=46592, out_features=1024, bias=True)
  (linear2): Linear(in_features=1024, out_features=6, bias=True)
  (relu): ReLU()
)

In [32]:
report_dict, report_txt = training.evaluate(best_model, dataloaders['test'], device)
print(report_txt)

100%|█████████████████████████████████████████████| 4/4 [00:02<00:00,  1.62it/s]

                precision    recall  f1-score   support

Bronchiectasis       1.00      0.50      0.67         2
 Bronchiolitis       1.00      0.50      0.67         2
          COPD       0.98      0.99      0.98        96
       Healthy       0.57      1.00      0.73         4
     Pneumonia       0.75      0.75      0.75         4
          URTI       1.00      0.33      0.50         3

      accuracy                           0.95       111
     macro avg       0.88      0.68      0.72       111
  weighted avg       0.96      0.95      0.94       111






Altough the f1-score if quite high, the precision level for Healthy class is only 0.57 with recall 1.0 which means 
that lots of recordings with diseases were classified as healthy.

Let's compare the model with the other one trained with weighted loss function

In [35]:
weighted_loss = True
model = RDClassifier(input_shape=(batch_size, channels, features_length), n_classes=len(class_dict.keys()))
loss_weights = training.set_weighted_loss(label_to_count=obs_count, device=device) if weighted_loss else None
criterion = torch.nn.CrossEntropyLoss(weight=loss_weights)
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

In [36]:
training.model_train(
    net=model,
    train_loader=dataloaders['train'],
    val_loader=dataloaders['val'],
    epochs=n_epochs,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    device=device,
    save_path='../weights',
    class_dict=class_dict,
    model_name='rd_classifier',
    verbose=0
)

100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.25it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.19it/s]


[SaveBestModelCallback] val_loss was improved: inf -> 1.593238353729248. Model was saved.
[15.639461 sec.][Epoch 1] train_loss: 5.343869060277939, val_loss: 1.593238353729248, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.24it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.32it/s]


[11.952528 sec.][Epoch 2] train_loss: 3.1308335959911346, val_loss: 1.7156362533569336, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.25it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.20it/s]


[12.073997 sec.][Epoch 3] train_loss: 0.8634307552128888, val_loss: 1.625204086303711, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.21it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.18it/s]


[SaveBestModelCallback] val_loss was improved: 1.593238353729248 -> 1.3312671184539795. Model was saved.
[12.579539 sec.][Epoch 4] train_loss: 1.2088197115808725, val_loss: 1.3312671184539795, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.13it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.16it/s]


[SaveBestModelCallback] val_loss was improved: 1.3312671184539795 -> 1.1905274391174316. Model was saved.
[12.932196 sec.][Epoch 5] train_loss: 0.7687633177265525, val_loss: 1.1905274391174316, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.08it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.04it/s]


[13.060271 sec.][Epoch 6] train_loss: 1.210467946715653, val_loss: 1.5179728269577026, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.13it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.12it/s]


[12.723468 sec.][Epoch 7] train_loss: 0.6292246170341969, val_loss: 2.29421067237854, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.10it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.11it/s]


[12.848886 sec.][Epoch 8] train_loss: 1.6639399891719222, val_loss: 1.6943998336791992, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.06it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.12it/s]


[13.039041 sec.][Epoch 9] train_loss: 0.7450844729319215, val_loss: 1.2438884973526, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.06it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.02it/s]


[13.155625 sec.][Epoch 10] train_loss: 0.5597356810467318, val_loss: 1.4617385864257812, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.09it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.12it/s]


[12.87257 sec.][Epoch 11] train_loss: 0.5405278741382062, val_loss: 2.4444162845611572, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.12it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.09it/s]


[12.786447 sec.][Epoch 12] train_loss: 0.5331633116584271, val_loss: 1.7160789966583252, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.10it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.05it/s]


[12.961545 sec.][Epoch 13] train_loss: 0.4292934502736898, val_loss: 1.5918339490890503, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.05it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.04it/s]


[13.205779 sec.][Epoch 14] train_loss: 0.3967362390831113, val_loss: 1.7146241664886475, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.07it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.09it/s]


[13.045783 sec.][Epoch 15] train_loss: 0.6751249814406037, val_loss: 1.420485496520996, learning_rate: 0.001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.08it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.08it/s]


[13.015201 sec.][Epoch 16] train_loss: 0.36333544133231044, val_loss: 1.7126591205596924, learning_rate: 0.0001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.07it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.06it/s]


[13.048181 sec.][Epoch 17] train_loss: 0.3200715179555118, val_loss: 1.5231831073760986, learning_rate: 0.0001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.10it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.02it/s]


[12.980816 sec.][Epoch 18] train_loss: 0.26708702580071986, val_loss: 1.4537360668182373, learning_rate: 0.0001.


100%|███████████████████████████████████████████| 21/21 [00:09<00:00,  2.12it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.04it/s]


[12.839237 sec.][Epoch 19] train_loss: 0.37886928860098124, val_loss: 1.3515390157699585, learning_rate: 0.0001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.08it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.00it/s]


[SaveBestModelCallback] val_loss was improved: 1.1905274391174316 -> 0.8690584897994995. Model was saved.
[13.421531 sec.][Epoch 20] train_loss: 0.2520470902090892, val_loss: 0.8690584897994995, learning_rate: 0.0001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.07it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.08it/s]


[13.032807 sec.][Epoch 21] train_loss: 0.31553258933126926, val_loss: 1.4591901302337646, learning_rate: 0.0001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.10it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.08it/s]


[12.910144 sec.][Epoch 22] train_loss: 0.2062397941481322, val_loss: 1.473633050918579, learning_rate: 0.0001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.07it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.17it/s]


[12.90195 sec.][Epoch 23] train_loss: 0.21640073833987117, val_loss: 1.265589714050293, learning_rate: 0.0001.


100%|███████████████████████████████████████████| 21/21 [00:10<00:00,  2.04it/s]
100%|█████████████████████████████████████████████| 6/6 [00:02<00:00,  2.03it/s]

[13.275742 sec.][Epoch 24] train_loss: 4.47603911254555, val_loss: 1.5099372863769531, learning_rate: 0.0001.





In [37]:
best_model = RDClassifier(input_shape=(batch_size, channels, features_length), n_classes=len(class_dict.keys()))
best_model.load_state_dict(torch.load(os.path.join('..', 'weights', 'pytorch_rd_classifier_epoch_20.pth')))
best_model.to(device)
best_model.eval()

RDClassifier(
  (conv1): Conv1d(1, 128, kernel_size=(5,), stride=(1,))
  (conv2): Conv1d(128, 256, kernel_size=(5,), stride=(1,))
  (maxpool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(256, 512, kernel_size=(5,), stride=(1,))
  (dropout): Dropout(p=0.2, inplace=False)
  (linear1): Linear(in_features=46592, out_features=1024, bias=True)
  (linear2): Linear(in_features=1024, out_features=6, bias=True)
  (relu): ReLU()
)

In [38]:
report_dict, report_txt = training.evaluate(best_model, dataloaders['test'], device)
print(report_txt)

100%|█████████████████████████████████████████████| 4/4 [00:02<00:00,  1.64it/s]

                precision    recall  f1-score   support

Bronchiectasis       0.17      0.50      0.25         2
 Bronchiolitis       0.50      0.50      0.50         2
          COPD       0.99      0.95      0.97        96
       Healthy       1.00      0.75      0.86         4
     Pneumonia       0.50      0.75      0.60         4
          URTI       1.00      0.67      0.80         3

      accuracy                           0.91       111
     macro avg       0.69      0.69      0.66       111
  weighted avg       0.95      0.91      0.92       111






Now, with the weighted loss function, f1-score was reduced. However, precision and recall metrics for healthy class became more balanced. The precision is 1.0 which means that all the samples that were predicted as healthy are correctly classified.

There are several possible ways to improve sound-based diseases classification:

1. Add more audio features.
2. Add several channels with various audio features aggregation. Thus, the task can be fully represented as an image classification.
3. Add recurrent layers
4. Changing dropout.
5. Changing sampling rate.
6. Use deeper models.
7. Audio augmentation to make classification more robust.

To orchestrate these experiments MLFlow and Neptune frameworks, for example, can be used. 

Models calibration is also important here because neural networks tend to be overconfident in their predictions. However, fair calibration requires an additional data subset.

From the production point of view, torch models are good because they can be easily converted to one and be inferenced via onnxruntime/onnxruntime server on the backend.