In [45]:
model_path = '/home/lidonghaowsl/develop/vesc2025/algo/final_algo/EMBD/student_k4_diydataset_xinyuanlaisi.tflite'


In [46]:
import multiprocessing as mp
import numpy as np
from scipy.fft import rfft, fft
import torch

In [47]:
class MFCCProcessor:
    """
    Python implementation of MFCC (Mel-Frequency Cepstral Coefficients)
    Based on the RISC-V DSP library implementation
    """
    
    def __init__(self, fft_len=256, nb_mel_filters=40, nb_dct_outputs=13, use_cfft=False):
        """
        Initialize MFCC processor
        
        Args:
            fft_len: FFT length (should match window length)
            nb_mel_filters: Number of mel filter banks
            nb_dct_outputs: Number of DCT outputs (MFCC coefficients)
            use_cfft: Whether to use CFFT instead of RFFT (default: False)
        """
        self.fft_len = fft_len
        self.nb_mel_filters = nb_mel_filters
        self.nb_dct_outputs = nb_dct_outputs
        self.use_cfft = use_cfft
        
        # Initialize coefficient arrays - to be filled with actual data
        self.window_coefs = None
        self.filter_pos = None
        self.filter_lengths = None
        self.filter_coefs = None
        self.dct_coefs = None
        
        self._load_coefficients()
    
    def _load_coefficients(self):
        """Load pre-computed coefficients"""
        
        # Hanning window coefficients (256 points)
        # TODO: Fill with mfcc_window_coefs_hann256 data
        self.window_coefs = np.array([
            # Insert mfcc_window_coefs_hann256 values here
            # Format: [0.000000, 0.000151, 0.000602, ...]
            0.000000,0.000151,0.000602,0.001355,0.002408,0.003760,0.005412,0.007361,0.009607,0.012149,
            0.014984,0.018112,0.021530,0.025236,0.029228,0.033504,0.038060,0.042895,0.048005,0.053388,
            0.059039,0.064957,0.071136,0.077573,0.084265,0.091208,0.098396,0.105827,0.113495,0.121396,
            0.129524,0.137876,0.146447,0.155230,0.164221,0.173414,0.182803,0.192384,0.202150,0.212096,
            0.222215,0.232501,0.242949,0.253551,0.264302,0.275194,0.286222,0.297379,0.308658,0.320052,
            0.331555,0.343159,0.354858,0.366644,0.378510,0.390449,0.402455,0.414519,0.426635,0.438795,
            0.450991,0.463218,0.475466,0.487729,0.500000,0.512271,0.524534,0.536782,0.549009,0.561205,
            0.573365,0.585481,0.597545,0.609551,0.621490,0.633356,0.645142,0.656841,0.668445,0.679947,
            0.691342,0.702621,0.713778,0.724806,0.735698,0.746449,0.757051,0.767499,0.777785,0.787904,
            0.797850,0.807616,0.817197,0.826586,0.835779,0.844770,0.853553,0.862124,0.870476,0.878604,
            0.886505,0.894173,0.901604,0.908792,0.915735,0.922427,0.928864,0.935044,0.940961,0.946612,
            0.951995,0.957105,0.961940,0.966496,0.970772,0.974764,0.978470,0.981888,0.985016,0.987851,
            0.990393,0.992639,0.994588,0.996240,0.997592,0.998645,0.999398,0.999849,1.000000,0.999849,
            0.999398,0.998645,0.997592,0.996240,0.994588,0.992639,0.990393,0.987851,0.985016,0.981888,
            0.978470,0.974764,0.970772,0.966496,0.961940,0.957105,0.951995,0.946612,0.940961,0.935044,
            0.928864,0.922427,0.915735,0.908792,0.901604,0.894173,0.886505,0.878604,0.870476,0.862124,
            0.853553,0.844770,0.835779,0.826586,0.817197,0.807616,0.797850,0.787904,0.777785,0.767499,
            0.757051,0.746449,0.735698,0.724806,0.713778,0.702621,0.691342,0.679947,0.668445,0.656841,
            0.645142,0.633356,0.621490,0.609551,0.597545,0.585481,0.573365,0.561205,0.549009,0.536782,
            0.524534,0.512271,0.500000,0.487729,0.475466,0.463218,0.450991,0.438795,0.426635,0.414519,
            0.402455,0.390449,0.378510,0.366644,0.354858,0.343159,0.331555,0.320052,0.308658,0.297379,
            0.286222,0.275194,0.264302,0.253551,0.242949,0.232501,0.222215,0.212096,0.202150,0.192384,
            0.182803,0.173414,0.164221,0.155230,0.146447,0.137876,0.129524,0.121396,0.113495,0.105827,
            0.098396,0.091208,0.084265,0.077573,0.071136,0.064957,0.059039,0.053388,0.048005,0.042895,
            0.038060,0.033504,0.029228,0.025236,0.021530,0.018112,0.014984,0.012149,0.009607,0.007361,
            0.005412,0.003760,0.002408,0.001355,0.000602,0.000151
        ], dtype=np.float32)
        
        # MEL filter positions (40 filters)
        # TODO: Fill with mfcc_filter_pos_mel40 data
        self.filter_pos = np.array([
            # Insert mfcc_filter_pos_mel40 values here
            # Format: [1, 2, 3, 4, 5, 6, 8, 9, ...]
            1,2,3,4,5,6,8,9,11,12,
            14,15,17,19,21,23,25,27,30,32,
            35,38,40,43,46,50,53,57,60,64,
            68,73,77,82,87,92,97,103,109,115,
        ], dtype=np.uint32)
        
        # MEL filter lengths (40 filters)
        # TODO: Fill with mfcc_filter_len_mel40 data
        self.filter_lengths = np.array([
            # Insert mfcc_filter_len_mel40 values here
            # Format: [2, 2, 2, 2, 3, 3, 3, 3, ...]
            2,2,2,2,3,3,3,3,3,3,
            3,4,4,4,4,4,5,5,5,6,
            5,5,6,7,7,7,7,7,8,9,
            9,9,10,10,10,11,12,12,13,13,
        ], dtype=np.uint32)
        
        # MEL filter coefficients
        # TODO: Fill with mfcc_filter_coefs_mel40 data
        self.filter_coefs = np.array([
            # Insert mfcc_filter_coefs_mel40 values here
            # Format: [0.940365, 0.158628, 0.841372, ...]
            0.940365,0.158628,0.841372,0.293816,0.706184,0.462403,0.537597,0.661904,0.338096,0.890104,
            0.145015,0.109896,0.854985,0.424850,0.575150,0.727995,0.052989,0.272005,0.947011,0.398503,
            0.601497,0.763326,0.146352,0.236674,0.853648,0.546566,0.453434,0.963036,0.394905,0.036964,
            0.605095,0.841380,0.301730,0.158620,0.698270,0.775275,0.261386,0.224725,0.738614,0.759477,
            0.269002,0.240523,0.730998,0.789451,0.320349,0.210549,0.679651,0.861250,0.411736,0.138750,
            0.588264,0.971416,0.539920,0.116902,0.028584,0.460080,0.883098,0.702035,0.295011,0.297965,
            0.704989,0.895539,0.503343,0.118164,0.104461,0.496657,0.881836,0.739755,0.367882,0.002322,
            0.260245,0.632118,0.997678,0.642866,0.289313,0.357134,0.710687,0.941471,0.599160,0.262206,
            0.058529,0.400840,0.737794,0.930444,0.603716,0.281873,0.069556,0.396284,0.718127,0.964769,
            0.652268,0.344238,0.040553,0.035231,0.347732,0.655761,0.959447,0.741092,0.445738,0.154382,
            0.258908,0.554262,0.845618,0.866915,0.583236,0.303246,0.026850,0.133085,0.416764,0.696754,
            0.973150,0.753958,0.484481,0.218335,0.246042,0.515519,0.781665,0.955439,0.695714,0.439085,
            0.185479,0.044561,0.304286,0.560915,0.814521,0.934825,0.687055,0.442105,0.199909,0.065175,
            0.312945,0.557895,0.800091,0.960408,0.723542,0.489253,0.257486,0.028188,0.039592,0.276458,
            0.510747,0.742514,0.971812,0.801306,0.576789,0.354590,0.134660,0.198694,0.423211,0.645410,
            0.865340,0.916954,0.701428,0.488037,0.276741,0.067498,0.083046,0.298572,0.511963,0.723259,
            0.932502,0.860269,0.655015,0.451700,0.250287,0.050740,0.139731,0.344985,0.548300,0.749713,
            0.949260,0.853026,0.657111,0.462962,0.270549,0.079839,0.146974,0.342889,0.537038,0.729451,
            0.920161,0.890805,0.703415,0.517642,0.333459,0.150837,0.109195,0.296585,0.482358,0.666541,
            0.849163,0.969752,0.790177,0.612087,0.435458,0.260267,0.086489,0.030248,0.209823,0.387913,
            0.564542,0.739733,0.913511,0.914103,0.743086,0.573416,0.405074,0.238037,0.072286,0.085897,
            0.256914,0.426584,0.594926,0.761963,0.927714,0.907802,0.744564,0.582555,0.421755,0.262148,
            0.103715,0.092198,0.255436,0.417445,0.578245,0.737852,0.896285,0.946440,0.790305,0.635294,
            0.481391,0.328580,0.176846,0.026174,0.053560,0.209695,0.364706,0.518609,0.671420,0.823154,
            0.973826,0.876550,0.727957,0.580384,0.433814,0.288236,0.143636
        ], dtype=np.float32)
        
        # DCT coefficients matrix (13 x 40)
        # TODO: Fill with mfcc_dct_coefs_dct13 data
        dct_coefs_flat = np.array([
            # Insert mfcc_dct_coefs_dct13 values here
            # Format: [0.223607, 0.223607, 0.223607, ...]
            0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,
            0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,
            0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,
            0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,0.223607,
            0.223434,0.222057,0.219310,0.215212,0.209786,0.203067,0.195096,0.185922,0.175602,0.164200,
            0.151784,0.138434,0.124229,0.109259,0.093615,0.077394,0.060696,0.043624,0.026282,0.008779,
            -0.008779,-0.026282,-0.043624,-0.060696,-0.077394,-0.093615,-0.109259,-0.124229,-0.138434,-0.151784,
            -0.164200,-0.175602,-0.185922,-0.195096,-0.203067,-0.209786,-0.215212,-0.219310,-0.222057,-0.223434,
            0.222917,0.217429,0.206586,0.190656,0.170032,0.145221,0.116834,0.085571,0.052200,0.017544,
            -0.017544,-0.052200,-0.085571,-0.116834,-0.145221,-0.170032,-0.190656,-0.206586,-0.217429,-0.222917,
            -0.222917,-0.217429,-0.206586,-0.190656,-0.170032,-0.145221,-0.116834,-0.085571,-0.052200,-0.017544,
            0.017544,0.052200,0.085571,0.116834,0.145221,0.170032,0.190656,0.206586,0.217429,0.222917,
            0.222057,0.209786,0.185922,0.151784,0.109259,0.060696,0.008779,-0.043624,-0.093615,-0.138434,
            -0.175602,-0.203067,-0.219310,-0.223434,-0.215212,-0.195096,-0.164200,-0.124229,-0.077394,-0.026282,
            0.026282,0.077394,0.124229,0.164200,0.195096,0.215212,0.223434,0.219310,0.203067,0.175602,
            0.138434,0.093615,0.043624,-0.008779,-0.060696,-0.109259,-0.151784,-0.185922,-0.209786,-0.222057,
            0.220854,0.199235,0.158114,0.101515,0.034980,-0.034980,-0.101515,-0.158114,-0.199235,-0.220854,
            -0.220854,-0.199235,-0.158114,-0.101515,-0.034980,0.034980,0.101515,0.158114,0.199235,0.220854,
            0.220854,0.199235,0.158114,0.101515,0.034980,-0.034980,-0.101515,-0.158114,-0.199235,-0.220854,
            -0.220854,-0.199235,-0.158114,-0.101515,-0.034980,0.034980,0.101515,0.158114,0.199235,0.220854,
            0.219310,0.185922,0.124229,0.043624,-0.043624,-0.124229,-0.185922,-0.219310,-0.219310,-0.185922,
            -0.124229,-0.043624,0.043624,0.124229,0.185922,0.219310,0.219310,0.185922,0.124229,0.043624,
            -0.043624,-0.124229,-0.185922,-0.219310,-0.219310,-0.185922,-0.124229,-0.043624,0.043624,0.124229,
            0.185922,0.219310,0.219310,0.185922,0.124229,0.043624,-0.043624,-0.124229,-0.185922,-0.219310,
            0.217429,0.170032,0.085571,-0.017544,-0.116834,-0.190656,-0.222917,-0.206586,-0.145221,-0.052200,
            0.052200,0.145221,0.206586,0.222917,0.190656,0.116834,0.017544,-0.085571,-0.170032,-0.217429,
            -0.217429,-0.170032,-0.085571,0.017544,0.116834,0.190656,0.222917,0.206586,0.145221,0.052200,
            -0.052200,-0.145221,-0.206586,-0.222917,-0.190656,-0.116834,-0.017544,0.085571,0.170032,0.217429,
            0.215212,0.151784,0.043624,-0.077394,-0.175602,-0.222057,-0.203067,-0.124229,-0.008779,0.109259,
            0.195096,0.223434,0.185922,0.093615,-0.026282,-0.138434,-0.209786,-0.219310,-0.164200,-0.060696,
            0.060696,0.164200,0.219310,0.209786,0.138434,0.026282,-0.093615,-0.185922,-0.223434,-0.195096,
            -0.109259,0.008779,0.124229,0.203067,0.222057,0.175602,0.077394,-0.043624,-0.151784,-0.215212,
            0.212663,0.131433,0.000000,-0.131433,-0.212663,-0.212663,-0.131433,-0.000000,0.131433,0.212663,
            0.212663,0.131433,0.000000,-0.131433,-0.212663,-0.212663,-0.131433,-0.000000,0.131433,0.212663,
            0.212663,0.131433,0.000000,-0.131433,-0.212663,-0.212663,-0.131433,-0.000000,0.131433,0.212663,
            0.212663,0.131433,-0.000000,-0.131433,-0.212663,-0.212663,-0.131433,-0.000000,0.131433,0.212663,
            0.209786,0.109259,-0.043624,-0.175602,-0.223434,-0.164200,-0.026282,0.124229,0.215212,0.203067,
            0.093615,-0.060696,-0.185922,-0.222057,-0.151784,-0.008779,0.138434,0.219310,0.195096,0.077394,
            -0.077394,-0.195096,-0.219310,-0.138434,0.008779,0.151784,0.222057,0.185922,0.060696,-0.093615,
            -0.203067,-0.215212,-0.124229,0.026282,0.164200,0.223434,0.175602,0.043624,-0.109259,-0.209786,
            0.206586,0.085571,-0.085571,-0.206586,-0.206586,-0.085571,0.085571,0.206586,0.206586,0.085571,
            -0.085571,-0.206586,-0.206586,-0.085571,0.085571,0.206586,0.206586,0.085571,-0.085571,-0.206586,
            -0.206586,-0.085571,0.085571,0.206586,0.206586,0.085571,-0.085571,-0.206586,-0.206586,-0.085571,
            0.085571,0.206586,0.206586,0.085571,-0.085571,-0.206586,-0.206586,-0.085571,0.085571,0.206586,
            0.203067,0.060696,-0.124229,-0.222057,-0.164200,0.008779,0.175602,0.219310,0.109259,-0.077394,
            -0.209786,-0.195096,-0.043624,0.138434,0.223434,0.151784,-0.026282,-0.185922,-0.215212,-0.093615,
            0.093615,0.215212,0.185922,0.026282,-0.151784,-0.223434,-0.138434,0.043624,0.195096,0.209786,
            0.077394,-0.109259,-0.219310,-0.175602,-0.008779,0.164200,0.222057,0.124229,-0.060696,-0.203067,
            0.199235,0.034980,-0.158114,-0.220854,-0.101515,0.101515,0.220854,0.158114,-0.034980,-0.199235,
            -0.199235,-0.034980,0.158114,0.220854,0.101515,-0.101515,-0.220854,-0.158114,0.034980,0.199235,
            0.199235,0.034980,-0.158114,-0.220854,-0.101515,0.101515,0.220854,0.158114,-0.034980,-0.199235,
            -0.199235,-0.034980,0.158114,0.220854,0.101515,-0.101515,-0.220854,-0.158114,0.034980,0.199235
        ], dtype=np.float32)
        
        # Reshape DCT coefficients to matrix form
        if len(dct_coefs_flat) > 0:
            self.dct_coefs = dct_coefs_flat.reshape(self.nb_dct_outputs, self.nb_mel_filters)
        else:
            self.dct_coefs = np.zeros((self.nb_dct_outputs, self.nb_mel_filters), dtype=np.float32)
    
    def compute_mfcc(self, input_signal):
        """
        Compute MFCC features from input signal
        
        Args:
            input_signal: Input audio signal (numpy array)
            
        Returns:
            mfcc_features: MFCC coefficients (numpy array)
        """
        # Ensure input is the correct length and type
        if len(input_signal) != self.fft_len:
            raise ValueError(f"Input signal length ({len(input_signal)}) must match FFT length ({self.fft_len})")
        
        src = input_signal.astype(np.float32).copy()
        
        # # Step 1: Normalize
        # max_value = np.max(np.abs(src))
        # max_index = np.argmax(np.abs(src))
        
        # # if max_value != 0.0:
        # #     src = src / max_value
        
        # Step 2: Apply window function
        if self.window_coefs is not None and len(self.window_coefs) == self.fft_len:
            src = src * self.window_coefs
        
        # Step 3: Compute spectrum magnitude
        if self.use_cfft:
            # CFFT-based implementation
            # Convert real to complex
            complex_signal = np.zeros(self.fft_len, dtype=np.complex64)
            complex_signal.real = src
            complex_signal.imag = 0.0
            
            # Compute FFT
            fft_result = fft(complex_signal)
            spectrum_mag = np.abs(fft_result)
        else:
            # RFFT-based implementation (default)
            fft_result = rfft(src)
            
            # Unpack real values (mimic the C code behavior)
            tmp = np.zeros(self.fft_len + 2, dtype=np.float32)
            tmp[0] = fft_result[0].real  # DC component
            
            # Pack real and imaginary parts
            for i in range(1, len(fft_result)):
                if i < self.fft_len // 2:
                    tmp[2*i] = fft_result[i].real
                    tmp[2*i + 1] = fft_result[i].imag
            
            # Handle Nyquist frequency
            if len(fft_result) > self.fft_len // 2:
                tmp[self.fft_len] = fft_result[-1].real
                tmp[self.fft_len + 1] = 0.0
            
            tmp[1] = 0.0  # Set imaginary part of DC to 0
            
            # Compute magnitude
            spectrum_mag = np.zeros(self.fft_len, dtype=np.float32)
            for i in range(self.fft_len):
                if i == 0:
                    spectrum_mag[i] = abs(tmp[0])
                elif i < self.fft_len // 2:
                    spectrum_mag[i] = np.sqrt(tmp[2*i]**2 + tmp[2*i + 1]**2)
                else:
                    spectrum_mag[i] = spectrum_mag[self.fft_len - i]
        
        # # Restore original scale if normalization was applied
        # if max_value != 0.0:
        #     spectrum_mag = spectrum_mag * max_value
        # spectrum_mag *= 32.0
        # Step 4: Apply MEL filters
        mel_outputs = np.zeros(self.nb_mel_filters, dtype=np.float32)
        coef_idx = 0
        
        if (self.filter_pos is not None and self.filter_lengths is not None and 
            self.filter_coefs is not None):
            
            for i in range(self.nb_mel_filters):
                pos = self.filter_pos[i]
                length = self.filter_lengths[i]
                
                # Compute dot product
                result = 0.0
                for j in range(length):
                    if pos + j < len(spectrum_mag) and coef_idx + j < len(self.filter_coefs):
                        result += spectrum_mag[pos + j] * self.filter_coefs[coef_idx + j]
                
                mel_outputs[i] = result
                coef_idx += length
        
        # Step 5: Compute logarithm
        # Add small offset to avoid log(0)
        mel_outputs = mel_outputs + 1.0e-6
        log_mel = np.log(mel_outputs)
        
        # Step 6: Apply DCT transformation
        if self.dct_coefs is not None:
            mfcc_features = np.dot(self.dct_coefs, log_mel)
        else:
            mfcc_features = log_mel[:self.nb_dct_outputs]
        # ② 计算帧能量并替换 C0
        # log_energy = np.log(np.sum(src**2) + 1e-30)
        # mfcc_features = np.dot(self.dct_coefs, log_mel)
        # mfcc_features[0] = log_energy
        return mfcc_features


