In [21]:
# from scipy.io import wavfile
import matplotlib.pyplot as plt
import numpy as np
from scipy.fft import rfft, rfftfreq
from scipy.signal import butter, lfilter, medfilt, cheby1, sosfilt # decimate
# from scipy.stats import median_abs_deviation
import pandas as pdf
# import math
import librosa

%matplotlib inline

data_path = "/home/collins/Desktop/projects/baymax/data/Respiratory_Sound_Database/"
audio_file_path = f"{data_path}audio_and_txt_files"

# patient_number = "122" # Those with pneumonia in the dataset are 122, 135, 140, 191, 219 and 226
# sound_location = ""

# Colors
plot_colors = {
    "Tc": "#32F2F5",
    "Al": "#FA9A0A",
    "Ll": "#F53532",
    "Pl": "#2AC126",
    "Ar": "#0A6AFA",
    "Pr": "#BD26C1",
    "Lr": "#32F2F5"
}

def remove_spikes(data):
    # Calculate the median absolute deviation (MAD) of the signal
    # mad = math.floor(median_abs_deviation(data))

    # Determine the window size based on the MAD
    # ws = mad * 10
    
    # Ensure window size is odd
    ws = 501 # ws if ws % 2 else ws + 1
    
    # print(ws)
    
    filtered_signal = medfilt(data)
    
#     b, a = butter(3, [50/(0.5*sampling_rate), 2500/(0.5*sampling_rate)], 'band')    
#     data = lfilter(b, a, data)

    return filtered_signal

# High pass filter function to remove heart sounds
def filterNoise(sampling_rate, data):
    # Butterworth band pass filter with cutoff frequencies
    # as a fraction of nquist frequency (1/2 the sampling rate)
    b, a = butter(3, [300/(0.5*sampling_rate), 1000/(0.5*sampling_rate)], 'band')

    signal = lfilter(b, a, data)
    
    # Resample the resulting signal to 4000 Hz
    # signal = decimate(signal, int(sampling_rate/4000))

    return signal


def fourierTransform():
    # Fourier transform for Frequency domain
    # plt.subplot(1,2,2)
    plt.xlabel("Frequency")
    plt.ylabel("Power")
    plt.title(f"Frequency domain of {diagnosis} patient {patient_number}")
    
    for z in files:
        file_name = files[z]
        if file_name:
            sr, data = wavfile.read(f"{audio_file_path}/{file_name}")
            number_of_samples = data.shape[0]
            yf = rfft(data)
            # yf_normalized = np.abs(yf) / np.max(np.abs(yf))
            xf = rfftfreq(number_of_samples, 1/sr)
            plt.plot(xf, np.abs(yf), label=z)


    plt.legend()
    plt.savefig(f"waveforms/{patient_number}_{diagnosis}_freqdom.png")
    plt.close()

    return


def createSpectogram():
    return


def wavToArray(file):
    # sr, data = wavfile.read(f"{audio_file_path}/{file}")
    # Loading with librosa resamples the data to the specified
    # frequencey (4000Hz) and normalizes the data between 1 and -1
    data, sr = librosa.core.load(f"{audio_file_path}/{file}", sr=4000)
    return sr, data


def getMelSpectogram():
    return


def createlPlots(patient_number, files, diagnosis, chest_location=''):
    # Time domain
    fig, axs = plt.subplots(2,2, figsize=(15, 6))
    

    axs[0][0].set_xlabel("Time [s]")
    axs[0][1].set_xlabel("Time [s]")
    axs[1][0].set_xlabel("Time [s]")
    axs[0][0].set_ylabel("Amplitude")
    axs[0][1].set_ylabel("Amplitude")
    axs[1][0].set_ylabel("Frequency")
    axs[0][0].set_title(f"a)")
    axs[0][1].set_title(f"b)")
    axs[1][0].set_title(f"c)")

    for y in files:
        file_name = files[y]
        if file_name:
            sr, raw_signal = wavToArray(file_name)
            signal = filterNoise(sr, raw_signal) # Remove cardiac sounds
            time = np.linspace(0, raw_signal.shape[0]/sr, raw_signal.shape[0])
            rtime = np.linspace(0, signal.shape[0]/4000, signal.shape[0])

            axs[0][0].plot(time, raw_signal, label=y, color=plot_colors[y])
            axs[0][1].plot(rtime, signal, label=y, color=plot_colors[y])

            S = librosa.feature.melspectrogram(y=signal, sr=4000, n_mels=256, fmax=1500, n_fft=8000)
            # n_fft is the window size -> We make it 2X the sr, because we assume a complete
            # respiratory phase takes 2 seconds -> 1 second inhale, 1 second exhale
            S_dB = librosa.power_to_db(S, ref=np.max)
            img = librosa.display.specshow(S_dB, x_axis='time', y_axis='mel', sr=4000, ax=axs[1][0])
            fig.colorbar(img, ax=axs[1][0], format='%+2.0f dB')
            axs[1][0].set(title='c)')

    axs[0][0].legend()
    axs[0][1].legend()
    
    fig.tight_layout()
    plt.savefig(f"waveforms/{patient_number}_{diagnosis}_timedom{chest_location}.png")
    plt.close()

    return spec

