In [14]:
import numpy as np
import matplotlib.pyplot as plt
import librosa
from scipy.signal import butter,filtfilt,find_peaks,find_peaks_cwt,medfilt,savgol_filter
from utils import butter_lowpass_filter, butter_highpass_filter, smooth, positions2onehot, normalize
import torch
import torch.nn as nn
import pandas as pd
from test_metric_utils import *
from model_unet import Unet
from model_AAE import FCAE
from tqdm import tqdm_notebook
import os
import time
%matplotlib inline

In [2]:
def inference(model,Speech,n_frame = 192,window_step = 32):
    assert n_frame%window_step ==0
    
    model.eval()
    Speech = np.expand_dims(Speech,axis=-1)
    EGG_pred = np.zeros_like(Speech)
    ratio = np.zeros_like(Speech)
    frame = 0
    
    while frame*window_step + n_frame <= len(Speech):
        tmp = Speech[frame*window_step:frame*window_step+n_frame]
        tmp = torch.Tensor([normalize(tmp)]).cuda() ## preprocessing
        
        result = model(tmp).cpu().detach().numpy()[0]
        EGG_pred[frame*window_step:frame*window_step+n_frame] += result ## postprocessing
        ratio[frame*window_step:frame*window_step+n_frame] +=1
        frame +=1
    
    for i in range(len(EGG_pred)):
        if ratio[i]!=0:
            EGG_pred[i] = EGG_pred[i]/ratio[i]
    return EGG_pred[:n_frame + frame*window_step]

In [4]:
AAE_cos = FCAE()
AAE_cos.encoder = nn.DataParallel(AAE_cos.encoder)
AAE_cos.decoder = nn.DataParallel(AAE_cos.decoder)
AAE_cos.encoder.load_state_dict(torch.load("./models/AAI/STZ-cosloss.pth"))
AAE_cos.decoder.load_state_dict(torch.load("./models/AAI/ZTE-cosloss.pth"))
AAE_cos.cuda()
# unet_model = nn.DataParallel(model)
# unet_model.load_state_dict(torch.load("./models/Unet/best_val.pth"))