# 工厂函数：返回一个新的 MFCCProcessor 实例
def mfcc_processor_factory():
    return MFCCProcessor(
        fft_len=256,
        nb_mel_filters=40,
        nb_dct_outputs=13,
        use_cfft=True
    )

# 子进程中调用的函数
def process_row(row_data):
    segment_size = 256
    num_segments = 31
    num_features = 13
    row_features = np.zeros((num_segments, num_features), dtype=np.float32)
    
    # 在子进程中创建 mfcc_processor
    mfcc_processor = mfcc_processor_factory()
    
    for j in range(num_segments):
        start = j * segment_size
        end = start + segment_size
        segment = row_data[start:end]
        mfcc = mfcc_processor.compute_mfcc(segment)
        row_features[j, :] = mfcc
    
    return row_features
def compute_batch_mfcc_features(input_array):
    """
    单线程处理 (n, 8000) 音频数组，返回 (n, 31, 13) MFCC 特征。
    """
    results = [process_row(row) for row in input_array]
    return np.stack(results, axis=0)
def compute_batch_mfcc_features_parallel(input_array):
    """
    多进程并行处理 (n, 8000) 音频数组，返回 (n, 31, 13) 特征。
    """
    n_rows = input_array.shape[0]

    # 启动进程池，使用所有可用核心
    with mp.Pool(processes=mp.cpu_count()) as pool:
        results = pool.map(process_row, [input_array[i] for i in range(n_rows)])

    # 拼接结果
    return np.stack(results, axis=0)

