In [None]:
!pip install wandb
!pip install import-ipynb



In [None]:

from google.colab import drive
import import_ipynb
import sys

drive.mount('/content/drive/',force_remount = False)
proj_dir_path = '/content/drive/MyDrive/Study_materials/Voice_disorder_detection_project/'
sys.path.append(proj_dir_path)
%cd $proj_dir_path


In [None]:
from __future__ import print_function, division
import os
from random import sample
# from cv2 import transform
import torch
import re
import glob
from torch.utils.data import Dataset
from scipy.io import wavfile
import torch.nn as nn
import pandas as pd
# from src.models.yamnet_model import Identity
from sklearn.model_selection import train_test_split
from sklearn.utils import resample

from src.transforms.transforms import PadWhiteNoise,ToTensor,Truncate,ToOneHot,WaveformToInput,Inflate,Deflate,CFloat
from src.params.params import CommonParams as cfg
from src.params.params import Pathologies as PathologiesToIndex
from src.params.params import data_location

import librosa
import random, sys

# from torch_audiomentations import Compose, PitchShift,TimeInversion,AddBackgroundNoise
from torchaudio.transforms import Spectrogram,TimeStretch, TimeMasking, FrequencyMasking, InverseSpectrogram,GriffinLim
import wandb
import numpy as np

In [None]:
class patients_dataset:
  def __init__(self,audio_files_dir,excel_sheet_path = ""):
    assert excel_sheet_path, f'No excel sheet path given'
    self.excel_path = excel_sheet_path
    self.dir = audio_files_dir


  def create_binary_labeled_dataset(self):
    patients_sheet = pd.read_excel(self.excel_path)
    # remove patients with incomplete data within spreadsheet,
    # I've taken date of birth column as an indicator
    patients_sheet = patients_sheet[patients_sheet['DOB'].notna()]
    diagnostics = patients_sheet['dysphonia diagnosis']
    # Convert diagnostics to healthy/ill labels
    diagnostics = diagnostics.apply(lambda x:0 if x.lower() == 'none' else 1)
    db_indices = patients_sheet['#'].to_numpy()
    files_array = []
    all_patients = glob.glob(self.dir + "/*.wav")
    all_patients_id = [int(re.findall(r'patient_(\d+).wav', path.lower())[0]) for path in all_patients]
    relevant_patient_ids = np.intersect1d(db_indices,all_patients_id)
    files_array = [os.path.join(data_location['preprocessed_data'],'patient_' + str(id))for id in relevant_patient_ids]
    self.df = pd.DataFrame({'id':relevant_patient_ids,'filepathway':files_array,'diagnosis':diagnostics})
    self.df_size = len(self.df)


  def train_val_test_split(self,split =(0.8,0.1,0.1),seed = None):
    train_size,val_size,test_size = split
    train_size = int(train_size*self.df_size)
    val_size = int(val_size*self.df_size)
    test_size = self.df_size - train_size - val_size
    self.train_set, val_test_sets = train_test_split(self.df,train_size=train_size,test_size=(val_size+test_size),random_state=seed,shuffle=True) #first split into train/ test+val
    self.val_set, self.test_set = train_test_split(val_test_sets,train_size=val_size,test_size=test_size,random_state=seed,shuffle=True) #first split into train/ test+val
    return self.train_set, self.val_set , self.test_set

  def binary_label_up_sample(self,df_to_up, up_amount=1, plot = False):
    import matplotlib.pyplot as plt
    sick_samples = df_to_up[df_to_up.iloc[:,-1] == 1]
    healthy_samples = df_to_up[df_to_up.iloc[:,-1] == 0]
    if len(sick_samples) > len(healthy_samples):
      majority = sick_samples
      minority = healthy_samples
    else:
      majority = healthy_samples
      minority = sick_samples
    difference = len(majority) - len(minority)
    new_samples = resample(minority,replace=True,n_samples= difference*up_amount,random_state = 0)
    new_minority = pd.concat([minority, new_samples])
    entire_df = pd.concat([new_minority,majority])
    #shuffle new samples into the df
    entire_df = entire_df.sample(frac=1).reset_index(drop=True)

    if (plot):
      print("entire df size: ", len(df_to_up))
      print("Major label size is: ", len(majority))
      print("Minor label size is: ", len(minority))
      print("Difference in labels is: ", difference)
      print(f"new minor label size is: {len(minority)}+{len(new_samples)} = {len(new_minority)} ")
      print("new df size is: ", len(entire_df))
      bin = df_to_up.iloc[:,-1].value_counts()
      bin_up = entire_df.iloc[:,-1].value_counts()

      plt.subplot(1, 2, 1) # row 1, col 2 index 1
      bin.plot(kind='bar')
      plt.xlabel('1: Sick     0: Healthy')
      plt.ylabel('Count')
      plt.title('Distribution of patients originally')

      plt.subplot(1, 2, 2) # index 2
      bin_up.plot(kind='bar')
      plt.xlabel('1: Sick     0: Healthy')
      plt.ylabel('Count')
      plt.title('After upsampling')
      plt.show()
    return entire_df

  def upsample_all_subsets(self,up_amount=1):
    try:
      self.train_set = self.binary_label_up_sample(df_to_up=self.train_set,up_amount=1, plot = False)
      self.val_set = self.binary_label_up_sample(df_to_up=self.val_set,up_amount=1, plot = False)
      self.test_set = self.binary_label_up_sample(df_to_up=self.test_set,up_amount=1, plot = False)
    except Exception as e:
      print("Upsampling error: ",e)