FCAE(
  (encoder): DataParallel(
    (module): FCEncoder(
      (module_list): ModuleList(
        (0): Linear(in_features=192, out_features=175, bias=True)
        (1): Linear(in_features=175, out_features=125, bias=True)
        (2): Linear(in_features=125, out_features=100, bias=True)
      )
      (batches): ModuleList(
        (0): BatchNorm1d(175, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm1d(125, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (decoder): DataParallel(
    (module): FCDecoder(
      (module_list): ModuleList(
        (0): Linear(in_features=100, out_features=125, bias=True)
        (1): Linear(in_features=125, out_features=175, bias=True)
        (2): Linear(in_features=175, out_features=192, bias=True)
      )
      (batches): ModuleList(
        (0): BatchNorm1d(125, eps=1e-05, momentum=0.1,

In [5]:
unet_cos = Unet(4,10)
unet_cos = nn.DataParallel(unet_cos)
unet_cos.load_state_dict(torch.load("./models/Unet/best-cosloss.pth"))
unet_cos.cuda()

DataParallel(
  (module): Unet(
    (encoder): ModuleList(
      (0): Conv1d(1, 10, kernel_size=(15,), stride=(1,), padding=(7,))
      (1): Conv1d(10, 20, kernel_size=(15,), stride=(1,), padding=(7,))
      (2): Conv1d(20, 30, kernel_size=(15,), stride=(1,), padding=(7,))
      (3): Conv1d(30, 40, kernel_size=(15,), stride=(1,), padding=(7,))
    )
    (decoder): ModuleList(
      (0): Conv1d(80, 40, kernel_size=(5,), stride=(1,), padding=(2,))
      (1): Conv1d(70, 30, kernel_size=(5,), stride=(1,), padding=(2,))
      (2): Conv1d(50, 20, kernel_size=(5,), stride=(1,), padding=(2,))
      (3): Conv1d(30, 10, kernel_size=(5,), stride=(1,), padding=(2,))
    )
    (ebatch): ModuleList(
      (0): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): BatchNorm1d(40, eps=1

# TEST CMU Data
## Normal person data

In [21]:
def get_points(EGG):
#     EGG = normalize(EGG)
    peak_range = (7,15)
    DEGG = np.gradient(EGG,edge_order = 2)
    DEGG = medfilt(DEGG, 3)
    DEGG_low,EGG_low = DEGG.copy(),EGG.copy()
    
    DEGG_low[DEGG_low>0] =0
    EGG_low[EGG_low>0] =0
    
    DEGG_low = find_peaks_cwt(-DEGG_low,np.arange(*peak_range))
    EGG_low = find_peaks_cwt(-EGG_low,np.arange(*peak_range))

    DEGG_high = []
    for i in range(len(DEGG_low)-1):
        DEGG_high.append(DEGG_low[i] + np.argmax(DEGG[DEGG_low[i]:DEGG_low[i+1]]))

    EGG_high = []
    for i in range(len(EGG_low)-1):
        EGG_high.append(EGG_low[i] + np.argmax(EGG[EGG_low[i]:EGG_low[i+1]]))
    
    DEGG_high, EGG_high = np.array(DEGG_high),np.array(EGG_high)
    
    return DEGG_high/16000,DEGG_low/16000,EGG_high/16000,EGG_low/16000

In [None]:
window_step = 64
directory = './datasets/TestData/CMU/'
filelist = os.listdir(directory)

DEGG_high_metrics_final = {'IDR':0, 'MR' : 0, 'FAR':0 ,'IDA':0}
DEGG_low_metrics_final = {'IDR':0, 'MR' : 0, 'FAR':0 ,'IDA':0}
EGG_high_metrics_final = {'IDR':0, 'MR' : 0, 'FAR':0 ,'IDA':0}
EGG_low_metrics_final = {'IDR':0, 'MR' : 0, 'FAR':0 ,'IDA':0}
# filelist = filelist[:]
for file in tqdm_notebook(filelist):
    [Speech,EGG_true],sr = librosa.load(directory + file,sr=16000,mono=False)
    Speech = butter_lowpass_filter(Speech,2500,16000)
    itvs = librosa.effects.split(Speech,frame_length = int(192*0.75), hop_length = int(192*0.25),top_db = 10)
    
    S = []
    E = []
    for st,ed in itvs:
        S += list(Speech[st:ed])
        E += list(EGG_true[st:ed])
    
    Speech = np.array(S)
    EGG_true = np.array(E)
    
    EGG_pred = inference(unet_cos,Speech,n_frame = 192,window_step = window_step)
    EGG_pred = np.squeeze(EGG_pred,axis=-1)
    EGG_pred = smooth(EGG_pred, 49)
    
    l = min(len(EGG_pred),len(EGG_true))
    EGG_true =EGG_true[:l]
    EGG_pred =EGG_pred[:l]
    
    DEGG_high_true,DEGG_low_true, EGG_high_true, EGG_low_true = get_points(EGG_true)
    DEGG_high_pred,DEGG_low_pred, EGG_high_pred, EGG_low_pred = get_points(EGG_pred)
    
    DEGG_high_metrics = corrected_naylor_metrics(DEGG_high_true, DEGG_high_pred) ##GOI
    DEGG_low_metrics = corrected_naylor_metrics(DEGG_low_true, DEGG_low_pred) ##GCI
    EGG_high_metrics = corrected_naylor_metrics(EGG_high_true, EGG_high_pred)
    EGG_low_metrics = corrected_naylor_metrics(EGG_low_true, EGG_low_pred)
    
    DEGG_high_metrics_final['IDR'] += DEGG_high_metrics["identification_rate"]/len(filelist)
    DEGG_high_metrics_final['MR'] += DEGG_high_metrics["miss_rate"]/len(filelist)
    DEGG_high_metrics_final['FAR'] += DEGG_high_metrics["false_alarm_rate"]/len(filelist)
    DEGG_high_metrics_final['IDA'] += DEGG_high_metrics["identification_accuracy"]/len(filelist)

    DEGG_low_metrics_final['IDR'] += DEGG_low_metrics["identification_rate"]/len(filelist)
    DEGG_low_metrics_final['MR'] += DEGG_low_metrics["miss_rate"]/len(filelist)
    DEGG_low_metrics_final['FAR'] += DEGG_low_metrics["false_alarm_rate"]/len(filelist)
    DEGG_low_metrics_final['IDA'] += DEGG_low_metrics["identification_accuracy"]/len(filelist)

    EGG_high_metrics_final['IDR'] += EGG_high_metrics["identification_rate"]/len(filelist)
    EGG_high_metrics_final['MR'] += EGG_high_metrics["miss_rate"]/len(filelist)
    EGG_high_metrics_final['FAR'] += EGG_high_metrics["false_alarm_rate"]/len(filelist)
    EGG_high_metrics_final['IDA'] += EGG_high_metrics["identification_accuracy"]/len(filelist)

    EGG_low_metrics_final['IDR'] += EGG_low_metrics["identification_rate"]/len(filelist)
    EGG_low_metrics_final['MR'] += EGG_low_metrics["miss_rate"]/len(filelist)
    EGG_low_metrics_final['FAR'] += EGG_low_metrics["false_alarm_rate"]/len(filelist)
    EGG_low_metrics_final['IDA'] += EGG_low_metrics["identification_accuracy"]/len(filelist)
print("=========DEGG_high(GOI) detection========")
print("IDR : %.2f MR : %.2f FAR : %.2f IDA : %.2f ms"
      %(DEGG_high_metrics_final['IDR']*100,DEGG_high_metrics_final['MR']*100,DEGG_high_metrics_final['FAR']*100,DEGG_high_metrics_final['IDA']*1000))

print("=========DEGG_low(GCI) detection========")
print("IDR : %.2f MR : %.2f FAR : %.2f IDA : %.2f ms"
      %(DEGG_low_metrics_final['IDR']*100,DEGG_low_metrics_final['MR']*100,DEGG_low_metrics_final['FAR']*100,DEGG_low_metrics_final['IDA']*1000))

print("=========EGG_high detection========")
print("IDR : %.2f MR : %.2f FAR : %.2f IDA : %.2f ms"
      %(EGG_high_metrics_final['IDR']*100,EGG_high_metrics_final['MR']*100,EGG_high_metrics_final['FAR']*100,EGG_high_metrics_final['IDA']*1000))

print("=========EGG_low detection========")
print("IDR : %.2f MR : %.2f FAR : %.2f IDA : %.2f ms"
      %(EGG_low_metrics_final['IDR']*100,EGG_low_metrics_final['MR']*100,EGG_low_metrics_final['FAR']*100,EGG_low_metrics_final['IDA']*1000))


HBox(children=(IntProgress(value=0, max=1614), HTML(value='')))

  snr = abs(cwt[line[0][0], line[1][0]] / noises[line[1][0]])


# TEST saarbrucken Data
## Pathologic person data