In [48]:
import numpy as np
import math

def hierarchical_embedding_aggregation(mfcc_batch, speaker_id, enroll_db, num_samples=None):
    """
    分层聚合embedding的方法（随机采样版本）：
    1. 反复随机从n个向量中抽取sqrt(n)个向量
    2. 对每次抽取的向量组做平均
    3. 对所有平均结果做单位化后求和
    """
    n = len(mfcc_batch)
    
    # 先计算所有的embeddings
    all_embeddings = []
    for i in range(n):
        emb = infer_one_mfcc(mfcc_batch[i][..., 0])  # (1, 128)
        all_embeddings.append(emb[0])
    
    # 计算采样参数
    sample_size = int(math.ceil(math.sqrt(n)))  # 每次采样的向量数量
    if num_samples is None:
        # 默认采样次数：确保充分利用数据，建议是n到2n之间
        num_samples = max(n, sample_size * 3)  # 至少采样n次
    
    # 第一阶段：随机采样并计算平均
    sampled_averages = []
    
    for _ in range(num_samples):
        # 随机选择sample_size个不重复的索引
        if sample_size >= n:
            # 如果采样数量 >= 总数量，直接使用所有数据
            selected_indices = list(range(n))
        else:
            selected_indices = np.random.choice(n, size=sample_size, replace=False)
        
        # 计算选中embeddings的平均
        selected_embeddings = [all_embeddings[i] for i in selected_indices]
        avg_emb = np.mean(selected_embeddings, axis=0)
        sampled_averages.append(avg_emb)
    
    # 第二阶段：对平均结果做单位化后求和
    normalized_averages = []
    for avg_emb in sampled_averages:
        # 单位化
        normed = avg_emb / np.linalg.norm(avg_emb)
        normalized_averages.append(normed)
    
    # 求和得到最终embedding
    final_emb = np.sum(normalized_averages, axis=0)
    
    # 存储到数据库
    enroll_db[speaker_id] = final_emb
    
    return final_emb

