# Load Libraries

In [1]:
import warnings
warnings.filterwarnings('ignore')

import glob
import io
import datasets
import os
import time
import joblib
import json
import csv
import pathlib
import librosa
import librosa.display

import pandas as pd
import numpy as np

from tqdm.notebook import tqdm
from collections import Counter
from pprint import pprint
%matplotlib inline

from sklearn.model_selection import train_test_split

from sklearn.metrics import (
    classification_report, 
    confusion_matrix, 
    accuracy_score,
    roc_curve,
    roc_auc_score, 
    precision_recall_curve,
    auc,
    precision_score, 
    recall_score, 
    f1_score
    )

import torch.utils.data
from torch.utils.data import Dataset, DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset, DatasetDict, Audio
from transformers import WhisperModel, WhisperFeatureExtractor, AdamW
# from transformers import WhisperEncoder
from transformers import WhisperProcessor

from functions_whisper_model import SpeechClassificationDataset, SpeechClassifier, train, evaluate

In [2]:
list_datasets = [
    # ['fsdkaggle'],    # 2% cough Counter({0: 1570, 1: 30})
    # ['virufy'],       # 100% cough Counter({1: 121})
    # ['esc50'],        # 2% cough Counter({0: 1960, 1: 40})
    # ['coughvid'],     # 30% cough Counter({1: 19777, 0: 10267})
    # ['coswara'],      # 25% cough Counter({0: 18914, 1: 5408})
    ['coswara', 'coughvid', 'esc50', 'fsdkaggle', 'virufy'], 
]

# Main

