-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_file.py
103 lines (93 loc) · 3.66 KB
/
test_file.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
@authors: Helin Wang, Dongchao Yang
"""
import os
import torch
import argparse
from net import *
import librosa
import numpy as np
class Detection():
def __init__(self, model_pth, gpuid):
super(Detection, self).__init__()
model = net(16000, 1024, 320, 64, 50, 8000, 2, False)
dicts = torch.load(model_pth, map_location='cpu')
model.load_state_dict(dicts["model_state_dict"])
self.gpuid = tuple(gpuid)
self.device = torch.device('cuda:{}'.format(gpuid[0]) if len(gpuid) > 0 else 'cpu')
self.model = model.to(self.device)
def inference(self, file_path): #
self.model.eval()
with torch.no_grad():
(audio, _) = librosa.core.load(file_path, sr=16000, mono=True)
print("Compute on utterance {}...".format(file_path))
audio = torch.from_numpy(audio).to(self.device)
if audio.dim() == 1:
audio = torch.unsqueeze(audio, 0)
out = self.model(audio)
if out[0, 0] > 0.5:
return True
else:
return False
def test(self, file_path, threshold=0.5):
self.model.eval()
data_list = []
TP = 0
FP = 0
TN = 0
FN = 0
for root, dirs, files in os.walk(file_path):
for name in files:
file = os.path.join(root, name)
data_list.append(file)
for file in data_list:
with torch.no_grad():
label = int(file[-5])
(audio, _) = librosa.core.load(file, sr=16000, mono=True)
audio = torch.from_numpy(audio).to(self.device)
if audio.dim() == 1:
audio = torch.unsqueeze(audio, 0)
out = self.model(audio)
if out[0, 0] > threshold: # 大于预测为正类
print("Compute on utterance {}: True".format(file))
else:
print("Compute on utterance {}: False".format(file))
if out[0, 0] > threshold and label < threshold:# label 0, predict 1
FP += 1.0
elif out[0, 0] > threshold and label > threshold:# label 1, predict 1
TP += 1.0
elif out[0, 0] < threshold and label > threshold:# label 1, predict 0
FN += 1.0
elif out[0, 0] > threshold and label > threshold:# label 0, predict 0
TN += 1.0
Precision = TP/(TP + FP)
Recall = TP/(TP+FN)
# FRR = FR/(TR + FR)
ACC = (TP + TN)/(FP+TP+FN+TN)
print('Precision ',Precision)
print('Recall ',Recall)
print('ACC: {}'.format(ACC))
print('True Acceptance: {}'.format(TP))
print('False Acceptance: {}'.format(FP))
print('True Rejection: {}'.format(TN))
print('False Rejection: {}'.format(FN))
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'-model_pth', type=str, default='./checkpoint/net/best.pt', help="Path to model file.")
parser.add_argument(
'-gpuid', type=str, default='0', help='Enter GPU id number')
parser.add_argument(
'-file_path', type=str,
default='/home/pkusz/home/PKU_team/new_data/test/20211016T191855_121568658-1_165_1.wav',
help='test file path')
parser.add_argument(
'-test_single_file', type=bool,
default=False,
help='whether test single file')
args = parser.parse_args()
gpuid = [int(i) for i in args.gpuid.split(',')]
separation = Detection(args.model_pth, gpuid)
print(separation.inference(args.file_path))
if __name__ == "__main__":
main()