In [49]:
import numpy as np
import torchaudio
from sklearn.metrics.pairwise import cosine_similarity

# === 用于保存每个注册说话人的平均向量 ===
enroll_db = {}

# === 调用你提供的 MFCC 批处理函数 ===
# from your_module import compute_batch_mfcc_features_parallel  # ← 替换为你的模块导入
# 如果你已经在 notebook 中定义了就不用导入了

# === 推理辅助函数 ===
def quantize_input(data, input_details):
    scale, zero_point = input_details[0]['quantization']
    return np.round(data / scale + zero_point).astype(np.int8)

def dequantize_output(data, output_details):
    scale, zero_point = output_details[0]['quantization']
    return (data.astype(np.float32) - zero_point) * scale

# === 加载模型 ===
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

def infer_one_mfcc(mfcc):
    """
    对 (31,13) mfcc 做 TFLite 推理 → (128,)
    """
    mfcc = mfcc[:, 1:] if mfcc.shape[1] == 13 else mfcc  # 去掉 DC，如果还在
    mfcc = mfcc[None, ..., None].astype(np.float32)  # (1,31,12,1)
    input_data = quantize_input(mfcc, input_details)
    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    output_int8 = interpreter.get_tensor(output_details[0]['index'])
    return dequantize_output(output_int8, output_details)  # (1,128)

