### Pre-process pictures for *train*
1. compute **nmf** for all wav files and save as `npy` file
2. get **labels** as ground truth from `wav` files path and save as `npy` file
3. save paths of all `npy` file in `train.h5` and `val.h5`

In [None]:
import os
import numpy as np
from util import audio_import

idxs=[401,402,486,513,558,642,776,889]
names=['accordion','acoustic_guitar','cello','trumpet','flute','xylophone','saxophone','violin']

path = 'h:/Study/bpfile/dataset/audios/'
wav_files = []
bases_filepath = []
labels_filepath = []
solo_bases_filepath = []
solo_labels_filepath = []

# 遍历数据集 `path` 路径下所有的 wav 文件，记录他们的路径
# wav_files[0] = 'h:/study/bpfile/dataset/audios/duet/acoustic_guitarviolin/1.wav'
for root,dirs,files in os.walk(path, topdown=True):
    for name in files:
        if '.wav' in name:
            wav_files.append(os.path.normcase(root+os.path.sep+name).replace('\\','/'))
            
# 计算 bases 和 labels
for i, fname in enumerate(wav_files):
    label = np.zeros(1000)   # 1000个乐器类别
    '''
    # 每个 wav 对应的 NMF 结果保存为 *.npy 文件，对应路径存入 bases
    # bases_filepath[0] = 'h:/study/bpfile/dataset/audios/duet/acoustic_guitarviolin/1_bases.npy'
    # np.array(nmf_base[0]).shape = (2401, 16)
    nmf_base = audio_import.nmf([fname])
    np.save(fname[:-4]+'_bases'+'.npy', nmf_base[0])
    bases_filepath.append(fname[:-4]+'_bases'+'.npy')'''
    
    # 根据文件路径名确定当前音频中乐器的 label，作为 ground truth 训练
    # labels_filepath[0] = 'h:/study/bpfile/dataset/audios/duet/acoustic_guitarviolin/1_labels.npy'
    for instrument, index in zip(names, idxs):
        if instrument in fname:
            label[index] = 1
    print(sum(label))
    np.save(fname[:-4]+'_labels'+'.npy', label)
    labels_filepath.append(fname[:-4]+'_labels'+'.npy')
    
    print(bases_filepath[-1])
    print(labels_filepath[-1])
    

### generate `train.h5` and `val.h5`

In [None]:
import h5py
import random

dataset_num = len(wav_files)
train_num = int(0.7*dataset_num)

# 随机选取 70% 的数据作为验证集
train_bases = random.sample(range(dataset_num), train_num)

# 路径 utf-8 编码，否则报错
train_bases_encode = []
train_labels_encode = []
val_bases_encode = []
val_labels_encode = []
solo_bases_encode = []
solo_labels_encode = []
for i in range(dataset_num):
    if i in train_bases:
        train_bases_encode.append(bases_filepath[i].encode(encoding='utf-8', errors='strict'))
        train_labels_encode.append(labels_filepath[i].encode(encoding='utf-8', errors='strict'))
    else:
        val_bases_encode.append(bases_filepath[i].encode(encoding='utf-8', errors='strict'))
        val_labels_encode.append(labels_filepath[i].encode(encoding='utf-8', errors='strict'))

h5f = h5py.File('dataset/train.h5', 'w')
h5f.create_dataset('bases', data=train_bases_encode)
h5f.create_dataset('labels', data=train_labels_encode)
h5f.close()

h5f = h5py.File('dataset/val.h5', 'w')
h5f.create_dataset('bases', data=val_bases_encode)
h5f.create_dataset('labels', data=val_labels_encode)
h5f.close()

### generate h5 files for only `solo`

In [8]:
import os
import numpy as np

idxs=[401,402,486,513,558,642,776,889]
names=['accordion','acoustic_guitar','cello','trumpet','flute','xylophone','saxophone','violin']

path = 'h:/Study/bpfile/dataset/audios/'
wav_files = []
bases_filepath = []
labels_filepath = []
solo_bases_filepath = []
solo_labels_filepath = []