In [22]:
# Prepare data - List with dicts of format
# {"patient_number": {"Al": {"annotation":{""}}, "Tc": "", "Ar": "", "": ""}}
import csv
import os


demographics_file = f"{data_path}/demographic_info.csv"
events_path = f"{data_path}events"
diagnosis_file = f"{data_path}patient_diagnosis.csv"


df = pdf.read_csv(diagnosis_file, header=None, names=["patient_no", "diagnosis"])
df = df[df['diagnosis'].isin(["Healthy", "Pneumonia"])] # Only deal with pneumonia and healthy


df2 = pdf.read_csv(demographics_file, header=None, names=["patient_no", "age", "sex", "adult_bmi", "child_weight", "child_height"])
df3 = pdf.merge(df, df2, on="patient_no")

In [23]:
# Add columns for each chest location
df3["Tc"] = None
df3["Al"] = None
df3["Pl"] = None
df3["Ll"] = None
df3["Ar"] = None
df3["Pr"] = None
df3["Lr"] = None

In [24]:
for d in os.listdir(audio_file_path):
    p = d.split("_")

    if p[4].split(".")[1] == "wav":
        df3.loc[df3['patient_no'] == int(p[0]), p[2]] = d

In [25]:
# for _, row in df3.iterrows():
#     plotSoundWaveform(
#         row['patient_no'],
#         {
#          "Tc": row['Tc'], "Al": row['Al'],
#          "Pl": row['Pl'], "Ll": row['Ll'],
#          "Ar": row['Ar'], "Pr": row['Pr'],
#          "Lr": row['Lr']
#         },
#         row['diagnosis']
#     )

In [26]:
# Plots for single chest locations
# for _, row in df3.iterrows():
#     for c_loc in ['Tc', 'Pl', 'Pr', 'Ll', 'Lr', 'Al', 'Ar']:
#         if row[c_loc]:
#             plotSoundWaveform(
#                 row['patient_no'],
#                 {
#                     c_loc: row[c_loc]
#                 },
#                 row['diagnosis'],
#                 c_loc
#             )

In [27]:
# Pneumonia patient and healthy patient samples
# for _, pp in df3.loc[df3["patient_no"].isin([135,159])].iterrows():
#     plotSoundWaveform(
#         pp['patient_no'],
#         {
#          "Tc": pp['Tc'], "Al": pp['Al'],
#          "Pl": pp['Pl'], "Ll": pp['Ll'],
#          "Ar": pp['Ar'], "Pr": pp['Pr'],
#          "Lr": pp['Lr']
#         },
#         pp['diagnosis']
#     )

In [28]:
# Plots for single chest locations
# from tensorflow import keras

for _, pp in df3.loc[df3["patient_no"].isin([135,159])].iterrows():
    createlPlots(
        pp['patient_no'],
        {
            "Ar": pp['Ar']
        },
        pp['diagnosis']
    )

In [29]:
from torch.utils.data import DataLoader, Dataset, random_split
import torchaudio

# ----------------------------
# Sound Dataset
# ----------------------------
class SoundDS(Dataset):
  def __init__(self, df, data_path):
    self.df = df
    self.data_path = str(data_path)
    self.duration = 4000
    self.sr = 44100
    self.channel = 2
    self.shift_pct = 0.4
            
  # ----------------------------
  # Number of items in dataset
  # ----------------------------
  def __len__(self):
    return len(self.df)    
    
  # ----------------------------
  # Get i'th item in dataset
  # ----------------------------
  def __getitem__(self, idx):
    # Absolute file path of the audio file - concatenate the audio directory with
    # the relative path
    audio_file = self.data_path + self.df.loc[idx, 'file']
    # Get the Class ID
    class_id = self.df.loc[idx, 'diagnosis']

    aud = AudioUtil.open(audio_file)
    # Some sounds have a higher sample rate, or fewer channels compared to the
    # majority. So make all sounds have the same number of channels and same 
    # sample rate. Unless the sample rate is the same, the pad_trunc will still
    # result in arrays of different lengths, even though the sound duration is
    # the same.
    reaud = AudioUtil.resample(aud, self.sr)
    rechan = AudioUtil.rechannel(reaud, self.channel)

    dur_aud = AudioUtil.pad_trunc(rechan, self.duration)
    shift_aud = AudioUtil.time_shift(dur_aud, self.shift_pct)
    sgram = AudioUtil.spectro_gram(shift_aud, n_mels=64, n_fft=1024, hop_len=None)
    aug_sgram = AudioUtil.spectro_augment(sgram, max_mask_pct=0.1, n_freq_masks=2, n_time_masks=2)

    return aug_sgram, class_id