# === 注册函数 ===
def register_speaker(speaker_id, wav, sr, enroll_db):
    """
    注册说话人，提取 MFCC → 嵌入 → 平均
    """
    assert sr == 8000, "当前注册流程仅支持8kHz音频"
    samples = wav[0]
    n = samples.shape[0] // 8000
    if n == 0:
        raise ValueError("音频太短，无法切出1秒段")
    batch = samples[:n*8000].reshape(n, 8000)
    mfcc_batch = compute_batch_mfcc_features(np.array(batch))  # (n, 31, 13)
    mfcc_batch = mfcc_batch[:, :, 1:]  # 去掉 DC → (n,31,12)
    mfcc_batch = mfcc_batch[..., None]  # → (n,31,12,1)

    # embeddings = []
    # for i in range(n):
    #     emb = infer_one_mfcc(mfcc_batch[i][..., 0])  # (1, 128)
    #     normed = emb[0] / np.linalg.norm(emb[0])     # 单位化
    #     embeddings.append(normed)
    # sum_emb = np.sum(embeddings, axis=0)   
    # # avg_emb = np.mean(embeddings, axis=0)            # 所有单位向量求平均
    
    # # avg_emb = avg_emb / np.linalg.norm(avg_emb)      # 再归一化成单位向量（可选，看你是否需要）
    # enroll_db[speaker_id] = sum_emb
    # print(f"注册成功：{speaker_id}，共{n}段，有效嵌入 shape={sum_emb.shape}")
    sum_emb = hierarchical_embedding_aggregation(mfcc_batch, speaker_id, enroll_db, num_samples=1000)
    return sum_emb