# 遍历数据集 `path` 路径下所有的 wav 文件，记录他们的路径
# wav_files[0] = 'h:/study/bpfile/dataset/audios/duet/acoustic_guitarviolin/1.wav'
for root,dirs,files in os.walk(path, topdown=True):
    for name in files:
        if '.wav' in name and 'duet' not in root:
            wav_files.append(os.path.normcase(root+os.path.sep+name).replace('\\','/'))
            #print(wav_files[-1])
            
# 计算 bases 和 labels
for i, fname in enumerate(wav_files):
    label = np.zeros(1000)   # 1000个乐器类别
    '''
    # 每个 wav 对应的 NMF 结果保存为 *.npy 文件，对应路径存入 bases
    # bases_filepath[0] = 'h:/study/bpfile/dataset/audios/duet/acoustic_guitarviolin/1_bases.npy'
    # np.array(nmf_base[0]).shape = (2401, 16)
    nmf_base = audio_import.nmf([fname])
    np.save(fname[:-4]+'_bases'+'.npy', nmf_base[0])'''
    bases_filepath.append(fname[:-4]+'_bases'+'.npy')
    
    # 根据文件路径名确定当前音频中乐器的 label，作为 ground truth 训练
    # labels_filepath[0] = 'h:/study/bpfile/dataset/audios/duet/acoustic_guitarviolin/1_labels.npy'
    for instrument, index in zip(names, idxs):
        if instrument in fname:
            label[index] = 1
    print(sum(label))
    #np.save(fname[:-4]+'_labels'+'.npy', label)
    labels_filepath.append(fname[:-4]+'_labels'+'.npy')
    
    print(bases_filepath[-1])
    print(labels_filepath[-1])
    

1.0
h:/study/bpfile/dataset/audios/solo/accordion/1_bases.npy
h:/study/bpfile/dataset/audios/solo/accordion/1_labels.npy
1.0
h:/study/bpfile/dataset/audios/solo/accordion/10_bases.npy
h:/study/bpfile/dataset/audios/solo/accordion/10_labels.npy
1.0
h:/study/bpfile/dataset/audios/solo/accordion/11_bases.npy
h:/study/bpfile/dataset/audios/solo/accordion/11_labels.npy
1.0
h:/study/bpfile/dataset/audios/solo/accordion/12_bases.npy
h:/study/bpfile/dataset/audios/solo/accordion/12_labels.npy
1.0
h:/study/bpfile/dataset/audios/solo/accordion/13_bases.npy
h:/study/bpfile/dataset/audios/solo/accordion/13_labels.npy
1.0
h:/study/bpfile/dataset/audios/solo/accordion/14_bases.npy
h:/study/bpfile/dataset/audios/solo/accordion/14_labels.npy
1.0
h:/study/bpfile/dataset/audios/solo/accordion/15_bases.npy
h:/study/bpfile/dataset/audios/solo/accordion/15_labels.npy
1.0
h:/study/bpfile/dataset/audios/solo/accordion/16_bases.npy
h:/study/bpfile/dataset/audios/solo/accordion/16_labels.npy
1.0
h:/study/bpfil

In [9]:
import h5py

dataset_num = len(wav_files)

# 路径 utf-8 编码，否则报错
train_bases_encode = []
train_labels_encode = []
for i in range(dataset_num):
    train_bases_encode.append(bases_filepath[i].encode(encoding='utf-8', errors='strict'))
    train_labels_encode.append(labels_filepath[i].encode(encoding='utf-8', errors='strict'))

h5f = h5py.File('solo.h5', 'w')
h5f.create_dataset('bases', data=train_bases_encode)
h5f.create_dataset('labels', data=train_labels_encode)
h5f.close()


  from ._conv import register_converters as _register_converters


In [None]:
from util.extract_key_bases import extract_key_bases


## train ...

### Pre-process pictures for *test*
1. compute **nmf** for all wav files and save as `npy` file
2. compute **labels** by feat_extrator and save as `npy` file
3. save paths of all `npy` file in `train.h5` and `val.h5`

