In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import librosa
from tqdm import tqdm
import seaborn
import torch
import torch.nn as nn
BILICOUGH_ROOT = "G:/DATAS-Medical/BILIBILICOUGH/"

# 从mp4中提取wav音频

In [None]:
names = ["金_属_音_咳_嗽","剧烈的咳嗽。","女孩感冒哮喘发作","四个常见咳嗽声音最后一个要重视_听一下你属于哪一种","我咳嗽的样子", "小朋友哮喘发作（看着好痛苦）"]
for name in names:
    filename = BILICOUGH_ROOT + name + ".mp4"
    outname = BILICOUGH_ROOT + name + ".wav"
    os.system("ffmpeg -i {} -f wav -ar 44100 {}".format(filename, outname))

In [None]:
file_list = []
for item in os.listdir(BILICOUGH_ROOT):
    if item[-3:] == "ass":
        file_list.append(item)
name_mapper = open(BILICOUGH_ROOT + "filename2index.txt", 'w')
for idx, item in enumerate(file_list):
    print(BILICOUGH_ROOT+item)
    name_mapper.write("bilicough_{},".format(("00"+str(idx))[-3:])+item[:-4]+"\n")
name_mapper.close()

# 读取整个音频并标注其咳嗽段

In [None]:
wavfin = open(BILICOUGH_ROOT + "filename2index.txt", 'r')
name_list = []
wavfin.readline()
line = wavfin.readline()
while line:
    name_list.append(line.strip())
    line = wavfin.readline()
print(name_list)

In [None]:
def min2sec(t: str):
    parts = t.split(':')
    res = float(parts[-1])
    f = 60
    for i in range(len(parts)-1):
        res += int(parts[len(parts)-2-i]) * f
        f *= 60
    return res

def wav_plot(wavfile, label_list, idx=0):
    y, sr = librosa.load(BILICOUGH_ROOT + wavfile)
    print("sample rate:", sr)
    y_plt = np.array([])
    for item in label_list:
        st, en = int(min2sec(item[0])*sr), int(min2sec(item[1])*sr+1)
        print("st, en:", st, en)
        seg = y[st: en]
        y_plt = np.concatenate((y_plt, seg, np.zeros(8000)), axis=0)
    plt.figure(idx)
    plt.plot(y_plt)

In [None]:
idx = 17
wavtest = name_list[idx] + ".wav"
asstest = name_list[idx] + ".ass"
print(wavtest, asstest)
assfin = open(BILICOUGH_ROOT + asstest, 'r', encoding="utf-8")
label_list = []
line = assfin.readline()
while line.strip()!="[Events]":
    line = assfin.readline()
    # print(line)
assfin.readline()
line = assfin.readline()
while line:
    # print(line)
    parts = line.split(',')
    if parts[9].strip() == "useless":
        pass
    else:
        label_list.append([parts[1], parts[2], parts[9].strip()])
    line = assfin.readline()
for item in label_list:
    print(item)
wav_plot(wavtest, label_list, idx)

### 批量绘图

In [None]:
for idx, name in enumerate(name_list):
    wavtest = name_list[idx] + ".wav"
    asstest = name_list[idx] + ".ass"

    assfin = open(BILICOUGH_ROOT + asstest, 'r', encoding="utf-8")
    label_list = []
    line = assfin.readline()
    while line.strip()!="[Events]":
        line = assfin.readline()
        # print(line)
    assfin.readline()
    line = assfin.readline()
    while line:
        # print(line)
        parts = line.split(',')
        if parts[9].strip() == "useless":
            pass
        else:
            label_list.append([parts[1], parts[2], parts[9].strip()])
        line = assfin.readline()
    for item in label_list:
        print(item)
    
    wav_plot(wavtest, label_list, idx)

# 二分类及其标注
- 非咳嗽的标注：0，"useless", "silence", "noise"
- 咳嗽的标注：1，其他都是

### 第一步，读取所有的ass文件
- 查看标签有哪些，来自哪些文件
- 查看标签的个数
- 查看时长分布

In [None]:
wavfin = open("G:/DATAS-Medical/BILIBILICOUGH/filename2index.txt", 'r')
name_list = []
wavfin.readline()
line = wavfin.readline()
while line:
    name_list.append(line.strip())
    line = wavfin.readline()
print(name_list)

### 注意！此处有重要文件“bilicough_metainfo.csv”的创建和写入