# === 识别函数 ===
def identify_speaker(wav, sr, enroll_db):
    assert sr == 8000, "当前识别流程仅支持8kHz音频"
    samples = wav[0]
    n = samples.shape[0] // 8000
    if n == 0:
        raise ValueError("音频太短")
    batch = samples[:n*8000].reshape(n, 8000)
    mfcc_batch = compute_batch_mfcc_features_parallel(np.array(batch))  # (n,31,13)
    mfcc_batch = mfcc_batch[:, :, 1:]  # 去掉 DC
    mfcc_batch = mfcc_batch[..., None]

    embeddings = []
    for i in range(n):
        emb = infer_one_mfcc(mfcc_batch[i][..., 0])
        embeddings.append(emb[0])
    test_emb = np.mean(embeddings, axis=0)

    # sims = {sid: cosine_similarity(test_emb.reshape(1, -1), ref.reshape(1, -1))[0][0]
    #         for sid, ref in enroll_db.items()}

    # 计算test_emb的模长
    test_norm = np.linalg.norm(test_emb)

    sims = {sid: np.dot(test_emb, ref) / test_norm
            for sid, ref in enroll_db.items()}
    best = max(sims.items(), key=lambda x: x[1])
    print(f"识别为：{best[0]}，相似度：{best[1]:.3f}")
    return best

    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    


