In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torchaudio
from tqdm.notebook import tqdm
import random
import soundfile as sf
from audiomentations import Compose, AddGaussianNoise, PitchShift, Gain, Resample
from audiomentations.core.transforms_interface import BaseTransform

# --- Configuration and PadOrTruncate Class (保持不变) ---
SAMPLE_RATE = 44100
TARGET_DURATION = 5
MAX_DURATION_TRUNCATE = 30.0
NUM_AUGMENTATIONS_PER_SAMPLE = 2
OUTPUT_DIR = '/home/renmengxing/audioRec/audio_data/augmented_audio_data'
INFO_CSV_NAME = '/home/renmengxing/audioRec/audio_data/augmented_train.csv'

class PadOrTruncate(BaseTransform):
    """Pads or truncates audio to a target duration."""
    def __init__(self, target_duration_seconds, sample_rate, p=1.0):
        super().__init__(p)
        self.target_duration_seconds = target_duration_seconds
        self.sample_rate = sample_rate
        self.target_samples = int(target_duration_seconds * sample_rate)
    def apply(self, samples, sample_rate):
        current_samples = samples.shape[0]
        if current_samples < self.target_samples:
            pad_needed = self.target_samples - current_samples
            pad_left = pad_needed // 2
            pad_right = pad_needed - pad_left
            samples = np.pad(samples, (pad_left, pad_right), 'constant')
        elif current_samples > self.target_samples:
            start_index = np.random.randint(0, current_samples - self.target_samples + 1)
            samples = samples[start_index : start_index + self.target_samples]
        return samples

# --- 逻辑修正：将预处理和增强分开 ---

# 1. 定义一个只用于预处理的变换
# 这个变换负责将任何长度的音频都变成目标长度
preprocess_transform = PadOrTruncate(
    target_duration_seconds=TARGET_DURATION, 
    sample_rate=SAMPLE_RATE, 
    p=1.0
)

# 2. 定义一个只用于“增强”的管道
# 这个管道作用于已经预处理过的、长度正确的音频上
augmentation_pipeline = Compose([
    AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
    PitchShift(min_semitones=-4, max_semitones=4, p=0.3),
    Gain(min_gain_db=-6.0, max_gain_db=6.0, p=0.3),
    Resample(min_sample_rate=int(SAMPLE_RATE * 0.9), max_sample_rate=int(SAMPLE_RATE * 1.1), p=0.2),
])

# --- 修正后的主函数 ---
def augment_audio_dataset_fixed(input_csv_path, output_dir, target_duration,
                                max_duration_truncate, num_augmentations_per_sample, 
                                preprocessor, augmentor):
    df = pd.read_csv(input_csv_path)
    augmented_data = []
    df = df.head(5) 

    os.makedirs(output_dir, exist_ok=True)
    for label in df['label'].unique():
        os.makedirs(os.path.join(output_dir, str(label)), exist_ok=True)

    for index, row in tqdm(df.iterrows(), total=len(df), desc="Augmenting audio"):
        fname = row['fname']
        label = row['label']
        original_duration = row['duration']
        audio_path = os.path.join('/home/renmengxing/audioRec/audio_data/train/', fname)

        if not os.path.exists(audio_path):
            print(f"Warning: File not found {audio_path}, skipping.")
            continue

        try:
            samples, sr = sf.read(audio_path)
            if samples.ndim > 1:
                samples = samples.mean(axis=1)

            if sr != SAMPLE_RATE:
                resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=SAMPLE_RATE)
                samples = resampler(torch.from_numpy(samples).float().unsqueeze(0)).squeeze(0).numpy()
                sr = SAMPLE_RATE

            if original_duration > max_duration_truncate:
                num_samples_to_keep = int(max_duration_truncate * sr)
                start_idx = random.randint(0, len(samples) - num_samples_to_keep)
                samples = samples[start_idx : start_idx + num_samples_to_keep]

            # 步骤 1: 对原始音频进行预处理，使其长度统一为 TARGET_DURATION
            # 注意这里我们直接“调用”预处理器实例
            processed_samples = preprocessor(samples=samples, sample_rate=sr)
            
            # 保存这个经过预处理的“原始”版本
            original_output_fname = f"original_{os.path.splitext(fname)[0]}.wav"
            original_output_path = os.path.join(output_dir, str(label), original_output_fname)
            sf.write(original_output_path, processed_samples, sr)
            augmented_data.append({
                'fname': os.path.join(str(label), original_output_fname),
                'label': label,
                'duration': target_duration
            })

            # 步骤 2: 在已经处理好的音频上，应用增强管道来生成多个增强版本
            for i in range(num_augmentations_per_sample):
                # 直接在 processed_samples 上进行增强
                augmented_samples = augmentor(samples=processed_samples, sample_rate=sr)
                
                augmented_output_fname = f"aug_{i+1}_{os.path.splitext(fname)[0]}.wav"
                augmented_output_path = os.path.join(output_dir, str(label), augmented_output_fname)
                sf.write(augmented_output_path, augmented_samples, sr)
                
                augmented_data.append({
                    'fname': os.path.join(str(label), augmented_output_fname),
                    'label': label,
                    'duration': target_duration
                })

        except Exception as e:
            # 加上 traceback 可以帮助调试更复杂的问题
            import traceback
            print(f"Error processing file {audio_path}: {e}, skipping.")
            # traceback.print_exc() # 如果问题还存在，取消这行注释
            continue

    augmented_df = pd.DataFrame(augmented_data)
    augmented_df.to_csv(INFO_CSV_NAME, index=False)
    print(f"\nData augmentation completed! Augmented data info saved to {INFO_CSV_NAME}")
    print(f"Generated {len(augmented_df)} audio files in total.")


# --- 执行修正后的数据增强 ---
if __name__ == "__main__":
    input_csv_path = '/home/renmengxing/audioRec/audio_data/train_with_duration.csv'
    augment_audio_dataset_fixed(input_csv_path=input_csv_path,
                                output_dir=OUTPUT_DIR,
                                target_duration=TARGET_DURATION,
                                max_duration_truncate=MAX_DURATION_TRUNCATE,
                                num_augmentations_per_sample=NUM_AUGMENTATIONS_PER_SAMPLE,
                                preprocessor=preprocess_transform,
                                augmentor=augmentation_pipeline)

PadOrTruncate initialized: target_samples=220500, target_duration_seconds=5


Augmenting audio:   0%|          | 0/9473 [00:00<?, ?it/s]

Error processing file /home/renmengxing/audioRec/audio_data/train/00044347.wav: 'PadOrTruncate' object is not callable, skipping.
Error processing file /home/renmengxing/audioRec/audio_data/train/001ca53d.wav: 'PadOrTruncate' object is not callable, skipping.
Error processing file /home/renmengxing/audioRec/audio_data/train/002d256b.wav: 'PadOrTruncate' object is not callable, skipping.
Error processing file /home/renmengxing/audioRec/audio_data/train/0033e230.wav: 'PadOrTruncate' object is not callable, skipping.
Error processing file /home/renmengxing/audioRec/audio_data/train/00353774.wav: 'PadOrTruncate' object is not callable, skipping.
Error processing file /home/renmengxing/audioRec/audio_data/train/003b91e8.wav: 'PadOrTruncate' object is not callable, skipping.
Error processing file /home/renmengxing/audioRec/audio_data/train/003da8e5.wav: 'PadOrTruncate' object is not callable, skipping.
Error processing file /home/renmengxing/audioRec/audio_data/train/0048fd00.wav: 'PadOrTrun

KeyboardInterrupt: 

In [None]:
print(1)

1