In [None]:
print(name_list)
label_dict = dict()
label_names = ["breathe", "cough","clearthroat","exhale", "hum", "inhale","noise", "silence", "sniff","speech", "vomit","whooping"]
label_cnt = dict()
name2label = {"breathe":0, "cough":2,"clearthroat":1,"exhale":3, "hum":4, "inhale":5,"noise":6, "silence":7, "sniff":8,"speech":9, "vomit":10,"whooping":11}
# metainfo_file = open("G:/DATAS-Medical/BILIBILICOUGH/bilicough_metainfo.csv", 'w')
# metainfo_file.write("filename,st,en,labelfull,labelname,label\n")
for idx, name in enumerate(name_list):
    wavtest = name_list[idx] + ".wav"
    asstest = name_list[idx] + ".ass"
    assfin = open("G:/DATAS-Medical/BILIBILICOUGH/" + asstest, 'r', encoding="utf-8")
    label_list = []
    line = assfin.readline()
    while line.strip()!="[Events]":
        line = assfin.readline()
        # print(line)
    assfin.readline()
    line = assfin.readline()
    while line:
        # print(line)
        parts = line.split(',')
        lab_tmp = parts[9].strip()
        if lab_tmp == "useless":
            pass
        # if lab_tmp == "clearingthroat":
        #     print(name_list[idx])
        else:
            label_list.append([parts[1], parts[2], lab_tmp])
            if lab_tmp not in label_dict:
                label_dict[lab_tmp] = 1
            else:
                label_dict[lab_tmp] = label_dict.get(lab_tmp)+1
            
            label = None
            if lab_tmp[:3] == "hum":
                label = lab_tmp[:3]
            elif lab_tmp[:5] in ["cough", "noise", "sniff", "vomit"]:
                label = lab_tmp[:5]
            elif lab_tmp[:6] in ["inhale", "exhale", "speech"]:
                label = lab_tmp[:6]
            elif lab_tmp[:7] in ["breathe","silence"]:
                label = lab_tmp[:7]
            elif lab_tmp[:8] in ["whooping"]:
                label = lab_tmp[:8]
            elif lab_tmp[:11] in ["clearthroat"]:
                label = lab_tmp[:11]
            else:
                print(lab_tmp, name_list[idx])
                raise Exception("Unknown Class.")
                
            if label not in label_cnt:
                label_cnt[label] = 1
            else:
                label_cnt[label] = label_cnt.get(label)+1
            # metainfo_file.write("{},{},{},{},{},{}\n".format(name_list[idx], parts[1], parts[2] ,lab_tmp, label,name2label[label]))
        line = assfin.readline()
# metainfo_file.close()
# for item in label_list:
#     print(item)
print("标签分布：")
for k,v in label_dict.items():
    print("key:{},\tcount:{}".format(k,v))
print("---------------=============----------------")
for k,v in label_cnt.items():
    print("key:{},\tcount:{}".format(k,v))


## 读取metainfo文件，创建不同任务的标注

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import librosa

In [None]:
metadf = pd.read_csv("G:/DATAS-Medical/BILIBILICOUGH/bilicough_metainfo.csv", delimiter=',', header=0, index_col=None, encoding="ansi")
print(metadf)
newdf = metadf
newdf["binlab"] = newdf["label"].apply(lambda x:2 if x==2 else 0)
newdf

In [None]:
def min2sec(t: str):
    parts = t.split(':')
    res = float(parts[-1])
    f = 60
    for i in range(len(parts)-1):
        res += int(parts[len(parts)-2-i]) * f
        f *= 60
    return res


In [None]:
sn_list = []
sr = 22050
for ind, item in enumerate(metadf.itertuples()):
    # print("key:{},\tcount:{}, st:{}, en:{}, {}, {}".format(ind, item[1], item[2], item[3], item[5], item[6]))
    if item[6] == 2:
        st, en = int(min2sec(item[2])*sr), int(min2sec(item[3])*sr+1)
        sn = en - st
        sn_list.append((en - st)/22050)

# plt.hist(sn_list, bins=12)

trs = [0.3, 0.5, 0.8, 1.0, 1.2, 1.5, 1.7, 2.0, 2.5, 10]
cnt_list = [0] * len(trs)
for sn in sn_list:
    for i in range(len(trs)):
        if sn < trs[i]+0.1:
            cnt_list[i] += 1
            break
plt.figure(0)
plt.bar([str(item) for item in trs], cnt_list, width=0.2)

# 通过滑动窗口截取数据片段
## 在所有数据中获取有效片段和无效片段

In [None]:
import pandas as pd
import librosa
# newdf.groupby("binlab").count()
def min2sec(t: str):
    parts = t.split(':')
    res = float(parts[-1])
    f = 60
    for i in range(len(parts)-1):
        res += int(parts[len(parts)-2-i]) * f
        f *= 60
    return res