In [50]:
def identify_speaker_streaming(pcm_data: np.ndarray, enroll_db: dict, threshold=0.7, m: int = 1):
    """
    对输入的 8kHz PCM 音频数据，按每秒切分，使用滑动平均嵌入判断说话人。

    Args:
        pcm_data: np.ndarray, shape=(samples,), 单通道 int16 或 float32
        enroll_db: 注册数据库 {speaker_id: embedding}
        threshold: 相似度阈值，小于此值认为是 Unknown
        m: 使用过去 m 秒的嵌入做平均，默认 1 表示只用当前秒

    Returns:
        List[str]: 每一秒的判断结果
    """
    if isinstance(pcm_data, torch.Tensor):
        pcm_data = pcm_data.numpy()
    if len(pcm_data.shape) > 1:
        pcm_data = pcm_data[0]  # 只取一个通道

    n = pcm_data.shape[0] // 8000
    if n == 0:
        raise ValueError("PCM 数据太短")

    batch = pcm_data[:n * 8000].reshape(n, 8000)
    mfcc_batch = compute_batch_mfcc_features(np.array(batch))  # (n,31,13)
    mfcc_batch = mfcc_batch[:, :, 1:]  # 去掉 DC
    mfcc_batch = mfcc_batch[..., None]

    emb_history = []
    result = []
    for i in range(n):
        emb = infer_one_mfcc(mfcc_batch[i][..., 0])[0]  # (128,)
        emb_history.append(emb)

        # 滑动平均
        recent_embs = emb_history[max(0, i - m + 1): i + 1]
        avg_emb = np.mean(recent_embs, axis=0)

        sims = {
            sid: cosine_similarity(avg_emb.reshape(1, -1), ref.reshape(1, -1))[0][0]
            for sid, ref in enroll_db.items()
        }
        best_id, best_sim = max(sims.items(), key=lambda x: x[1])
        result.append(best_id if best_sim >= threshold else "Unknown")

    return result

