In [1]:
import os
import numpy as np
import torch
import pickle
import requests

In [2]:
!git clone https://github.com/BlairLeng/VoiceForge.git

Cloning into 'VoiceForge'...
remote: Enumerating objects: 37, done.[K
remote: Counting objects: 100% (37/37), done.[K
remote: Compressing objects: 100% (31/31), done.[K
remote: Total 37 (delta 3), reused 33 (delta 3), pack-reused 0[K
Receiving objects: 100% (37/37), 324.89 KiB | 12.50 MiB/s, done.
Resolving deltas: 100% (3/3), done.


In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
import os
import numpy as np
import torch
import pickle


import requests
from VoiceForge.audio_feature_ext.modules.ecapa_tdnn import EcapaTdnn, SpeakerIdetification
from VoiceForge.audio_feature_ext.data_utils.reader import load_audio, CustomDataset

class AudioFeatureExtraction:
    def __init__(self,model_director='./audio_feature_ext/models', feature_method='melspectrogram' ):
        self.use_model = ''
        self.model_director = model_director
        self.feature_method = feature_method
        self.model = None
        self.device = None
        self.load_model()

    def init_models(self,path):
        model_urls = ['https://huggingface.co/scixing/voicemodel/resolve/main/model.pth',
                      'https://huggingface.co/scixing/voicemodel/resolve/main/model.state',
                      'https://huggingface.co/scixing/voicemodel/resolve/main/optimizer.pth']
        listdir = os.listdir(path)
        for url in model_urls:
            filename = url.split('/')[-1]
            if filename in listdir:
                continue
            r = requests.get(url, allow_redirects=True)
            print(f'downloading model pth {filename}')
            open(f'{path}/{filename}', 'wb').write(r.content)
            print(f'{filename} success download')

    def load_model(self):
        dataset = CustomDataset(data_list_path=None, feature_method=self.feature_method)
        ecapa_tdnn = EcapaTdnn(input_size=dataset.input_size)
        self.model = SpeakerIdetification(backbone=ecapa_tdnn)
        self.device = torch.device("cuda")
        self.model.to(self.device)

        if not os.path.exists(self.model_director):
            os.makedirs(self.model_director)
        model_files = ['model.pth', 'model.state', 'optimizer.pth']
        for file in model_files:
            if not os.path.exists(f'{self.model_director}/{file}'):
                self.init_models(self.model_director)

        # 加载模型
        model_path = os.path.join(self.model_director, 'model.pth')
        model_dict = self.model.state_dict()
        param_state_dict = torch.load(model_path)
        for name, weight in model_dict.items():
            if name in param_state_dict.keys():
                if list(weight.shape) != list(param_state_dict[name].shape):
                    param_state_dict.pop(name, None)
        self.model.load_state_dict(param_state_dict, strict=False)
        print(f"成功加载模型参数和优化方法参数：{model_path}")
        self.model.eval()

    def infer(self, audio_path, duration):
        data = load_audio(audio_path, mode='infer', feature_method=self.feature_method,
                          chunk_duration=duration)
        data = data[np.newaxis, :]
        data = torch.tensor(data, dtype=torch.float32, device=self.device)
        feature = self.model.backbone(data)
        return feature.data.cpu().numpy()

In [8]:
AFE = AudioFeatureExtraction()

成功加载模型参数和优化方法参数：./audio_feature_ext/models/model.pth


In [10]:
import wave
import contextlib

In [14]:
file_path = "/content/drive/MyDrive/GPTData/Haruhi_audio/01/312)}{\k20}凉{\k20}宫{\k20}春{\k20}日{\k20}团00:23:07.070000.wav"

with contextlib.closing(wave.open(file_path, 'rb')) as f:
    frames = f.getnframes()
    rate = f.getframerate()
    duration = frames / float(rate)

AFE.infer(file_path, duration)[0]

array([-0.25482637, -0.08357781, -0.07106941, -0.26909938,  0.18829505,
        0.26165918,  0.4537298 ,  0.73450327, -0.35488844,  0.11526221,
       -0.00976957, -0.14743726, -0.15223378,  0.36701366, -0.43090054,
       -0.34027702,  0.700111  ,  0.20501454,  0.23966783,  0.10887624,
        0.6455976 ,  0.19876142, -0.21320389, -0.5663119 , -0.14633986,
        0.14458369,  0.09441376, -0.14459994,  0.19150178,  0.18343854,
       -0.02837923, -0.325045  , -0.13683006,  0.04951802, -0.04076549,
       -0.20020688, -0.12568444, -0.20528723,  0.25369284, -0.23654728,
        0.08003059,  0.6285238 ,  0.06177036,  0.38667247, -0.21959871,
       -0.45782745,  0.25508097, -0.31731164,  0.5698923 , -0.26788864,
        0.6064223 ,  0.09479906,  0.33312544,  0.29221392,  0.17239876,
       -0.24952959,  0.21743146,  0.4922386 ,  0.26232612, -0.17286031,
        0.45107996, -0.18003249,  0.37111294, -0.32601923, -0.30233046,
       -0.3790557 , -0.19180658,  0.2996932 , -0.05059752, -0.73