In [1]:
import torch
import torchaudio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.fftpack import dct
from nnAudio.features.gammatone import Gammatonegram

In [4]:
class GFCCExtractor:
        def __init__(self, 
                 sample_rate = 16000,
                 n_gfcc = 40,
                 n_fft = 2048,
                 hop_length = 512,
                 f_min = 20.0,
                 f_max = None,
                 window = 'hann',
                 center = True,
                 device = 'cuda:0'):
            """
            初始化GFCC提取器
            
            Args:
                sample_rate: 采样率
                n_gfcc: GFCC系数数量
                n_bins: Gammatone滤波器数量
                n_fft: FFT窗口大小
                hop_length: 跳跃长度
                f_min: 最低频率
                f_max: 最高频率 (None时使用sample_rate/2)
                window: 窗口函数类型
                center: 是否将STFT核心置于时间步中心
                device: 计算设备 ('cpu' 或 'cuda')
            """
            self.sample_rate = sample_rate
            self.n_gfcc = n_gfcc
            self.n_fft = n_fft
            self.hop_length = hop_length
            self.f_min = f_min
            self.f_max = f_max if f_max is not None else sample_rate // 2
            self.window = window
            self.center = center
            self.device = device


In [None]:
        # 初始化Gammatonegram变换器
        self.gammatone_transform = Gammatonegram(
            sr=sample_rate,
            n_fft=n_fft,
            n_bins=n_bins,
            hop_length=hop_length,
            window=window,
            center=center,
            fmin=f_min,
            fmax=f_max,
            trainable_bins=False,
            trainable_STFT=False,
            verbose=False
        ).to(device)
    
    def extract_gfcc_features(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        提取GFCC特征
        
        Args:
            waveform: 输入音频波形 [batch_size, channels, samples] 或 [samples]
            
        Returns:
            gfcc_features: GFCC特征 [batch_size, n_gfcc, time_frames]
        """
        # 处理输入维度
        original_shape = waveform.shape
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0).unsqueeze(0)  # [1, 1, samples]
        elif waveform.dim() == 2:
            if waveform.shape[0] == 1:  # [1, samples]
                waveform = waveform.unsqueeze(1)  # [1, 1, samples]
            else:  # [batch, samples]
                waveform = waveform.unsqueeze(1)  # [batch, 1, samples]
        
        # 确保在正确的设备上
        waveform = waveform.to(self.device)
        
        batch_size, channels, samples = waveform.shape
        
        # 存储所有批次和通道的GFCC特征
        all_gfcc = []
        
        for b in range(batch_size):
            batch_gfcc = []
            for c in range(channels):
                # 使用nnAudio计算Gammatone频谱图
                # 输入形状: [1, samples]
                audio_input = waveform[b, c].unsqueeze(0)
                
                # 计算Gammatone频谱图
                # 输出形状: [1, n_bins, time_frames]
                gammatone_spec = self.gammatone_transform(audio_input)
                
                # 转换为对数域
                log_spec = torch.log(gammatone_spec + 1e-10)
                
                # 转换为numpy进行DCT变换
                log_spec_np = log_spec.squeeze(0).cpu().numpy()  # [n_bins, time_frames]
                
                # 应用DCT变换得到GFCC
                gfcc = dct(log_spec_np, type=2, axis=0, norm='ortho')[:self.n_gfcc, :]
                
                batch_gfcc.append(gfcc)
            
            # 合并通道
            batch_gfcc = np.stack(batch_gfcc, axis=0)  # [channels, n_gfcc, time]
            all_gfcc.append(batch_gfcc)
        
        # 合并批次
        all_gfcc = np.stack(all_gfcc, axis=0)  # [batch, channels, n_gfcc, time]
        
        # 转换回torch tensor
        gfcc_tensor = torch.from_numpy(all_gfcc).float()
        
        # 如果只有一个通道，去掉通道维度
        if channels == 1:
            gfcc_tensor = gfcc_tensor.squeeze(1)  # [batch, n_gfcc, time]
        
        # 如果原始输入是1D，去掉批次维度
        if len(original_shape) == 1:
            gfcc_tensor = gfcc_tensor.squeeze(0)  # [n_gfcc, time]
        
        return gfcc_tensor
    
    def compute_delta_features(self, features: torch.Tensor, delta_window: int = 2) -> torch.Tensor:
        """
        计算差分特征(一阶导数)
        
        Args:
            features: 输入特征 [batch_size, n_features, time_frames] 或 [n_features, time_frames]
            delta_window: 差分计算窗口大小
            
        Returns:
            delta_features: 差分特征
        """
        # 在时间维度上进行填充
        padded_features = torch.nn.functional.pad(
            features, (delta_window, delta_window), mode='replicate'
        )
        
        # 计算差分
        delta_features = torch.zeros_like(features)
        denominator = 2 * sum(i**2 for i in range(1, delta_window + 1))
        
        for t in range(features.shape[-1]):
            numerator = sum(
                i * (padded_features[..., t + delta_window + i] - 
                     padded_features[..., t + delta_window - i])
                for i in range(1, delta_window + 1)
            )
            delta_features[..., t] = numerator / denominator
        
        return delta_features
    
    def extract_gfcc_with_delta(self, waveform: torch.Tensor, 
                               include_delta: bool = True, 
                               include_delta2: bool = True) -> torch.Tensor:
        """
        提取GFCC特征及其差分特征
        
        Args:
            waveform: 输入音频波形
            include_delta: 是否包含一阶差分
            include_delta2: 是否包含二阶差分
            
        Returns:
            combined_features: 组合的GFCC特征
        """
        # 提取基础GFCC特征
        gfcc = self.extract_gfcc_features(waveform)
        
        features_list = [gfcc]
        
        if include_delta:
            delta_gfcc = self.compute_delta_features(gfcc)
            features_list.append(delta_gfcc)
        
        if include_delta2:
            if include_delta:
                delta2_gfcc = self.compute_delta_features(delta_gfcc)
            else:
                delta_gfcc = self.compute_delta_features(gfcc)
                delta2_gfcc = self.compute_delta_features(delta_gfcc)
            features_list.append(delta2_gfcc)
        
        # 在特征维度上连接
        combined_features = torch.cat(features_list, dim=-2)
        
        return combined_features
    
    def to(self, device: str):
        """移动模型到指定设备"""
        self.device = device
        self.gammatone_transform = self.gammatone_transform.to(device)
        return self

def save_gfcc_to_csv(gfcc_features: torch.Tensor, 
                     output_path: str, 
                     feature_names: Optional[list] = None) -> None:
    """
    保存GFCC特征到CSV文件
    
    Args:
        gfcc_features: GFCC特征张量 [n_features, time_frames]
        output_path: 输出文件路径
        feature_names: 特征名称列表
    """
    # 转换为numpy数组
    if isinstance(gfcc_features, torch.Tensor):
        if gfcc_features.dim() == 3:  # [batch, features, time]
            gfcc_features = gfcc_features[0]  # 取第一个批次
        gfcc_np = gfcc_features.cpu().numpy()
    else:
        gfcc_np = gfcc_features
    
    # 转置以便每行是一个时间帧
    gfcc_df = pd.DataFrame(gfcc_np.T)
    
    # 设置列名
    if feature_names is None:
        n_features = gfcc_np.shape[0]
        feature_names = [f'gfcc_{i}' for i in range(n_features)]
    
    gfcc_df.columns = feature_names
    
    # 保存到CSV
    gfcc_df.to_csv(output_path, index=False)
    print(f"GFCC特征已保存到: {output_path}")

def plot_gfcc_features(gfcc_features: torch.Tensor, 
                      sample_rate: int = 16000,
                      hop_length: int = 512,
                      title: str = "GFCC Features",
                      save_path: Optional[str] = None) -> None:
    """
    可视化GFCC特征
    
    Args:
        gfcc_features: GFCC特征张量
        sample_rate: 采样率
        hop_length: 跳跃长度
        title: 图表标题
        save_path: 保存路径(可选)
    """
    # 处理输入维度
    if isinstance(gfcc_features, torch.Tensor):
        if gfcc_features.dim() == 3:  # [batch, features, time]
            gfcc_features = gfcc_features[0]  # 取第一个批次
        gfcc_np = gfcc_features.cpu().numpy()
    else:
        gfcc_np = gfcc_features
    
    # 计算时间轴
    time_frames = gfcc_np.shape[1]
    time_axis = np.arange(time_frames) * hop_length / sample_rate
    
    # 创建图表
    plt.figure(figsize=(12, 8))
    
    # 绘制热力图
    plt.imshow(gfcc_np, aspect='auto', origin='lower', 
               extent=[0, time_axis[-1], 0, gfcc_np.shape[0]])
    
    plt.colorbar(label='GFCC Coefficient Value')
    plt.xlabel('Time (s)')
    plt.ylabel('GFCC Coefficient Index')
    plt.title(title)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"GFCC特征图已保存到: {save_path}")
    
    plt.show()

In [None]:
if __name__ == "__main__":
    # 检查CUDA可用性
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"使用设备: {device}")
    
    # 创建GFCC提取器
    gfcc_extractor = GFCCExtractor(
        sample_rate=16000,
        n_gfcc=13,
        n_bins=64,
        n_fft=2048,
        hop_length=512,
        f_min=50.0,
        f_max=8000.0,
        device=device
    )
    
    # 加载音频文件(示例)
    # waveform, sample_rate = torchaudio.load('your_audio_file.wav')
    
    # 或者创建示例音频
    duration = 2.0  # 2秒
    sample_rate = 16000
    t = torch.linspace(0, duration, int(sample_rate * duration))
    waveform = torch.sin(2 * torch.pi * 440 * t)  # 440Hz正弦波
    
    print(f"音频波形形状: {waveform.shape}")
    
    # 提取GFCC特征
    gfcc_features = gfcc_extractor.extract_gfcc_features(waveform)
    print(f"GFCC特征形状: {gfcc_features.shape}")
    
    # 提取包含差分的GFCC特征
    gfcc_with_delta = gfcc_extractor.extract_gfcc_with_delta(
        waveform, include_delta=True, include_delta2=True
    )
    print(f"包含差分的GFCC特征形状: {gfcc_with_delta.shape}")
    
    # 保存特征到CSV
    save_gfcc_to_csv(gfcc_features, 'gfcc_features.csv')
    
    # 可视化特征
    plot_gfcc_features(gfcc_features, sample_rate=sample_rate, 
                      hop_length=512, title="GFCC Features (nnAudio)", 
                      save_path='gfcc_features_nnaudio.png')
    
    print("GFCC特征提取完成！")