In [None]:
import os
import numpy as np
import torch
import pickle
import requests
import fnmatch
from datetime import datetime
import sys

In [None]:
!git clone https://github.com/BlairLeng/VoiceForge.git

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import numpy as np
import torch
import pickle


import requests
from VoiceForge.audio_feature_ext.modules.ecapa_tdnn import EcapaTdnn, SpeakerIdetification
from VoiceForge.audio_feature_ext.data_utils.reader import load_audio, CustomDataset

class AudioFeatureExtraction:
    def __init__(self,model_director='./audio_feature_ext/models', feature_method='melspectrogram' ):
        self.use_model = ''
        self.model_director = model_director
        self.feature_method = feature_method
        self.model = None
        self.device = None
        self.load_model()

    def init_models(self,path):
        model_urls = ['https://huggingface.co/scixing/voicemodel/resolve/main/model.pth',
                      'https://huggingface.co/scixing/voicemodel/resolve/main/model.state',
                      'https://huggingface.co/scixing/voicemodel/resolve/main/optimizer.pth']
        listdir = os.listdir(path)
        for url in model_urls:
            filename = url.split('/')[-1]
            if filename in listdir:
                continue
            r = requests.get(url, allow_redirects=True)
            print(f'downloading model pth {filename}')
            open(f'{path}/{filename}', 'wb').write(r.content)
            print(f'{filename} success download')

    def load_model(self):
        dataset = CustomDataset(data_list_path=None, feature_method=self.feature_method)
        ecapa_tdnn = EcapaTdnn(input_size=dataset.input_size)
        self.model = SpeakerIdetification(backbone=ecapa_tdnn)
        self.device = torch.device("cuda")

        if not os.path.exists(self.model_director):
            os.makedirs(self.model_director)
        model_files = ['model.pth', 'model.state', 'optimizer.pth']
        for file in model_files:
            if not os.path.exists(f'{self.model_director}/{file}'):
                self.init_models(self.model_director)

        # 加载模型
        model_path = os.path.join(self.model_director, 'model.pth')
        model_dict = self.model.state_dict()
        param_state_dict = torch.load(model_path)
        for name, weight in model_dict.items():
            if name in param_state_dict.keys():
                if list(weight.shape) != list(param_state_dict[name].shape):
                    param_state_dict.pop(name, None)
        self.model.load_state_dict(param_state_dict, strict=False)
        print(f"成功加载模型参数和优化方法参数：{model_path}")
        self.model.to(self.device)
        self.model.eval()

    def infer(self, audio_path, duration):
        data = load_audio(audio_path, mode='infer', feature_method=self.feature_method,
                          chunk_duration=duration)
        data = data[np.newaxis, :]
        data = torch.tensor(data, dtype=torch.float32, device=self.device)
        feature = self.model.backbone(data)
        return feature.data.detach().cpu().numpy()

    def feats_extract(self, dataloader):
        feat_list = []
        for data in dataloader:
            data = data.to(self.device)
            feats = self.model.backbone(data)
            feat_list.append(feats.data.detach().cpu().numpy())
        return np.concatenate(np.array(feat_list))

In [None]:
class CustomDefDataset(torch.utils.data.Dataset):
    """
    加载并预处理音频
    :param data_list_path: 数据列表
    :param feature_method: 预处理方法
    :param mode: 对数据处理的方式，包括train，eval，infer
    :param sr: 采样率
    :param chunk_duration: 训练或者评估使用的音频长度
    :param min_duration: 最小训练或者评估的音频长度
    :param augmentors: 数据增强方法
    :return:
    """
    def __init__(self, data_list,
                 feature_method='melspectrogram',
                 mode='eval',
                 sr=16000,
                 chunk_duration=2,
                 min_duration=2,
                 augmentors=None):
        super(CustomDefDataset, self).__init__()
        self.lines = data_list
        # self.lines = data_list_path
        self.feature_method = feature_method
        self.mode = mode
        self.sr = sr
        self.chunk_duration = chunk_duration
        self.min_duration = min_duration
        self.augmentors = augmentors

    def __getitem__(self, idx):
        try:
            audio_path = self.lines[idx]
            # print(audio_path)
            # 加载并预处理音频
            features = load_audio(audio_path, feature_method=self.feature_method, mode=self.mode, sr=self.sr,
                                  chunk_duration=self.chunk_duration, min_duration=self.min_duration,
                                  augmentors=self.augmentors)
            # print(features.shape)
            return features
        except Exception as ex:
            print(f"[{datetime.now()}] 数据: {self.lines[idx]} 出错，错误信息: {ex}", file=sys.stderr)
            rnd_idx = np.random.randint(self.__len__())
            return self.__getitem__(rnd_idx)

    def __len__(self):
        return len(self.lines)

    @property
    def input_size(self):
        if self.feature_method == 'melspectrogram':
            return 80
        elif self.feature_method == 'spectrogram':
            return 201
        else:
            raise Exception(f'预处理方法 {self.feature_method} 不存在！')


In [None]:
# 定义文件夹路径
folder_path = "/content/drive/MyDrive/GPTData/Haruhi_audio"

audio_path_list = []
# 使用os模块列出文件夹中所有的文件和子文件夹
for root, dirs, files in os.walk(folder_path):
    for file in files:
        # 使用fnmatch模块匹配以.wav结尾的文件
        if fnmatch.fnmatch(file, '*.wav'):
            # 打印或处理符合条件的文件
            audio_path_list.append(os.path.join(root, file))

print(type(audio_path_list))

In [None]:
from torch.utils.data import DataLoader
dataset = CustomDefDataset(audio_path_list)
dataloader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=6)

In [None]:
AFE = AudioFeatureExtraction()

import time
since = time.time()

feature_np = AFE.feats_extract(dataloader)

print(time.time() - since)