In [2]:
import os
import random
import numpy as np
from util import audio_import
from util.feat_extractor import load_model, get_CAM, feat_pred

idxs=[401,402,486,513,558,642,776,889]
names=['accordion','acoustic_guitar','cello','trumpet','flute','xylophone','saxophone','violin']

test_path = 'h:/Study/bpfile/testset25/'
wav_files = []
img_files = []
bases_filepath = []
labels_filepath = []
locations = {}

# 寻找音频路径与图片路径
# wav_dir = 'h:/Study/bpfile/testset25/testaudio/'
# img_dir = 'h:/Study/bpfile/testset25/testimage/'
wav_dir = os.path.normcase(os.path.join(test_path, 'gt_audio/')).replace('\\','/')
img_dir = os.path.normcase(os.path.join(test_path, 'testimage/')).replace('\\','/')
'''
for cur_dir in os.listdir(test_path):
    if 'audio' in cur_dir:
        os.path.normcase(os.path.join(test_path, cur_dir, '/')).replace('\\','/')
    if 'image' in cur_dir:
        os.path.normcase(os.path.join(test_path, cur_dir, '/')).replace('\\','/')
'''
        
# 寻找所有 wav 的文件名、图片文件所在文件夹
# wav_files[0] = 'accordion_1_saxophone_1.wav'
# img_files[0] = 'accordion_1_saxophone_1'
for f in os.listdir(wav_dir):   # 剔除非 wav 文件
    if '_gt1.wav' not in f and '_gt2.wav' not in f and '.wav' in f:
        wav_files.append(f)
for f in os.listdir(img_dir):   # 剔除目录下所有文件，只保留文件夹
    if os.path.isdir(img_dir + f):
        img_files.append(f)
        locations[f+'.mp4'] = []
        
# 排序确保 wav 文件与图片文件夹相对应
wav_files.sort()
img_files.sort()
load_model()

# 计算 bases 和 labels
for wav_fname, img_folder in zip(wav_files, img_files):
    assert wav_fname[:-4] == img_folder    # 确保 wav 文件与图片文件夹相对应
    '''
    # 计算 NMF 并将结果保存在 npy 文件中
    # bases_filepath[0] = 'h:/Study/bpfile/testset25/testaudio/accordion_1_saxophone_1_bases.npy'
    # np.array(nmf_base[0]).shape = (2401, 16)
    nmf_base = audio_import.nmf([wav_dir+wav_fname])
    np.save(wav_dir+wav_fname[:-4]+'_bases'+'.npy', nmf_base[0])
    bases_filepath.append(wav_dir+wav_fname[:-4]+'_bases'+'.npy')
    '''
    imgs = []
    # 获得当前文件夹下所有图片（绝对）路径，并随机抽样
    # img_folder = 'accordion_1_saxophone_1'
    # imgs[0] = '000001.jpg'
    for img in os.listdir(img_dir+img_folder):
        if '.jpg' in img or '.png' in img:
            imgs.append(img)
    imgs = random.sample(imgs, 30)  # 由于图片数量过多，每个视频中只随机抽取 50 张图片进行预测
    
    probs=np.zeros([8])
    location = np.zeros([8])
    # 计算图像的 label，并获得每种乐器的定位（列号）
    for img in imgs:
        probs1 = feat_pred(img_dir+img_folder, img)
        probs = probs + np.array(probs1)
        #(locate, CAMs, heatmap) = get_CAM(img_dir+img_folder, 'results', img)  # heatmap 保存为文件
        locate = get_CAM(img_dir+img_folder, 'results', img)
        location = location + locate
        '''
        print(np.argmax(CAMs))
        print(CAMs)
        plt.figure()
        plt.imshow(0.3*CAMs)
        plt.figure()
        plt.imshow(heatmap)
        '''
        
    locations[img_folder+'.mp4'] = location
    #print(probs)
    print(location)
    softmax = np.zeros(1000)
    #for i in range(len(probs)):
    #    softmax[idxs[i]] = probs[i]
    #np.save(img_dir+img_folder+'/labels'+'.npy', softmax)
    labels_filepath.append(img_dir+img_folder+'/labels'+'.npy')
    
    #print(bases_filepath[-1])
    print(labels_filepath[-1])
    

  nn.init.kaiming_normal(m.weight.data)