In [3]:
for window_length in [1, 5, 10]:
    df_results = []
    for datasets_name in list_datasets:
        datasets_name.sort()
        print('')
        print('#'*60)
        print(', '.join(datasets_name))
        print(f'Window Length: {window_length}')
        print('#'*60)
        
        dataset_str = '_'.join(datasets_name)
    
        if not os.path.exists(f'Results/Model_Whisper/{dataset_str}'):
            os.makedirs(f'Results/Model_Whisper/{dataset_str}')
        
        path_model_save = f'Results/Model_Whisper/{dataset_str}/whisper_best_model_{window_length}s.pt'

        ################################################################
        # Load Data
        ################################################################
        df_all = pd.DataFrame()
        for dataset in datasets_name:
            df = pd.read_csv(f'Results/Sliced_Wav/dataset_{dataset}_{window_length}s.csv')
            df_all = pd.concat([df_all, df], axis=0)
        df_all = df_all.reset_index(drop=True)

        ################################################################
        # Prepare Data
        ################################################################
        df_all['filepath'] = '/home/l083319/Cough_Related/' + df_all['filepath']
        df_all = df_all[df_all['mean_amplitude'] > 0.005].reset_index(drop=True)

        for col in ['prob', 'status', 'age', 'Unnamed: 0', 'gender', 'mean_amplitude']:
            if col in df_all.columns:
                df_all = df_all.drop([col], axis=1)
        
        audio_df = df_all.rename(columns={
            'label': 'classID', 
            'filepath': 'full_path',
        })
        
        print(audio_df.shape)
        audio_df = audio_df.sample(frac=1).groupby('classID').head(1000).reset_index(drop=True)
    
        print(Counter(audio_df['dataset']))
        print(Counter(audio_df['classID']))
    
        train_df, temp_df = train_test_split(audio_df, test_size=0.3, random_state=42)
        val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)
        
        print('Train:', len(train_df))
        print('Val  :', len(val_df))
        print('Test :', len(test_df))
        
        train_audio_dataset = datasets.Dataset.from_dict({
            "audio": train_df["full_path"].tolist(),
            "labels": train_df["classID"].tolist()    
            }).cast_column("audio", Audio(sampling_rate=16_000))
        
        test_audio_dataset = datasets.Dataset.from_dict({
            "audio": test_df["full_path"].tolist(),
            "labels": test_df["classID"].tolist()
            }).cast_column("audio", Audio(sampling_rate=16_000))
        
        val_audio_dataset = datasets.Dataset.from_dict({
            "audio": val_df["full_path"].tolist(),
            "labels": val_df["classID"].tolist()
            }).cast_column("audio", Audio(sampling_rate=16_000))

        ################################################################
        # Load Whisper
        ################################################################
        model_checkpoint = "openai/whisper-base"
        processor = WhisperProcessor.from_pretrained(model_checkpoint)
        whisper_model = WhisperModel.from_pretrained("openai/whisper-base")
        encoder = whisper_model.encoder  # this is the encoder module
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
        train_dataset = SpeechClassificationDataset(train_audio_dataset, processor, encoder)
        test_dataset = SpeechClassificationDataset(test_audio_dataset, processor, encoder)
        val_dataset = SpeechClassificationDataset(val_audio_dataset, processor, encoder)
        
        batch_size = 32
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
        num_labels = 2
        
        model = SpeechClassifier(num_labels, encoder).to(device)
        optimizer = AdamW(model.parameters(), lr=2e-5, betas=(0.9, 0.999), eps=1e-08)
        criterion = nn.CrossEntropyLoss()
    
        num_epochs = 1
        
        # state_dict = torch.load('/home/l083319/Cough_Related/Results/Model/whisper_best_model.pt')
        # encoder = WhisperModel.from_pretrained(model_checkpoint)
        # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # num_labels = 2
        # model = SpeechClassifier(num_labels, encoder).to(device)
        # model.load_state_dict(state_dict)

        ################################################################
        # Train Whisper
        ################################################################
        start = time.time() 
        train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs, path_model_save)
        end = time.time() 
        
        print(f"Total runtime of the program is {round(end - start, 3)}s") 
        print('Training Done!')
    
        ################################################################
        # Test Whisper
        ################################################################
        print('Load Whisper Model')
        
        state_dict = torch.load(path_model_save)
        
        # Create a new instance of the model and load the state dictionary
        num_labels = 2
        model = SpeechClassifier(num_labels, encoder).to(device)
        model.load_state_dict(state_dict)
        
        print('Evaluate Data')
        _, _, _, all_labels, all_preds, all_probs = evaluate(model, test_loader, optimizer, criterion, device)
        
        print(classification_report(all_labels, all_preds))
        print('ACC:', round(accuracy_score(all_labels, all_preds), 3))
        print('Test Done!')
    
    
        y_test = all_labels
        y_predict = all_preds
        
        acc = accuracy_score(y_test, y_predict)
        cm = confusion_matrix(y_test, y_predict)
        print(cm)
        
        lr_fpr, lr_tpr, _ = roc_curve(y_test, all_probs[:,1])
        roc_auc = auc(lr_fpr, lr_tpr)
        precision, recall, _ = precision_recall_curve(y_test, all_probs[:,1])
        pr_auc = auc(recall, precision)
        
        pre = precision_score(y_test, y_predict)
        rec = recall_score(y_test, y_predict)
        f1 = f1_score(y_test, y_predict)
        tn, fp, fn, tp = confusion_matrix(y_test, y_predict).ravel()
        spe = tn / (tn + fp)
        sen = rec
        
        columns = ['dataset', 'dataset_counter', 'label_count', 'window_length',
                   'model',
                   'acc', 'sen', 'spe', 'pre', 'rec', 'f1', 'auc', 'auprc', 'cm']  
        
        results = [[
            dataset_str,
            Counter(audio_df['dataset']),
            Counter(audio_df['classID']),
            window_length, 'Whisper',
            acc, sen, spe, pre, rec, f1,
            roc_auc, pr_auc, cm]]

        df_results.append(results)
    
        test_df['pred'] = all_preds
        test_df.to_csv(f'Results/Model_Whisper/{dataset_str}/results_test_data_{window_length}s.csv', index=False)
    
        # Check which data is predicted wrongly
        test_df_wrong = test_df[test_df['classID'] != test_df['pred']]

    df_results = pd.DataFrame(results, columns=columns)
    df_results.to_csv(f'Results/Model_Whisper/{dataset_str}/results_summary_{window_length}s.csv', index=False)
    
print('All Done!')


############################################################
coswara, coughvid, esc50, fsdkaggle, virufy
Window Length: 1
############################################################
(45949, 6)
Counter({'coswara': 779, 'coughvid': 734, 'fsdkaggle': 315, 'esc50': 164, 'virufy': 8})
Counter({0: 1000, 1: 1000})
Train: 1400
Val  : 300
Test : 300


2025-09-18 08:22:37.868920: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-18 08:22:37.879398: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-09-18 08:22:37.895233: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-09-18 08:22:37.899920: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-09-18 08:22:37.911988: I tensorflow/core/platform/cpu_feature_guar