In [51]:
import pickle

# === 用于保存每个注册说话人的平均向量 ===
enroll_db = {}
file_path = f"Xin-1.wav"
pcm, _ = torchaudio.load(file_path)# pcm = remove_silence_8k_pcm(pcm)
embd = register_speaker('XiaoXin', pcm, 8000, enroll_db)
file_path = f"Yuan-1.wav"
pcm, _ = torchaudio.load(file_path)# pcm = remove_silence_8k_pcm(pcm)
embd = register_speaker('XiaoYuan', pcm, 8000, enroll_db)
file_path = f"Si-1.wav"
pcm, _ = torchaudio.load(file_path)# pcm = remove_silence_8k_pcm(pcm)
embd = register_speaker('XiaoSi', pcm, 8000, enroll_db)
file_path = f"Lai-1.wav"
pcm, _ = torchaudio.load(file_path)# pcm = remove_silence_8k_pcm(pcm)
embd = register_speaker('XiaoLai', pcm, 8000, enroll_db)

  mfcc_batch = compute_batch_mfcc_features(np.array(batch))  # (n, 31, 13)


In [52]:
from collections import Counter

def count_categories(data_list):
    """
    统计列表中每个类别出现的次数

    参数:
        data_list (list): 类别组成的列表

    返回:
        dict: 类别及其对应的计数
    """
    return dict(Counter(data_list))

In [53]:
pcm , sr = torchaudio.load('Xin-2.wav')
assert sr == 8000

labels = identify_speaker_streaming(pcm.numpy(),enroll_db,threshold=0.7,m=5)
result = count_categories(labels)
print(result)

{'Unknown': 4, 'XiaoSi': 1, 'XiaoXin': 148}


In [54]:
pcm , sr = torchaudio.load('Yuan-2.wav')
assert sr == 8000

labels = identify_speaker_streaming(pcm.numpy(),enroll_db,threshold=0.7,m=5)
result = count_categories(labels)
print(result)

{'Unknown': 5, 'XiaoYuan': 132}


In [55]:
pcm , sr = torchaudio.load('Si-2.wav')
assert sr == 8000

labels = identify_speaker_streaming(pcm.numpy(),enroll_db,threshold=0.7,m=5)
result = count_categories(labels)
print(result)

{'Unknown': 6, 'XiaoSi': 150}


In [56]:
pcm , sr = torchaudio.load('Lai-2.wav')
assert sr == 8000

labels = identify_speaker_streaming(pcm.numpy(),enroll_db,threshold=0.7,m=5)
result = count_categories(labels)
print(result)

{'Unknown': 5, 'XiaoLai': 148, 'XiaoYuan': 1}


In [57]:
for key, value in enroll_db.items():
    print(f'{key}:')
    print(",".join(map(str, value)))

XiaoXin:
-3.699422,87.3447,46.563026,-43.793304,89.66671,-101.02349,-19.723362,-97.49725,10.242561,118.54056,34.218098,-121.03215,-40.551174,50.92178,-81.7186,73.594505,153.16034,53.471954,71.334984,-23.26353,-79.14674,62.385242,60.181408,-58.599174,-21.358622,26.40811,58.648365,67.21203,51.10134,20.134949,-30.209806,-5.1059794,-14.72544,52.21114,61.208317,-68.13978,82.50538,-28.502335,123.20865,-2.5192437,-181.69925,55.196083,19.90288,-79.376236,112.64569,52.514572,-37.711216,42.697964,26.566025,-65.16125,39.664173,-33.410378,-57.345207,108.743286,102.224915,7.5548706,-46.195545,-226.52396,30.16622,-62.388103,-8.654097,-41.88243,71.476974,22.33639,-31.810633,-88.16103,59.304375,-48.400097,-87.3758,-51.49459,24.850143,62.3882,-28.21096,-30.042171,68.869545,14.099492,4.5524964,-27.136549,95.56441,150.82874,-13.297154,30.455515,-78.919624,43.92391,35.941242,116.47799,18.834042,62.621185,-24.385736,90.8056,-76.28048,73.8557,126.56467,28.441467,-82.25712,69.91427,76.14204,24.337841,-5.7052