In [1]:
# -*- coding: utf-8 -*-
# Author: Vi
# Created on: 2024-06-13 16:35:53
# Description: 创建数据分类并测试

from datasets.models.label import Label
from datasets.models.sources import US8KDataSource, ESC50DataSource, ProvinceDataSource
from datasets.models.category import Category

In [2]:
import yaml
with open("configs.yml", 'r') as f:
    configs = yaml.safe_load(f)
    
datasources_info = configs['DataSources']

In [3]:
us8k = US8KDataSource(**datasources_info['US8K'])
esc50 = ESC50DataSource(**datasources_info['ESC50'])
province_traffic = ProvinceDataSource(name="交通噪声", **datasources_info['Province'])
province_nature = ProvinceDataSource(name="自然噪声", **datasources_info['Province'])
province_industry = ProvinceDataSource(name="工业噪声", **datasources_info['Province'])
province_social = ProvinceDataSource(name="社会噪声", **datasources_info['Province'])
province_building = ProvinceDataSource(name="建筑施工噪声", **datasources_info['Province'])

In [4]:
from enum import Enum

class DataSources(Enum):
    US8K = us8k
    ESC50 = esc50
    PROVINCE_TRAFFIC = province_traffic
    PROVINCE_NATURE = province_nature
    PROVINCE_INDUSTRY = province_industry
    PROVINCE_SOCIAL = province_social
    PROVINCE_BUILDING = province_building

In [5]:
','.join(esc50.childs_info)

'dog,rooster,pig,cow,frog,cat,hen,insects,sheep,crow,rain,sea_waves,crackling_fire,crickets,chirping_birds,water_drops,wind,pouring_water,toilet_flush,thunderstorm,crying_baby,sneezing,clapping,breathing,coughing,footsteps,laughing,brushing_teeth,snoring,drinking_sipping,door_wood_knock,mouse_click,keyboard_typing,door_wood_creaks,can_opening,washing_machine,vacuum_cleaner,clock_alarm,clock_tick,glass_breaking,helicopter,chainsaw,siren,car_horn,engine,train,church_bells,airplane,fireworks,hand_saw'

In [6]:
province_nature.to_label()

Label(id=None, name='自然噪声', sources=[ProvinceDataSource(base_dir="\\10.166.168.123\典型城市声纹数据库-标签过", name="北红尾鸲叫声", label=0, length=1200, parent="自然噪声"), ProvinceDataSource(base_dir="\\10.166.168.123\典型城市声纹数据库-标签过", name="叉尾太阳鸟叫声", label=1, length=1200, parent="自然噪声"), ProvinceDataSource(base_dir="\\10.166.168.123\典型城市声纹数据库-标签过", name="大鹰鹃叫声", label=2, length=1200, parent="自然噪声"), ProvinceDataSource(base_dir="\\10.166.168.123\典型城市声纹数据库-标签过", name="强脚树莺叫声", label=3, length=1200, parent="自然噪声"), ProvinceDataSource(base_dir="\\10.166.168.123\典型城市声纹数据库-标签过", name="普通夜鹰叫声", label=4, length=1200, parent="自然噪声"), ProvinceDataSource(base_dir="\\10.166.168.123\典型城市声纹数据库-标签过", name="棕颈钩嘴鹛叫声", label=5, length=1200, parent="自然噪声"), ProvinceDataSource(base_dir="\\10.166.168.123\典型城市声纹数据库-标签过", name="淡脚柳莺叫声", label=6, length=1200, parent="自然噪声"), ProvinceDataSource(base_dir="\\10.166.168.123\典型城市声纹数据库-标签过", name="潮汐声", label=7, length=1200, parent="自然噪声"), ProvinceDataSource(base_dir="\\10.166.168.123

In [7]:
labels = [
    province_nature.get_child("雷声").to_label(),
    province_nature.get_child("蛙声").to_label(),
    Label(
        name="鸟叫",
        sources=province_nature.get_childs(['北红尾鸲叫声', '叉尾太阳鸟叫声', '大鹰鹃叫声', '强脚树莺叫声', '普通夜鹰叫声', '棕颈钩嘴鹛叫声', '淡脚柳莺叫声'])+esc50.get_childs('chirping_birds'),
    )
] + [c.to_label() for c in province_traffic.childs]
len(labels)

23

In [8]:
category = Category(name='test', labels=labels)
len(category)

34840

In [9]:
len(category.labels)

23

In [10]:
category.labels_info

[{'id': 0, 'name': '三轮车'},
 {'id': 1, 'name': '公交车'},
 {'id': 2, 'name': '地铁'},
 {'id': 3, 'name': '小艇'},
 {'id': 4, 'name': '拖拉机'},
 {'id': 5, 'name': '摩托车'},
 {'id': 6, 'name': '救火警铃'},
 {'id': 7, 'name': '有轨电车'},
 {'id': 8, 'name': '汽车'},
 {'id': 9, 'name': '汽车刹车声'},
 {'id': 10, 'name': '汽车鸣笛'},
 {'id': 11, 'name': '直升机'},
 {'id': 12, 'name': '船'},
 {'id': 13, 'name': '蛙声'},
 {'id': 14, 'name': '警铃'},
 {'id': 15, 'name': '货车'},
 {'id': 16, 'name': '车辆防盗报警'},
 {'id': 17, 'name': '铁轨'},
 {'id': 18, 'name': '长途客车'},
 {'id': 19, 'name': '雷声'},
 {'id': 20, 'name': '飞机'},
 {'id': 21, 'name': '高铁'},
 {'id': 22, 'name': '鸟叫'}]

In [11]:
import os
if not os.path.exists('test'):
    os.makedirs('test')
with open('test/category.txt','w',encoding='utf-8') as f:
    for i in range(len(category)):
        f.write(str(category[i])+'\n')

In [12]:
# 按比例分割数据，并且分层
X_train, X_test, y_train, y_test = category.get_train_test_data()
_, _ = category.count_train_test_data(y_train=y_train, y_test=y_test)


训练集分类分布: Counter({22: 6752, 7: 960, 21: 960, 15: 960, 16: 960, 5: 960, 1: 960, 12: 960, 3: 960, 8: 960, 9: 960, 17: 960, 20: 960, 19: 960, 6: 960, 4: 960, 18: 960, 11: 960, 10: 960, 0: 960, 14: 960, 13: 960, 2: 960})
测试集分类分布: Counter({22: 1688, 16: 240, 11: 240, 8: 240, 14: 240, 7: 240, 9: 240, 2: 240, 18: 240, 3: 240, 13: 240, 19: 240, 4: 240, 21: 240, 5: 240, 17: 240, 10: 240, 15: 240, 0: 240, 12: 240, 6: 240, 20: 240, 1: 240})


In [13]:
len(X_test), len(y_test), len(X_train), len(y_train)

(6968, 6968, 27872, 27872)

In [14]:
with open('test/category.txt', 'r', encoding='utf-8') as f:
    lines = f.read()

In [15]:
# 检查结果
for x, y in zip(X_train, y_train):
    if (res:=str((x, y))+'\n') not in lines:
        print(res)
        
for x, y in zip(X_test, y_test):
    if (res:=str((x, y))+'\n') not in lines:
        print(res)