Epoch 1/1, Batch 20/44, Train Loss: 0.5722, Run-time: 109.623s
Epoch 1/1, Batch 40/44, Train Loss: 0.3241, Run-time: 103.509s


Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/1, Val Loss: 0.2185, Val Accuracy: 0.9467, Val F1: 0.9466, Best Accuracy: 0.9467
Total runtime of the program is 271.127s
Training Done!
Load Whisper Model
Evaluate Data


Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.92      0.90      0.91       136
           1       0.92      0.93      0.92       164

    accuracy                           0.92       300
   macro avg       0.92      0.91      0.92       300
weighted avg       0.92      0.92      0.92       300

ACC: 0.917
Test Done!
[[122  14]
 [ 11 153]]

############################################################
coswara, coughvid, esc50, fsdkaggle, virufy
Window Length: 5
############################################################
(11194, 6)
Counter({'coswara': 781, 'coughvid': 674, 'fsdkaggle': 390, 'esc50': 133, 'virufy': 22})
Counter({0: 1000, 1: 1000})
Train: 1400
Val  : 300
Test : 300
Epoch 1/1, Batch 20/44, Train Loss: 0.5365, Run-time: 102.526s
Epoch 1/1, Batch 40/44, Train Loss: 0.1623, Run-time: 102.342s


Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/1, Val Loss: 0.1980, Val Accuracy: 0.9267, Val F1: 0.9267, Best Accuracy: 0.9267
Total runtime of the program is 260.722s
Training Done!
Load Whisper Model
Evaluate Data


Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.96      0.90      0.93       148
           1       0.91      0.96      0.93       152

    accuracy                           0.93       300
   macro avg       0.93      0.93      0.93       300
weighted avg       0.93      0.93      0.93       300

ACC: 0.93
Test Done!
[[133  15]
 [  6 146]]

############################################################
coswara, coughvid, esc50, fsdkaggle, virufy
Window Length: 10
############################################################
(7197, 6)
Counter({'coswara': 711, 'coughvid': 591, 'fsdkaggle': 420, 'esc50': 227, 'virufy': 51})
Counter({0: 1000, 1: 1000})
Train: 1400
Val  : 300
Test : 300
Epoch 1/1, Batch 20/44, Train Loss: 0.5656, Run-time: 102.439s
Epoch 1/1, Batch 40/44, Train Loss: 0.1782, Run-time: 101.744s


Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/1, Val Loss: 0.2115, Val Accuracy: 0.9333, Val F1: 0.9333, Best Accuracy: 0.9333
Total runtime of the program is 259.123s
Training Done!
Load Whisper Model
Evaluate Data


Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.96      0.94      0.95       141
           1       0.94      0.96      0.95       159

    accuracy                           0.95       300
   macro avg       0.95      0.95      0.95       300
weighted avg       0.95      0.95      0.95       300

ACC: 0.95
Test Done!
[[132   9]
 [  6 153]]
All Done!


In [4]:
# Create an empty DataFrame to hold the combined data
combined_df = pd.DataFrame()

for window_length in [1, 5, 10]:
    df_results = []
    for datasets_name in list_datasets:
        datasets_name.sort()
        
        dataset_str = '_'.join(datasets_name)
    
        df = pd.read_csv(f'Results/Model_Whisper/{dataset_str}/results_summary_{window_length}s.csv')
        combined_df = pd.concat([combined_df, df], ignore_index=True)

# Display or save the result
print(combined_df)
combined_df.to_csv(f'Results/Model_Whisper/{dataset_str}/results_summary_All.csv')

                                   dataset  \
0  coswara_coughvid_esc50_fsdkaggle_virufy   
1  coswara_coughvid_esc50_fsdkaggle_virufy   
2  coswara_coughvid_esc50_fsdkaggle_virufy   

                                     dataset_counter  \
0  Counter({'coswara': 779, 'coughvid': 734, 'fsd...   
1  Counter({'coswara': 781, 'coughvid': 674, 'fsd...   
2  Counter({'coswara': 711, 'coughvid': 591, 'fsd...   

                   label_count  window_length    model       acc       sen  \
0  Counter({0: 1000, 1: 1000})              1  Whisper  0.916667  0.932927   
1  Counter({0: 1000, 1: 1000})              5  Whisper  0.930000  0.960526   
2  Counter({0: 1000, 1: 1000})             10  Whisper  0.950000  0.962264   

        spe       pre       rec        f1       auc     auprc  \
0  0.897059  0.916168  0.932927  0.924471  0.962383  0.967086   
1  0.898649  0.906832  0.960526  0.932907  0.981108  0.980885   
2  0.936170  0.944444  0.962264  0.953271  0.965610  0.956189   

                