In [None]:
def create_transformations(augmentations):
    print(augmentations)
    name_to_aug = {
        "TimeStretch":TimeStretch(fixed_rate=0.8),
        "FrequencyMasking":
        FrequencyMasking(
            freq_mask_param=80
        ),
        "TimeMasking":
            TimeMasking(time_mask_param=80),
    }

    transforms = [name_to_aug[augmentation] for augmentation in augmentations]
    transforms = [Spectrogram()]+ transforms + [CFloat(),InverseSpectrogram()]
    return nn.Sequential(*transforms)

default_label_transforms = nn.Sequential(ToOneHot())

def create_datasets(root_dir,split:tuple,hp,filter_gender=None,seed=None,**kwargs)->list():
    assert sum(split)==1, f"Splits fraction array should sum up to 1"
    split = np.cumsum(split)
    files_array = []
    if hp["filter_gender"] != None:
        root_dir = os.path.join(root_dir,hp["filter_gender"])
    for root, dirs, files in os.walk(root_dir):
        files_array += [os.path. join(root,f) for f in files if not f.startswith('.') and  f.endswith('.wav')]
    if seed == None:
        seed = random.randrange(sys.maxsize)
    random.Random(seed).shuffle(files_array)
    split = [int(s * len(files_array))for s in split][:-1]

    files_split = np.split(files_array, split)
    hp['seed']=seed
    splits = [SvdExtendedVoiceDataset(sp,hp,**kwargs) for sp in files_split]

    return splits

# TODO: make this inherit AudioFolderDataset
class SvdExtendedVoiceDataset(Dataset):
    """Saarbruken blah blah"""

    def __init__(self, files, hp,label_transform=default_label_transforms, class_definitions=None,classification_binary=True):
    # audiomentations = create_transformations(hp['augmentations'])
        data_transform = nn.Sequential(ToTensor(),PadWhiteNoise(50000),Truncate(50000))

        self.data_transform = data_transform
        self.label_transform = label_transform
        self.classification_binary = classification_binary
        self.class_definitions=class_definitions if class_definitions!= None else PathologiesToIndex# Placeholder for actual definitions
        self.seed = hp['seed']
        self.files = files
            # assert len(files) == 0 or (len(files) != 0 and
        assert len(self.files) > 0,f"Directory should not be empty, it is {self.files}"

    def _load_wav(self,wav_file):
        return wavfile.read(wav_file)
    def _get_class(self,wav_file_path):
        print("cd is ",self.class_definitions)
        return self.class_definitions[wav_file_path.split('/')[-3]],  wav_file_path.split('/')[-3]
    def __len__(self):
        return len(self.files)
    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.item()
        if isinstance(index,list):
            index = index[0]
        samplerate, data = self._load_wav(self.files[index])
        classification_index,classification = self._get_class(self.files[index])

        if self.data_transform != None:
            data = self.data_transform(data)
        if self.label_transform != None and not self.classification_binary:
            label = self.label_transform(classification)
        if self.classification_binary:
            label = classification_index!=0
        return {'data':data, 'sampling_rate':samplerate,'classification':label,'original_class':classification}