[41638. 28137. 23348. 26450. 21528. 29065. 38568. 18622.]
h:/study/bpfile/testset25/testimage/accordion_1_saxophone_1/labels.npy
[44184. 22514. 45354. 28776. 23004. 42053. 36036. 49654.]
h:/study/bpfile/testset25/testimage/accordion_2_acoustic_guitar_2/labels.npy
[ 7318. 10006. 16459. 11275. 19657. 10952.  7481. 12197.]
h:/study/bpfile/testset25/testimage/accordion_2_cello_3/labels.npy
[42947. 37449. 35432. 34241. 40664. 41247. 40478. 33705.]
h:/study/bpfile/testset25/testimage/accordion_3_acoustic_guitar_3/labels.npy
[16221. 12031. 14893. 13090. 11563. 16330. 13942. 12595.]
h:/study/bpfile/testset25/testimage/accordion_4_saxophone_3/labels.npy
[30113. 27039. 24987. 32208. 32384. 29942. 27369. 26171.]
h:/study/bpfile/testset25/testimage/accordion_5_cello_1/labels.npy
[11254. 16168. 14288. 23516. 14177. 22383. 22966. 27767.]
h:/study/bpfile/testset25/testimage/accordion_6_saxophone_2/labels.npy
[16209. 20243. 29032. 36223. 23102. 24569. 12941. 37013.]
h:/study/bpfile/testset25/testimage

In [8]:
import h5py

test_bases_encode = []
test_labels_encode = []

for i in range(len(bases_filepath)):
    test_bases_encode.append(bases_filepath[i].encode(encoding='utf-8', errors='strict'))
    test_labels_encode.append(labels_filepath[i].encode(encoding='utf-8', errors='strict'))

h5f = h5py.File('dataset/test.h5', 'w')
h5f.create_dataset('bases', data=test_bases_encode)
h5f.create_dataset('labels', data=test_labels_encode)
h5f.close()

In [11]:
import json

for k in locations.keys():
    locations[k] = list(locations[k])

with open('dataset/locations.json','w') as f:
    json.dump(locations, f)

In [2]:
from util import feat_extractor
from util.feat_extractor import feat_pred, feat_pred_by_seg, load_model

load_model()
img_dir = '.'
imgname = '000053.jpg'

print(feat_pred(img_dir,imgname))
left,right = feat_pred_by_seg(img_dir,imgname)
print('left ', left)
print('right ', right)

position = [0]*8
left_max = max(left)
right_max = max(right)
if left_max > right_max:
    position[left.index(left_max)] = 1      # 确定左侧乐器种类
    right[left.index(left_max)] = 0
    position[right.index(max(right))] = 2   # 确定右侧乐器种类
else:
    position[right.index(right_max)] = 2    # 确定右侧乐器种类
    left[right.index(right_max)] = 0
    position[left.index(max(left))] = 1     # 确定左侧乐器种类
print(position)

[4.8531398e-08, 4.203821e-10, 3.0119654e-07, 1.6685843e-09, 7.592569e-08, 0.58534557, 4.232982e-09, 8.99717e-09]
left  [3.3698166e-06, 1.5242497e-08, 1.61589e-05, 3.455999e-08, 1.1760211e-07, 2.40524e-06, 2.1710822e-08, 2.5602446e-06]
right  [2.5203516e-08, 9.114263e-12, 1.9799065e-09, 6.485805e-11, 5.8738996e-09, 0.11121938, 1.4527327e-10, 1.1737265e-10]
[0, 0, 1, 0, 0, 2, 0, 0]


