In [None]:
# 虚拟环境必须 python 3.9+(因为whisper)

In [1]:
!pip install -q pysrt
!pip install -q pysubs2

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/104.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.4/104.4 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for pysrt (setup.py) ... [?25l[?25hdone


### 1.根据字幕切割出分段音频

In [2]:
import os
import re
import subprocess
from collections import Counter

import chardet
import pysrt
import pysubs2
from tqdm import tqdm

#### 常用函数



In [3]:
def get_subdir(directory):
    subdirectories = []
    for dirpath, dirnames, files in os.walk(directory):
        for dirname in dirnames:
            subdirectories.append(os.path.join(dirpath, dirname))
    subdirectories.sort()
    return subdirectories

def get_filename(directory,format=None):
    file_list = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if format:
                if file.endswith(format):
                    file_path = os.path.join(root, file)
                    file_list.append([file,file_path])
            else:
                file_path = os.path.join(root, file)
                file_list.append([file, file_path])
    file_list.sort()
    return file_list


#获取一级子目录
def get_first_subdir(directory):
    subdirectories = []
    for name in os.listdir(directory):
        if os.path.isdir(os.path.join(directory, name)):
            subdirectories.append(os.path.join(directory, name))
    subdirectories.sort()
    return subdirectories

In [4]:
def detect_encoding(file_name):
    with open(file_name, 'rb') as file:
        result = chardet.detect(file.read())
    return result['encoding']

def most_common_element(lst,num=1):
    counter = Counter(lst)
    most = counter.most_common(num)
    return most


def make_filename_safe(filename):
    # 将非法字符替换为下划线
    filename = re.sub(r'[\\/:*?"<>|]', '_', filename)
    # 去除多余的空格
    filename = re.sub(r'\s+', ' ', filename)
    # 去除开头和结尾的空格
    filename = filename.strip()
    return filename


#### VideoSegmentation

In [5]:


class VideoSegmentation:
    def __init__(self, video_lis_pth,audio_out_dir,subtitle_dir):
        self.video_lis_pth = video_lis_pth
        self.audio_out_dir = audio_out_dir
        self.subtitle_dir = subtitle_dir


    def process(self):
        video_lis = get_filename(self.video_lis_pth)

        style = ''
        sub_format = ''
        voice_dir = 'voice'
        for file, pth in tqdm(video_lis[:], desc='Processing Videos'):

            filename, format = os.path.splitext(file)
            # 创建对应的音频文件夹
            os.makedirs(f'{self.audio_out_dir}/{filename}', exist_ok=True)
            os.makedirs(f'{self.audio_out_dir}/{filename}/{voice_dir}', exist_ok=True)

            if self.subtitle_dir:
                if not sub_format:
                    # 选择一个字幕文件 获取字幕文件的格式和编码
                    one_subtitle_file = os.path.join(self.subtitle_dir,os.listdir(self.subtitle_dir)[0])
                    sub_file,sub_format = os.path.splitext(one_subtitle_file)
                    encoding = detect_encoding(one_subtitle_file)

                # 获取当前视频对应的字幕文件
                cur_sub_file = f'{self.subtitle_dir}/{filename}{sub_format}'
                # 获取对应字幕
                if sub_format == '.srt':

                    srt_file = pysrt.open(cur_sub_file, encoding=encoding)
                    for index, subtitle in enumerate(srt_file[:]):
                        # 获取开始和结束时间

                        start_time = subtitle.start
                        end_time = subtitle.end

                        start_time = start_time.to_time()
                        end_time = end_time.to_time()
                        # print(f'开始时间：{start_time}，结束时间：{end_time}')

                        # 使用FFmpeg切割视频 改成mp3就无法输出
                        audio_output = f'{self.audio_out_dir}/{filename}/{voice_dir}/{index}_{make_filename_safe(subtitle.text)}.wav'

                        command = ['ffmpeg', '-ss', str(start_time), '-to', str(end_time), '-i', f'{pth}', "-vn",  '-c:a', 'pcm_s16le',
                                         audio_output,  '-loglevel', 'quiet']

                        subprocess.run(command)
                elif sub_format == '.ass':
                    subs = pysubs2.load(cur_sub_file, encoding=encoding)
                    if not style:
                        style_lis = [sub.style for sub in subs]
                        most_1 = most_common_element(style_lis)
                        style = most_1[0][0]

                    new_subs = [sub for sub in subs if sub.style == style]
                    for index, subtitle in enumerate(new_subs[:]):
                        # 获取开始和结束时间
                        if subtitle.style == style:
                            start_time = subtitle.start
                            end_time = subtitle.end

                            start_time = start_time / 1000
                            end_time = end_time / 1000


                            # 使用FFmpeg切割视频 改成mp3就无法输出
                            # audio_output = f'{self.audio_out_dir}/{filename}/{index}_{make_filename_safe(subtitle.text)}_ass.wav'
                            audio_output = f'{self.audio_out_dir}/{filename}/{voice_dir}/{index}_{make_filename_safe(subtitle.text)}.wav'

                            command = ['ffmpeg', '-ss', str(start_time), '-to', str(end_time), '-i', f'{pth}', "-vn",  '-c:a', 'pcm_s16le',
                                             audio_output,  '-loglevel', 'quiet']

                            subprocess.run(command)

        exit()


#### 自定义config参数

In [6]:
# 自定义修改下面字典的value值  标有**的文件夹，需要有对应文件。
video_config = {"video_lis_pth": "/content/drive/MyDrive/GPTData/origin_video",  # **视频所在文件夹 **需要准备视
                }
audio_config = {
                "audio_model_pth":'/content/drive/MyDrive/GPTData/voicemodel',  # **模型权重路径 需要下载模型→git clone https://huggingface.co/scixing/voicemodel
                "audio_roles_dir":'/content/drive/MyDrive/GPTData/roles', # **分类好的角色音频路径 需要手动分类
                "audio_out_dir": "/content/drive/MyDrive/GPTData/audio",  # 视频切割输出的音频路径
                }

srt_config = {
               "subtitle_dir":"/content/drive/MyDrive/GPTData/srt",  # **视频对应字幕，视频和字幕名称需要一致 需要准备 ,
                "srt_out_dir":"/content/drive/MyDrive/GPTData/roletxt",  # 预测的角色类型路径
            }


# whisper 模型和下载链接，不建议使用 tiny 和large 模型
WHISPER_MODELS_LIST = ['tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 
                            'medium.en', 'medium', 'large-v1', 'large-v2', 'large']
WHISPER_MODELS_LINKS = {
    "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
    "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
    "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
    "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
    "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
    "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
    "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
    "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
    "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
    "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
    "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}

####  Video to Subtitles with Whisper

In [None]:
!pip install openai-whisper

In [24]:
print(WHISPER_MODELS_LINKS["small"])

https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt


In [1]:
# !wget 模型地址

In [None]:
import ffmpeg
import tempfile
import torch
import whisper
from whisper.utils import get_writer
from typing import Iterator, TextIO

In [None]:
class Video2Subtitles(object):
    def __init__(self):
        pass

    def srt_format_timestamp(self, seconds: float):
        assert seconds >= 0, "non-negative timestamp expected"
        milliseconds = round(seconds * 1000.0)

        hours = milliseconds // 3_600_000
        milliseconds -= hours * 3_600_000

        minutes = milliseconds // 60_000
        milliseconds -= minutes * 60_000

        seconds = milliseconds // 1_000
        milliseconds -= seconds * 1_000

        return (f"{hours}:") + f"{minutes:02d}:{seconds:02d},{milliseconds:03d}"

    def write_srt(self, transcript: Iterator[dict], file: TextIO):
        count = 0
        for segment in transcript:
            count += 1
            print(
                # f"{count}\n"
                f"{self.srt_format_timestamp(segment['start'])} --> {self.srt_format_timestamp(segment['end'])}\n"
                f"{segment['text'].replace('-->', '->').strip()}\n",
                file=file,
                flush=True,
            )

    def transcribe(self, input_video: str, lang: str, MODEL_WHISPER: str, task: str, subtitle_format: str,
                   AddSrtToVideo: bool):
        """

        Parameters
        ----------
        input_video:
        lang: language of your input file
        MODEL_WHISPER: tiny/small /base/large,you can also download into your local path  eg. /tiny.pt
        task: transcribe/translate(any language to english)
        subtitle_format:"txt", "vtt", "srt", "tsv",  "json",
        AddSrtToVideo:

        Returns
        -------
        }
        """
        DEVICE = torch.cuda.is_available()
        model = whisper.load_model(MODEL_WHISPER)
        input_video_ = input_video if isinstance(input_video, str) else input_video.name
        result = model.transcribe(
            input_video_,
            task=task,
            language=lang,
            verbose=True,
            initial_prompt=None,
            word_timestamps=False,
            fp16=DEVICE
        )
        subtitle_file = input_video_.rsplit(".", 1)[0] + "." + subtitle_format
        print("subtitle_file:", subtitle_file)
        writer = get_writer(subtitle_format, str(tempfile.gettempdir()))
        writer(result, subtitle_file)
        if subtitle_format == "srt":
            with open(subtitle_file, "w") as srt:
                self.write_srt(result["segments"], file=srt)
        if AddSrtToVideo:
            return self.add_srt_to_video(input_video_, subtitle_file)
        return subtitle_file

    def add_srt_to_video(self, input_video_, subtitle_file):
        video_out = input_video_ + "_output.mp4"
        input_ffmpeg = ffmpeg.input(input_video_)
        input_ffmpeg_sub = ffmpeg.input(subtitle_file)
        input_video = input_ffmpeg['v']
        input_audio = input_ffmpeg['a']
        input_subtitles = input_ffmpeg_sub['s']
        stream = ffmpeg.output(
            input_video, input_audio, input_subtitles, video_out,
            vcodec='copy', acodec='copy', scodec=subtitle_format
        )
        stream = ffmpeg.overwrite_output(stream)
        ffmpeg.run(stream)
        return video_out


##### 运行: Video to Subtitles with Whisper

In [None]:
input_video = video_config["video_lis_pth"]+"/input.mp4"
MODEL_WHISPER = "small" # or your downloaded model path"
lang = "zh" # 输入视频的语言， 可支持各种语言 
task = "transcribe" # translate or transcribe
AddSrtToVideo = False
subtitle_format = "srt"
Video2Subtitles().transcribe(input_video, lang, MODEL_WHISPER, task, subtitle_format, AddSrtToVideo)

#### 运行1 音频提取分割

In [7]:
# video_segmentor = VideoSegmentation(video_config['video_lis_pth'],
#                                         audio_config['audio_out_dir'],
#                                         srt_config['subtitle_dir'])
# video_segmentor.process()

In [8]:
# !git clone https://huggingface.co/scixing/voicemodel

### 2.音频特征提取
Audio Feature Extraction  

In [9]:

import os
import numpy as np
import torch
import pickle
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import random
import sys
from datetime import datetime
import librosa
from torch.utils import data


#### Audio模型定义辅助函数

In [10]:
def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
    return nn.Sequential(
        Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
        Res2Conv1dReluBn(channels, kernel_size, stride, padding, dilation, scale=scale),
        Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0),
        SE_Connect(channels)
    )


def load_audio(audio_path,
               feature_method='melspectrogram',
               mode='train',
               sr=16000,
               chunk_duration=3,
               min_duration=0.5,
               augmentors=None):
    """
    加载并预处理音频
    :param audio_path: 音频路径
    :param feature_method: 预处理方法
    :param mode: 对数据处理的方式，包括train，eval，infer
    :param sr: 采样率
    :param chunk_duration: 训练或者评估使用的音频长度
    :param min_duration: 最小训练或者评估的音频长度
    :param augmentors: 数据增强方法
    :return:
    """
    # 读取音频数据
    wav, sr_ret = librosa.load(audio_path, sr=sr)
    num_wav_samples = wav.shape[0]
    # 数据太短不利于训练
    if mode == 'train':
        if num_wav_samples < int(min_duration * sr):
            raise Exception(f'音频长度小于{min_duration}s，实际长度为：{(num_wav_samples / sr):.2f}s')
    # 对小于训练长度的复制补充
    num_chunk_samples = int(chunk_duration * sr)
    if num_wav_samples <= num_chunk_samples:
        shortage = num_chunk_samples - num_wav_samples
        wav = np.pad(wav, (0, shortage), 'wrap')
    # 裁剪需要的数据
    if mode == 'train':
        # 随机裁剪
        num_wav_samples = wav.shape[0]
        num_chunk_samples = int(chunk_duration * sr)
        if num_wav_samples > num_chunk_samples + 1:
            start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
            stop = start + num_chunk_samples
            wav = wav[start:stop]
            # 对每次都满长度的再次裁剪
            if random.random() > 0.5:
                wav[:random.randint(1, sr // 4)] = 0
                wav = wav[:-random.randint(1, sr // 4)]
        # 数据增强
        if augmentors is not None:
            for key, augmentor in augmentors.items():
                if key == 'specaug': continue
                wav = augmentor(wav)
    elif mode == 'eval':
        # 为避免显存溢出，只裁剪指定长度
        num_wav_samples = wav.shape[0]
        num_chunk_samples = int(chunk_duration * sr)
        if num_wav_samples > num_chunk_samples + 1:
            wav = wav[:num_chunk_samples]
    # 获取音频特征
    if feature_method == 'melspectrogram':
        # 计算梅尔频谱
        features = librosa.feature.melspectrogram(y=wav, sr=sr, n_fft=400, n_mels=80, hop_length=160, win_length=400)
    elif feature_method == 'spectrogram':
        # 计算声谱图
        linear = librosa.stft(wav, n_fft=400, win_length=400, hop_length=160)
        features, _ = librosa.magphase(linear)
    else:
        raise Exception(f'预处理方法 {feature_method} 不存在！')
    features = librosa.power_to_db(features, ref=1.0, amin=1e-10, top_db=None)
    # 数据增强
    if mode == 'train' and augmentors is not None:
        for key, augmentor in augmentors.items():
            if key == 'specaug':
                features = augmentor(features)
    # 归一化
    mean = np.mean(features, 0, keepdims=True)
    std = np.std(features, 0, keepdims=True)
    features = (features - mean) / (std + 1e-5)
    return features


class Res2Conv1dReluBn(nn.Module):
    def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False, scale=4):
        super().__init__()
        assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
        self.scale = scale
        self.width = channels // scale
        self.nums = scale if scale == 1 else scale - 1

        self.convs = []
        self.bns = []
        for i in range(self.nums):
            self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
            self.bns.append(nn.BatchNorm1d(self.width))
        self.convs = nn.ModuleList(self.convs)
        self.bns = nn.ModuleList(self.bns)

    def forward(self, x):
        out = []
        spx = torch.split(x, self.width, 1)
        for i in range(self.nums):
            if i == 0:
                sp = spx[i]
            else:
                sp = sp + spx[i]
            # Order: conv -> relu -> bn
            sp = self.convs[i](sp)
            sp = self.bns[i](F.relu(sp))
            out.append(sp)
        if self.scale != 1:
            out.append(spx[self.nums])
        out = torch.cat(out, dim=1)
        return out


class Conv1dReluBn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
        self.bn = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        return self.bn(F.relu(self.conv(x)))


class SE_Connect(nn.Module):
    def __init__(self, channels, s=2):
        super().__init__()
        assert channels % s == 0, "{} % {} != 0".format(channels, s)
        self.linear1 = nn.Linear(channels, channels // s)
        self.linear2 = nn.Linear(channels // s, channels)

    def forward(self, x):
        out = x.mean(dim=2)
        out = F.relu(self.linear1(out))
        out = torch.sigmoid(self.linear2(out))
        out = x * out.unsqueeze(2)
        return out


class AttentiveStatsPool(nn.Module):
    def __init__(self, in_dim, bottleneck_dim):
        super().__init__()
        # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
        self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1)  # equals W and b in the paper
        self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1)  # equals V and k in the paper

    def forward(self, x):
        # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
        alpha = torch.tanh(self.linear1(x))
        alpha = torch.softmax(self.linear2(alpha), dim=2)
        mean = torch.sum(alpha * x, dim=2)
        residuals = torch.sum(alpha * x ** 2, dim=2) - mean ** 2
        std = torch.sqrt(residuals.clamp(min=1e-9))
        return torch.cat([mean, std], dim=1)


class EcapaTdnn(nn.Module):
    def __init__(self, input_size=80, channels=512, embd_dim=192):
        super().__init__()
        self.layer1 = Conv1dReluBn(input_size, channels, kernel_size=5, padding=2, dilation=1)
        self.layer2 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8)
        self.layer3 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8)
        self.layer4 = SE_Res2Block(channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8)

        cat_channels = channels * 3
        out_channels = cat_channels * 2
        self.emb_size = embd_dim
        self.conv = nn.Conv1d(cat_channels, cat_channels, kernel_size=1)
        self.pooling = AttentiveStatsPool(cat_channels, 128)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.linear = nn.Linear(out_channels, embd_dim)
        self.bn2 = nn.BatchNorm1d(embd_dim)

    def forward(self, x):
        out1 = self.layer1(x)
        out2 = self.layer2(out1) + out1
        out3 = self.layer3(out1 + out2) + out1 + out2
        out4 = self.layer4(out1 + out2 + out3) + out1 + out2 + out3

        out = torch.cat([out2, out3, out4], dim=1)
        out = F.relu(self.conv(out))
        out = self.bn1(self.pooling(out))
        out = self.bn2(self.linear(out))
        return out


class SpeakerIdetification(nn.Module):
    def __init__(
            self,
            backbone,
            num_class=1,
            lin_blocks=0,
            lin_neurons=192,
            dropout=0.1, ):
        """The speaker identification model, which includes the speaker backbone network
           and the a linear transform to speaker class num in training

        Args:
            backbone (Paddle.nn.Layer class): the speaker identification backbone network model
            num_class (_type_): the speaker class num in the training dataset
            lin_blocks (int, optional): the linear layer transform between the embedding and the final linear layer. Defaults to 0.
            lin_neurons (int, optional): the output dimension of final linear layer. Defaults to 192.
            dropout (float, optional): the dropout factor on the embedding. Defaults to 0.1.
        """
        super(SpeakerIdetification, self).__init__()
        # speaker idenfication backbone network model
        # the output of the backbond network is the target embedding
        self.backbone = backbone
        if dropout > 0:
            self.dropout = nn.Dropout(dropout)
        else:
            self.dropout = None

        # construct the speaker classifer
        input_size = self.backbone.emb_size
        self.blocks = list()
        for i in range(lin_blocks):
            self.blocks.extend([
                nn.BatchNorm1d(input_size),
                nn.Linear(in_features=input_size, out_features=lin_neurons),
            ])
            input_size = lin_neurons

        # the final layer
        self.weight = Parameter(torch.FloatTensor(num_class, input_size), requires_grad=True)
        nn.init.xavier_normal_(self.weight, gain=1)

    def forward(self, x):
        """Do the speaker identification model forwrd,
           including the speaker embedding model and the classifier model network

        Args:
            x (paddle.Tensor): input audio feats,
                               shape=[batch, dimension, times]
            lengths (paddle.Tensor, optional): input audio length.
                                        shape=[batch, times]
                                        Defaults to None.

        Returns:
            paddle.Tensor: return the logits of the feats
        """
        # x.shape: (N, C, L)
        x = self.backbone(x)  # (N, emb_size)
        if self.dropout is not None:
            x = self.dropout(x)

        for fc in self.blocks:
            x = fc(x)

        logits = F.linear(F.normalize(x), F.normalize(self.weight, dim=-1))

        return logits


# 数据加载器
class CustomDataset(data.Dataset):
    """
    加载并预处理音频
    :param data_list_path: 数据列表
    :param feature_method: 预处理方法
    :param mode: 对数据处理的方式，包括train，eval，infer
    :param sr: 采样率
    :param chunk_duration: 训练或者评估使用的音频长度
    :param min_duration: 最小训练或者评估的音频长度
    :param augmentors: 数据增强方法
    :return:
    """

    def __init__(self, data_list_path,
                 feature_method='melspectrogram',
                 mode='train',
                 sr=16000,
                 chunk_duration=3,
                 min_duration=0.5,
                 augmentors=None):
        super(CustomDataset, self).__init__()
        # 当预测时不需要获取数据
        if data_list_path is not None:
            with open(data_list_path, 'r') as f:
                self.lines = f.readlines()
        self.feature_method = feature_method
        self.mode = mode
        self.sr = sr
        self.chunk_duration = chunk_duration
        self.min_duration = min_duration
        self.augmentors = augmentors

    def __getitem__(self, idx):
        try:
            audio_path, label = self.lines[idx].replace('\n', '').split('\t')
            # 加载并预处理音频
            features = load_audio(audio_path, feature_method=self.feature_method, mode=self.mode, sr=self.sr,
                                  chunk_duration=self.chunk_duration, min_duration=self.min_duration,
                                  augmentors=self.augmentors)
            return features, np.array(int(label), dtype=np.int64)
        except Exception as ex:
            print(f"[{datetime.now()}] 数据: {self.lines[idx]} 出错，错误信息: {ex}", file=sys.stderr)
            rnd_idx = np.random.randint(self.__len__())
            return self.__getitem__(rnd_idx)

    def __len__(self):
        return len(self.lines)

    @property
    def input_size(self):
        if self.feature_method == 'melspectrogram':
            return 80
        elif self.feature_method == 'spectrogram':
            return 201
        else:
            raise Exception(f'预处理方法 {self.feature_method} 不存在！')



#### AudioFeatureExtraction

In [11]:
class AudioFeatureExtraction:
    def __init__(self, model_local_pth, audio_duration=3, feature_method='melspectrogram', ):
        self.use_model = ''
        self.audio_duration = audio_duration
        self.feature_method = feature_method
        self.resume = model_local_pth
        self.model = None
        self.device = None
        self.load_model()

    def load_model(self):
        dataset = CustomDataset(data_list_path=None, feature_method=self.feature_method)
        ecapa_tdnn = EcapaTdnn(input_size=dataset.input_size)
        self.model = SpeakerIdetification(backbone=ecapa_tdnn)
        self.device = torch.device("cuda")
        self.model.to(self.device)

        # 加载模型
        model_path = os.path.join(self.resume, self.use_model, 'model.pth')
        model_dict = self.model.state_dict()
        param_state_dict = torch.load(model_path)
        for name, weight in model_dict.items():
            if name in param_state_dict.keys():
                if list(weight.shape) != list(param_state_dict[name].shape):
                    param_state_dict.pop(name, None)
        self.model.load_state_dict(param_state_dict, strict=False)
        print(f"成功加载模型参数和优化方法参数：{model_path}")
        self.model.eval()

    def infer(self, audio_path):
        data = load_audio(audio_path, mode='infer', feature_method=self.feature_method,
                          chunk_duration=self.audio_duration)
        data = data[np.newaxis, :]
        data = torch.tensor(data, dtype=torch.float32, device=self.device)
        feature = self.model.backbone(data)
        return feature.data.cpu().numpy()

    def extract_features(self, root_dir):
        sub_dirs = get_subdir(root_dir)

        for dir in sub_dirs:
            voice_files = get_filename(os.path.join(dir, 'voice'))
            for file, pth in voice_files:
                new_dir = os.path.join(dir, 'feature')
                os.makedirs(new_dir, exist_ok=True)
                feature = self.infer(pth)[0]
                with open(f"{new_dir}/{file}.pkl", "wb") as f:
                    pickle.dump(feature, f)
        print('音频特征提取完成')

#### 运行2 音频embedding生成

In [12]:
# 模型参数在第一部分 自定义config参数
# audio_feature_extractor = AudioFeatureExtraction(audio_config['audio_model_pth'])
# audio_feature_extractor.extract_features(audio_config['audio_out_dir'])

成功加载模型参数和优化方法参数：/content/drive/MyDrive/GPTData/voicemodel/model.pth
音频特征提取完成


### 识别台本角色

#### 导包

In [13]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score
from scipy.spatial.distance import cosine

#### My_Classifier

In [14]:
class My_Classifier:
    def __init__(self, feature, labels):
        self.feature = feature
        self.labels = labels

    def predict(self, x):
        min_dist = float('inf')
        predicted_label = None

        for i, f in enumerate(self.feature):
            dist = cosine(x, f)
            if dist < min_dist:
                min_dist = dist
                predicted_label = self.labels[i]

        return predicted_label, min_dist

#### AudioClassification

In [16]:
class AudioClassification:
    def __init__(self, audio_roles_dir,srt_out_dir,audio_out_dir):
        self.audio_roles_dir = audio_roles_dir
        self.srt_out_dir = srt_out_dir

        self.audio_first_dir = get_first_subdir(audio_out_dir)
        self.candidate_path = self.audio_first_dir[:]

        self.roles, self.roles_list = self.get_roles_list()
        self.features, self.labels = self.get_features()
        self.feat_sel, self.label_sel = self.get_feat_sel()
        self.my_classifier = My_Classifier(self.feat_sel, self.label_sel)



    def get_roles_list(self):
        roles = os.listdir(self.audio_roles_dir)

        """
        roles
            春日 谷口 长门 新川先生 多丸裕 朝比奈 朝仓 ...
        roles_list
            ['209_或者这座岛有没有被当地人称为「什么什么岛」的传闻？.wav',
            '106_我要喝100%果汁.wav', '121_你睡什么觉啊  笨蛋.wav',
        """
        roles_list = []
        roles_list_full = []
        for role in roles:

            sub_list = os.listdir(os.path.join(self.audio_roles_dir,role))
            roles_list.append(sub_list)

            full_name_list = [self.audio_roles_dir+role+'/'+file for file in sub_list]
            roles_list_full.append(full_name_list)

        return roles, roles_list



    def get_features(self):
        features = []
        labels = []
        dim = 0
        count = 0

        for role in self.roles:
            print(role,end='')
            for file in self.roles_list[self.roles.index(role)]:
                deal_flag = False
                for candidate in self.candidate_path: #'/mnt/sda/baidu_disk/lg/scixing/Haruhi ep1'
                    candidate_fname = os.path.join(candidate, 'voice')
                    if os.path.exists(candidate_fname):
                        deal_flag = True
                        feature_fname = os.path.join(candidate,'feature',file) +'.pkl'
                        break

                if deal_flag == False:
                    # print('warning!', file, 'not found')
                    continue

                if not os.path.exists(feature_fname):
                    # print('warning!', feature_fname, 'not found')
                    continue

                    # pinkle load feature_fname
                with open(feature_fname, 'rb') as f:
                    feature = pickle.load(f)

                count += 1

                # append numpy array feature into numpy matrix features
                if dim == 0:
                    features = feature
                    dim = feature.shape[0]
                    # print(dim)
                else:
                    features = np.vstack((features, feature))

                labels.append(role)

                # print(feature_fname,'found')

            # break
        return features, labels

    def knn_test(self):
        """
        feature是一个N*D的numpy矩阵，每行存储了一个D维特征 labels是一个python的list of string，表示每行对应的数据的标签。
        我想验证这批数据使用K近邻分类，在10折交叉时的准确率，请用python为我实现。
        """
        k = 1
        knn = KNeighborsClassifier(n_neighbors=k, metric='cosine')

        features = np.array(self.features)

        labels = np.array(self.labels)

        cv_accuracy = cross_val_score(knn, features, labels, cv=5)

        for fold, accuracy in enumerate(cv_accuracy,1):
            print(f"Fold {fold}: {accuracy}")

        # 打印平均准确率
        mean_accuracy = np.mean(cv_accuracy)
        print(f"Average Accuracy: {mean_accuracy}")

    def gather_feature_label(self,roles, roles_list):
        features = []
        labels = []
        dim = 0

        count = 0

        for role in roles:
            print(role,end=' ')

            for file in roles_list[roles.index(role)]:
                # print(file)

                deal_flag = False

                for candidate in self.candidate_path:

                    candidate_fname = os.path.join(candidate,'voice',file)

                    if os.path.exists(candidate_fname):
                        # print(candidate_fname,'found')
                        deal_flag = True
                        feature_fname = os.path.join(candidate,'feature',file) + '.pkl'
                        break

                if deal_flag == False:
                    print('warning!',file,'not found')
                    continue

                if not os.path.exists(feature_fname):
                    print('warning!',feature_fname,'not found')
                    continue

                # pinkle load feature_fname
                with open(feature_fname,'rb') as f:
                    feature = pickle.load(f)

                count += 1

                # append numpy array feature into numpy matrix features
                if dim == 0:
                    features = feature
                    dim = feature.shape[0]
                    # print(dim)
                else:
                    features = np.vstack((features,feature))

                labels.append(role)

        return features, labels


    def get_feat_sel(self):
        roles_sel = []
        roles_list_sel = []

        M = 8

        for role in self.roles[:]:
            wav_list = self.roles_list[self.roles.index(role)]

            # if len(wav_list) < M:
                # continue

            # random pick 5 element from wav_list
            random.shuffle(wav_list)
            # wav_list = wav_list[:]

            roles_sel.append(role)
            roles_list_sel.append(wav_list)

        # print(roles)
        # print(roles_sel)


        feat_sel, label_sel = self.gather_feature_label(roles_sel,roles_list_sel)
        return feat_sel, label_sel

    def get_sel_predict(self):


        corrent_dists = []
        wrong_dists = []

        for i in range(len(self.labels)):
            # read i-th row from features, save as feat
            feat = self.features[i, :]
            # read i-th row from labels, save as label
            label = self.labels[i]

            # predict label of i-th row
            predicted_label, distance = self.my_classifier.predict(feat)

            # if distance < 1e-3:
            #     continue

            if label == predicted_label:
                corrent_dists.append(distance)
            else:
                wrong_dists.append(distance)



    def get_pridict(self):
        threshold_certain = 0.4
        threshold_doubt = 0.6
        for idx,feature_folder in enumerate(self.candidate_path):
            name = feature_folder.split('/')[-1]
            save_name = os.path.join(self.srt_out_dir,f'{name}.txt')
            feature_folder = os.path.join(feature_folder,"feature")

            file_list = os.listdir(feature_folder)

            N_files = len( os.listdir(feature_folder) )
            N_files += 100

            with open(save_name, "w", encoding="utf-8") as f_out:

                for id in range(N_files):

                    deal_file_name = '';

                    for file in file_list:
                        if file.startswith(str(id) + '_') and file.endswith('.wav.pkl'):
                            deal_file_name = file
                            id_str = file.split('_')[1]
                            break

                    if deal_file_name == '':
                        # print('file not found')
                        continue

                    full_file_name = os.path.join(feature_folder, deal_file_name)

                    with open(full_file_name,'rb') as f:
                        feature = pickle.load(f)

                    predicted_label, distance = self.my_classifier.predict(feature)

                    role_name = ''

                    if distance < threshold_certain:
                        role_name = predicted_label
                    elif distance < threshold_doubt:
                        role_name = '(可能)' + predicted_label

                    output_str = role_name + ':「' +  id_str[:-8] + '」'

                    # print(output_str)
                    f_out.write(output_str + "\n")

                # break

#### 运行3 台本识别

In [17]:
audio_classification = AudioClassification(audio_config['audio_roles_dir'],
                                               srt_config['srt_out_dir'],
                                               audio_config['audio_out_dir'])



In [18]:
audio_classification.get_pridict()