class SvdCutOffShort(SvdExtendedVoiceDataset):
    """Saarbruken blah blah, cut off samples smaller than 0.96"""
    def __init__(self, files, hp,label_transform=default_label_transforms, class_definitions=None,classification_binary=True,overfit_test = False):
        super().__init__(files,hp,label_transform,class_definitions,classification_binary)
        import random
        self.files = [file for file in self.files if librosa.get_duration(filename=file)>=cfg.VOICE_SAMPLE_MIN_LENGTH]
        random.shuffle(self.files)
        if overfit_test:
            self.files = self.files[:40]


class SvdWindowedDataset(SvdExtendedVoiceDataset):
    """Saarbruken blah blah, cut off samples smaller than 0.96"""
    def __init__(self, files, hp,label_transform=default_label_transforms, class_definitions=None,classification_binary=True,overfit_test = False,delta=1):
        super().__init__(files,hp,label_transform,class_definitions,classification_binary)
        import random
        def _filter_pitch(filename):
            if hp["filter_pitch"] != None:
                return filename.split("_")[1].split(".")[0] in hp["filter_pitch"]
            return True
        def _filter_sound(filename):
            if hp["filter_letter"] != None:
                # assert False, f"filename split {filename.split('_')}"
                return filename.split("_")[0].split("-")[1] in hp["filter_letter"]
            return True
        self.delta=delta
        self.files = [file for file in self.files if _filter_sound(file) and _filter_pitch(file)]
        self.files = self._inflate_sound_files(self.files)

        if overfit_test:
            random.shuffle(self.files)
            self.files = self.files[:40]

    def _load_wav(self,wav_file):
        window_index = wav_file["window_index"]
        file_path = wav_file["path"]
        sample_rate,data = wavfile.read(file_path)
        start_index = int(self.delta*window_index*cfg.SVD_SAMPLE_RATE)
        end_index = int(self.delta*(window_index+1)*cfg.SVD_SAMPLE_RATE)
        return sample_rate,data[start_index:end_index]
    def _get_class(self,wav_file):
        wav_file_path = wav_file["path"]
        return self.class_definitions[wav_file_path.split('/')[-3]], wav_file_path.split('/')[-3]

    def _inflate_sound_files(self,files):
        def get_window_count(f):
            length = librosa.get_duration(filename=f)*cfg.SVD_SAMPLE_RATE
            length = 0 if length-cfg.SVD_SAMPLE_RATE<0 else length-cfg.SVD_SAMPLE_RATE
            return int(length/(self.delta*cfg.SVD_SAMPLE_RATE))+1
        return [{'path':file,'window_index':i} for file in files for i in range(get_window_count(file))]

if __name__ == "__main__":
    from tqdm import tqdm
    from torch.utils.data import DataLoader

    hp = {}
    hp["augmentations"] = None
    hp["filter_pitch"] = None
    hp["filter_letter"] = None
    hp["filter_gender"] = None

    # sets = create_datasets(data_location['preprocessed_data'],split=(0.6,0.2,0.2),hp=hp,filter_gender=None)
    sets = patients_dataset(audio_files_dir=data_location["preprocessed_data"],excel_sheet_path = data_location["data_spreadsheet"])
    sets.create_binary_labeled_dataset()
    sets.train_val_test_split()
    sets.upsample_all_subsets()