In [1]:
from Evaluate import Evaluate
Evaluate('H:/Study/bpfile/testset25/result_json','H:/Study/bpfile/testset25/result_audio','H:/Study/bpfile/testset25/gt_audio')

([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [9.777360795119105,
  -10.09248636743938,
  10.32364742145534,
  3.057444494070724,
  9.529792246779623,
  -0.6161393633356158,
  3.1117141337320224,
  6.190048117861831,
  2.956510307435125,
  0.04340889726304364,
  3.4566056205549684,
  12.543886389811345,
  -13.151761483136939,
  9.114885674476607,
  -4.766174396467576,
  9.825438513856191,
  6.061299399695891,
  6.848018085174301,
  3.2044448447364045,
  2.2413911499002683,
  -3.1581192467039965,
  5.333239158105924,
  3.3113456254585802,
  4.027945542808805,
  9.941331588664816,
  -8.116518557928018,
  -5.899301523814823,
  6.688577810126427,
  5.115476319864616,
  -1.8459559096740747,
  -2.151789376974273,
  3.8127202483788496,
  -0.8400440406089104,
  3.181078741122456,
  -10.673818773370423,
  9.308022146046875,
  9.815917894217304,
  1.9693316002253545,
  6.0703857165648465,
  -0.6313290081705814,
  1.6345979103775776,
  -1.0661233936990187,
  -11.8

In [9]:
# 2018/12/31 01:00
accur = [1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0]
snr = [11.347610961650894,
  -9.632960200620044,
  8.70021952059402,
  2.5079513008718477,
  9.652218240462485,
  -0.1527975835486499,
  2.0595129473305334,
  5.1103995341998925,
  1.561645286855171,
  -1.313358260349896,
  0.9039634811239038,
  12.656031646268426,
  -13.439097840668804,
  8.627718589954496,
  -8.30973963605447,
  7.548450334721046,
  2.0754457673556868,
  3.9472773987273646,
  2.4069430738422453,
  1.534099362765119,
  -1.4902245012246584,
  4.6937194856242614,
  2.605547560421433,
  4.373700425100127,
  9.875911990888353,
  -8.932793377600113,
  -7.911671922013421,
  6.714576715646839,
  2.952031236943138,
  -4.591842359176853,
  -3.664687551573451,
  2.461675985931538,
  -0.4154217807845178,
  3.641413631044385,
  -6.275929831730655,
  10.901451073989607,
  4.391146147909199,
  -5.780470227883133,
  6.2859942439767735,
  -1.4655790234925066,
  -0.03865087571590465,
  -1.7699167771162454,
  -11.366997010864136,
  14.822491686259596,
  2.4814621671509456,
  -0.5059430078344603,
  8.13116251910653,
  -4.994748791387616,
  6.488218187777051,
  -0.03958338260151157]

print(sum(accur))
print(sum(snr)/len(snr))

12
1.5873515312450368


In [10]:
cnt = 0
for i in range(len(snr)):
    cnt = cnt + int(snr[i]>0)
print(cnt)

30


In [4]:
# 2018/12/31 11:00
accur = [1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
snr = [11.025667273241673,
  -8.616144398241026,
  8.70033389108073,
  2.5079283361029905,
  9.652112032056358,
  -0.15310698486073632,
  2.05947453003453,
  5.1101977566897325,
  1.562369217599937,
  -1.3130846512270777,
  0.9041050961813324,
  12.655941904901958,
  -13.438770299126723,
  8.627968336798402,
  -8.309988122393,
  7.548632691892022,
  2.0753651542099916,
  3.9472065291720515,
  2.407020892892854,
  1.5340932165582575,
  -1.489769209593552,
  4.694007757740754,
  0.7759560881044568,
  3.236182103382213,
  9.875781822568525,
  -8.932016372898001,
  -7.911161286186349,
  6.71525418993304,
  3.171514461085306,
  -4.228833308769227,
  -3.532583959722861,
  2.446789310157182,
  -0.4158362013207336,
  3.6410206480297926,
  -7.14992673159284,
  10.78672873386352,
  4.3889315979614345,
  -5.781894478256094,
  6.211971539542357,
  -1.0298283188974704,
  -0.04018120965681371,
  -1.7705561093021722,
  -11.368012914148942,
  14.822047305213744,
  2.481380517704754,
  -0.5060000788044992,
  9.136656160609991,
  -3.96703800321657,
  6.488657374146306,
  -0.03982178185451063]

print(sum(snr)/len(snr))
cnt = 0
for i in range(len(snr)):
    cnt = cnt + int(snr[i]>0)
print(cnt)

1.5839348409877392
30


In [6]:
# 2018/12/31 12:40
accur = [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
snr = [8.240998820245776,
  -9.807396471718175,
  10.287541607998836,
  3.5836771885271546,
  9.38400357937949,
  -1.4614503957915583,
  -1.162931653049811,
  3.2211157879474457,
  2.063703423985708,
  -1.291734904831695,
  4.429829222801779,
  14.067683199607533,
  -13.22726984159213,
  10.343464283495036,
  -5.098737397905819,
  9.564600102088889,
  5.738229906243913,
  5.4539361177681664,
  4.191032480258785,
  2.1577529623961302,
  -1.8017149742732004,
  5.266018460263124,
  0.2918861621178297,
  2.234977715074784,
  10.71434447462033,
  -9.863619164319644,
  -8.195302961781856,
  6.641515603169252,
  3.9781730166640705,
  -2.749921295435573,
  -1.8009283045528353,
  4.401786400575135,
  -2.76229068903871,
  1.4576142785699133,
  -4.708411700266992,
  12.455044097831768,
  9.808707336216631,
  1.5772010736237747,
  6.0057726111433665,
  -1.667109486944027,
  0.8251959707570942,
  -0.8442919946204832,
  -16.984587959932508,
  11.945206413836587,
  5.969883204262004,
  2.6423938645071305,
  7.588383858182938,
  -8.309931542552043,
  5.249249811584468,
  -0.9721423179563187]

print(sum(snr)/len(snr))
cnt = 0
for i in range(len(snr)):
    cnt = cnt + int(snr[i]>0)
print(cnt)

1.9814229995836294
32


In [None]:
good result : ['acoustic_guitar_2_violin_1', ]

In [11]:
# solo 2018/12/31 15:14
accur = [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
snr = [8.240998820245776,
  -9.807396471718175,
  10.287541607998836,
  3.5836771885271546,
  9.38400357937949,
  -1.4614503957915583,
  -1.162931653049811,
  3.2211157879474457,
  2.063703423985708,
  -1.291734904831695,
  4.429829222801779,
  14.067683199607533,
  -13.22726984159213,
  10.343464283495036,
  -5.098737397905819,
  9.564600102088889,
  5.738229906243913,
  5.4539361177681664,
  4.191032480258785,
  2.1577529623961302,
  -1.8017149742732004,
  5.266018460263124,
  0.2918861621178297,
  2.234977715074784,
  10.71434447462033,
  -9.863619164319644,
  -8.195302961781856,
  6.641515603169252,
  3.9781730166640705,
  -2.749921295435573,
  -1.8009283045528353,
  4.401786400575135,
  -2.76229068903871,
  1.4576142785699133,
  -4.708411700266992,
  12.455044097831768,
  9.808707336216631,
  1.5772010736237747,
  6.0057726111433665,
  -1.667109486944027,
  0.8251959707570942,
  -0.8442919946204832,
  -16.984587959932508,
  11.945206413836587,
  5.969883204262004,
  2.6423938645071305,
  7.588383858182938,
  -8.309931542552043,
  5.249249811584468,
  -0.9721423179563187]

print(sum(snr)/len(snr))
cnt = 0
for i in range(len(snr)):
    cnt = cnt + int(snr[i]>0)
print(cnt)

1.9814229995836294
32


In [1]:
snr = [8.240998820245776,
  -9.807396471718175,
  10.287541607998836,
  3.5836771885271546,
  9.38400357937949,
  -1.4614503957915583,
  -1.162931653049811,
  3.2211157879474457,
  2.063703423985708,
  -0.4097723217777879,
  4.429829222801779,
  14.067683199607533,
  -13.22726984159213,
  10.343464283495036,
  -5.098737397905819,
  9.564600102088889,
  5.738229906243913,
  5.4539361177681664,
  4.191032480258785,
  2.1577529623961302,
  -1.8017149742732004,
  5.266018460263124,
  0.2918861621178297,
  2.234977715074784,
  10.71434447462033,
  -9.863619164319644,
  -8.195302961781856,
  6.641515603169252,
  3.7922489840886793,
  -3.476187892045962,
  -1.8009283045528353,
  4.401786400575135,
  -2.76229068903871,
  1.4576142785699133,
  -4.708411700266992,
  12.455044097831768,
  9.808707336216631,
  1.5772010736237747,
  6.0057726111433665,
  -1.667109486944027,
  1.5002059690703295,
  -1.1127767737937773,
  -16.984587959932508,
  11.945206413836587,
  5.969883204262004,
  2.6423938645071305,
  7.588383858182938,
  -8.309931542552043,
  5.249249811584468,
  -0.9721423179563187]

print(sum(snr)/len(snr))
cnt = 0
for i in range(len(snr)):
    cnt = cnt + int(snr[i]>0)
print(cnt)

1.9889489430437899
32


In [23]:
# 2013/12/31 22：54 提交实验报告前完整的一次训练
accur = [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
snr = [9.777360795119105,
  -10.09248636743938,
  10.32364742145534,
  3.057444494070724,
  9.529792246779623,
  -0.6161393633356158,
  3.1117141337320224,
  6.190048117861831,
  2.956510307435125,
  0.04340889726304364,
  3.4566056205549684,
  12.543886389811345,
  -13.151761483136939,
  9.114885674476607,
  -4.766174396467576,
  9.825438513856191,
  6.061299399695891,
  6.848018085174301,
  3.2044448447364045,
  2.2413911499002683,
  -3.1581192467039965,
  5.333239158105924,
  3.3113456254585802,
  4.027945542808805,
  9.941331588664816,
  -8.116518557928018,
  -5.899301523814823,
  6.688577810126427,
  5.115476319864616,
  -1.8459559096740747,
  -2.151789376974273,
  3.8127202483788496,
  -0.8400440406089104,
  3.181078741122456,
  -10.673818773370423,
  9.308022146046875,
  9.815917894217304,
  1.9693316002253545,
  6.0703857165648465,
  -0.6313290081705814,
  1.6345979103775776,
  -1.0661233936990187,
  -11.885080352787893,
  14.333275361866031,
  8.341531499872444,
  4.962668134936543,
  8.200342574478096,
  -8.40075163572855,
  8.695711860595944,
  1.5057541363570497]

print('Location:\n  ', accur)
print('SDR:')
for i in range(len(snr)):
    print('%10.4f'%snr[i],end='')
    if not (i+1)%5:
        print()
    
print('mean SDR:\n%10.4f' % (sum(snr)/len(snr)))
print('location accuracy:\n%8.2f %%' % (100*sum(accur)/len(accur)))

Location:
   [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
SDR:
    9.7774  -10.0925   10.3236    3.0574    9.5298
   -0.6161    3.1117    6.1900    2.9565    0.0434
    3.4566   12.5439  -13.1518    9.1149   -4.7662
    9.8254    6.0613    6.8480    3.2044    2.2414
   -3.1581    5.3332    3.3113    4.0279    9.9413
   -8.1165   -5.8993    6.6886    5.1155   -1.8460
   -2.1518    3.8127   -0.8400    3.1811  -10.6738
    9.3080    9.8159    1.9693    6.0704   -0.6313
    1.6346   -1.0661  -11.8851   14.3333    8.3415
    4.9627    8.2003   -8.4008    8.6957    1.5058
mean SDR:
    2.6248
location accuracy:
   92.00 %