def get_bilicough_dataset():
    ROOT = "G:/DATAS-Medical/BILIBILICOUGH/"
    metadf = pd.read_csv(ROOT+"bilicough_metainfo.csv", delimiter=',', header=0, index_col=None, usecols=[0,1,2,5], encoding="ansi")
    print(metadf)
    cur_fname = None
    cur_wav = None
    data_length = None
    sample_list = []
    label_list = []
    sr_list = []
    pre_st, pre_en = None, None
    # filename	st	en	labelfull	labelname	label	binlab
    for ind, item in enumerate(metadf.itertuples()):
        if (cur_fname != item[1]) or (cur_fname is None):
            cur_fname = item[1]
            cur_wav, sr = librosa.load(ROOT+cur_fname+".wav")
            if sr not in sr_list:
                sr_list.append(sr)
            data_length = sr
        st, en = int(min2sec(item[2])*sr), int(min2sec(item[3])*sr+1)
        if en > len(cur_wav):
            en = len(cur_wav)
        if en - st < 100:
            raise Exception("Error Index.")
        sn = en - st
        # sec = (en - st)/22050
        if (pre_en is None):
            if st >= data_length:
                st_pos = 0
                ind = 0
                while st_pos + data_length <= st:
                    # if len(cur_wav[st_pos:st_pos+data_length]) != sr:
                    #     raise Exception("Error Length.")
                    sample_list.append(cur_wav[st_pos:st_pos+data_length])
                    label_list.append(0)
                    st_pos += data_length
                    ind += 1
                    if ind >2:
                        break
                sample_list.append(cur_wav[st-data_length:st])
                label_list.append(0)
        else:
            if st - pre_en >= sr:
                st_pos = pre_en
                ind = 0
                while st_pos + data_length <= st:
                    # if len(cur_wav[st_pos:st_pos+data_length]) != sr:
                    #     raise Exception("Error Length.")
                    sample_list.append(cur_wav[st_pos:st_pos+data_length])
                    label_list.append(0)
                    st_pos += data_length
                    ind += 1
                    if ind > 2:
                        break
                sample_list.append(cur_wav[st-data_length:st])
                label_list.append(0)
        label = int(item[4])
        if sn==data_length:
            # if len(cur_wav[st:en]) != sr:
            #     raise Exception("Error Length.")
            sample_list.append(cur_wav[st:en])
            if label in [6, 7]:
                label_list.append(0)
            else:
                label_list.append(1)
        elif sn < data_length:
            new_sample = np.zeros(data_length)
            # print(st, en, sn, len(cur_wav), item[1])
            if en <= len(cur_wav):
                new_sample[:sn] = cur_wav[st:en]
            else:
                new_sample[:sn] = cur_wav[len(cur_wav)-sn:len(cur_wav)]
            # if len(new_sample) != sr:
            #     raise Exception("Error Length.")
            sample_list.append(new_sample)
            if label in [6, 7]:
                label_list.append(0)
            else:
                label_list.append(1)
        else:
            cnt_sum = sn // data_length + 1
            res = cnt_sum * data_length - sn
            overlap = res // (cnt_sum-1)
            st_pos = st
            while st_pos + data_length < en:
                # if len(cur_wav[st_pos:st_pos+data_length]) < data_length: 
                #     tmp_length = len(cur_wav[st_pos:st_pos+data_length])
                #     print(data_length, tmp_length)
                #     # raise Exception("Error Length.")
                #     print("Error Length.")
                #     new_sample = np.zeros(data_length)
                #     new_sample[:tmp_length] = cur_wav[st_pos:st_pos+data_length]
                #     sample_list.append(new_sample)
                # else:
                #     sample_list.append(cur_wav[st_pos:st_pos+data_length])  
                sample_list.append(cur_wav[st_pos:st_pos+data_length])                
                if label in [6, 7]:
                    label_list.append(0)
                else:
                    label_list.append(1)
                st_pos += data_length - overlap
            sample_list.append(cur_wav[en-data_length:en])
            label_list.append(1)
        pre_st, pre_en = st, en
    print("sound count:{}, all count:{}.".format(sum(label_list), len(label_list)))
    print(sr_list)
    return sample_list, label_list

sample_list, label_list = get_bilicough_dataset()
length_list = []
for item in sample_list:
    if len(item) not in length_list:
        length_list.append(len(item))
print(length_list)

## 在另外下载的白噪声数据中截取部分用于充实数据的无效片段

In [None]:
import os
import librosa

def load_bilinoise_dataset():
    NOISE_ROOT = "G:/DATAS-Medical/BILINOISE/"
    noise_length = None
    filter_length = 25
    ind = 0
    new_noise_list = []
    for item in os.listdir(NOISE_ROOT):
        if item[-4:] == ".wav" and len(item)>=filter_length:
            cur_fname = NOISE_ROOT+item
            cur_wav, sr = librosa.load(cur_fname)
            noise_length = sr
            L = len(cur_wav)
            st_pos = np.random.randint(0, L-noise_length)
            new_noise_list.append(cur_wav[st_pos:st_pos+noise_length])
            # print(NOISE_ROOT+item)
        ind += 1
        if ind > 18:
            break
    for item in new_noise_list:
        print(len(item))

# End