<a href="https://colab.research.google.com/github/Icecream0507/KhailGen/blob/main/colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 音频扩散模型训练笔记本 (从 Hugging Face 加载数据)

这是一个用于训练音频扩散模型的 Google Colab 笔记本。它会自动从 Hugging Face Hub 加载 `Ice144/KhailGen` 数据集。

**使用说明：**
1.  从上到下依次运行每个单元格。
2.  在"挂载 Google Drive"单元格中，根据提示进行授权，以便保存处理后的数据和模型。
3.  运行第4节中的代码，它会自动下载并处理数据集，然后保存为 `.npy` 文件，以供训练使用。

## 1. 环境准备和库安装

In [1]:
!pip install torch torchvision torchaudio
!pip install numpy tqdm PyYAML soundfile pynvml tensorboardX
# Install Hugging Face datasets and librosa for audio processing
!pip install datasets librosa



## 2. 挂载 Google Drive

这一步是为了持久化存储数据和模型，避免每次运行都重新下载和处理。你需要点击链接并授权。

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 3. 创建配置文件 `config.yaml`

训练和数据处理所需的配置参数。请确保 `processed_folder` 和 `model_folder` 路径正确。

In [3]:
%%writefile config.yaml
# Training configuration
epochs: 1000
learning_rate: 0.0002
batch_size: 1
timesteps: 200

# Dataset configuration
# The folder where processed .npy files will be saved
processed_folder: /content/drive/MyDrive/khailgen_processed_data
audio_folder: /content/drive/MyDrive/KhailGen/data/wav
waveform_length: 1323000 # 1 second at 44.1kHz
sample_rate: 22050
channels: 1 # 1 for mono, 2 for stereo

# Model configuration
embedding_dim: 128

# WaveNet specific parameters
res_channels: 64
skip_channels: 64
num_res_layers: 8
dilation_cycle: 4

# Saving configuration
model_folder: /content/drive/MyDrive/khailgen_models
model_path: /content/drive/MyDrive/khailgen_models/wave_diffusion_best_model.pth
log_dir: /content/drive/MyDrive/khailgen_logs

Overwriting config.yaml


## 4. 从 Hugging Face 加载并预处理数据集

这个单元格会从 Hugging Face Hub 下载 `Ice144/KhailGen` 数据集，然后对音频进行重采样、归一化和分割，最后保存为 `.npy` 文件。这个过程可能需要一些时间。

In [4]:
!pip install -q datasets librosa soundfile tqdm torchcodec torchaudio


In [5]:
!nvcc --version
!nvidia-smi

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0
Fri Aug 22 11:18:13 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   55C    P8             12W /   70W |       0MiB /  15360MiB |      0%      Default |
|                       

In [6]:
# import os
# import librosa
# import numpy as np
# from tqdm import tqdm
# import yaml
# from datasets import load_dataset
# import torch
# import torchcodec
# import soundfile as sf
# from IPython.display import Audio as play_audio


# import torchaudio
# try:
#     torchaudio.set_audio_backend("soundfile")
#     print("使用SoundFile后端")
# except:
#     print("SoundFile不可用，使用默认后端")


# print(torch.cuda.is_available())


# config_path = 'config.yaml'

# def load_config(config_path):
#     with open(config_path, 'r') as f:
#         config = yaml.safe_load(f)
#     return config

# config = load_config(config_path)

# # def process_and_save_dataset(dataset_name, processed_folder, sr, fixed_length):
# #     """
# #     Loads an audio dataset from Hugging Face, processes it, and saves it as .npy files.
# #     """
# #     os.makedirs(processed_folder, exist_ok=True)

# #     # Load the dataset from Hugging Face
# #     print(f"Loading dataset '{dataset_name}' from Hugging Face...")
# #     dataset = load_dataset(dataset_name)

# #     # Access the 'train' split and the 'audio' column
# #     audio_data = dataset['train']['audio']


# #         # 调试：检查第一个样本的详细信息
# #     first_sample = dataset['train'][0]
# #     first_audio = first_sample['audio'].get_samples_played_in_range(0, 10)



# #     play_audio(first_audio.data.cpu().numpy(), rate=first_audio.sample_rate)


# #     print(type(first_audio.metadata))


# #     all_samples = first_audio.get_all_samples().data

# #     print(all_samples)

# #     print(f"Starting pre-processing of {len(audio_data)} audio samples...")
# #     with tqdm(total=len(audio_data), desc="Processing audio samples") as pbar:
# #         for i, sample in enumerate(audio_data):
# #           try:
# #               y, sr = librosa.load(sample['array'], sr=sr, mono = True)

# #               # Normalize to [-1, 1]
# #               y = y / np.max(np.abs(y))

# #               # Split into fixed-length segments and save
# #               waveform_len = y.shape[0]
# #               num_segments = int(np.ceil(waveform_len / fixed_length))

# #               for j in range(num_segments):
# #                   start_idx = j * fixed_length
# #                   end_idx = start_idx + fixed_length

# #                   if end_idx <= waveform_len:
# #                       segment = y[start_idx:end_idx]
# #                   else:
# #                       segment = y[start_idx:]
# #                       pad_width = fixed_length - len(segment)
# #                       segment = np.pad(segment, (0, pad_width), mode='constant')

# #                   segment_file_name = f"audio_sample_{i:04d}_segment_{j:04d}.npy"
# #                   processed_file_path = os.path.join(processed_folder, segment_file_name)

# #                   np.save(processed_file_path, segment)

# #           except Exception as e:
# #               print(f"Error processing sample {i}: {e}")
# #               continue

# #           pbar.update(1)

# #     print("Audio pre-processing complete.")

# def process_and_save_audio(audio_folder, processed_folder, sr, fixed_length):
#     """
#     加载、处理原始音频文件，将其分割为固定长度的片段，并保存为 .npy 格式。
#     """
#     os.makedirs(processed_folder, exist_ok=True)

#     audio_files = [os.path.join(audio_folder, f) for f in os.listdir(audio_folder)
#                   if f.endswith('.wav') or f.endswith('.mp3')]

#     print(f"Starting pre-processing of {len(audio_files)} audio files...")
#     with tqdm(total=len(audio_files), desc="Processing audio files") as pbar:
#         for file_path in audio_files:
#             file_name = os.path.basename(file_path)
#             base_name = os.path.splitext(file_name)[0]

#             try:
#                 # 保持 mono=True，正确加载为单声道
#                 y, _ = librosa.load(file_path, sr=sr, mono=True)

#                 # 归一化到 [-1, 1]
#                 y = y / np.max(np.abs(y))

#                 # 计算需要分割的片段数量
#                 waveform_len = y.shape[0]
#                 num_segments = int(np.ceil(waveform_len / fixed_length))

#                 # 处理每个片段
#                 for i in range(num_segments):
#                     start_idx = i * fixed_length
#                     end_idx = start_idx + fixed_length

#                     # 提取片段
#                     if end_idx <= waveform_len:
#                         segment = y[start_idx:end_idx]
#                     else:
#                         # 最后一个片段，需要填充
#                         segment = y[start_idx:]
#                         pad_width = fixed_length - len(segment)
#                         segment = np.pad(segment, (0, pad_width), mode='constant')

#                     # 生成唯一的文件名
#                     segment_file_name = f"{base_name}_segment_{i:04d}.npy"
#                     processed_file_path = os.path.join(processed_folder, segment_file_name)
#                     # 保存片段
#                     np.save(processed_file_path, segment)

#             except Exception as e:
#                 print(f"Error processing {file_name}: {e}")
#                 continue

#             pbar.update(1)

#     print("Audio pre-processing complete.")

# if __name__ == '__main__':
#     process_and_save_audio(
#         audio_folder=config["audio_folder"],
#         processed_folder=config["processed_folder"],
#         sr=config["sample_rate"],
#         fixed_length=config["waveform_length"]
#     )


## 5. 整合所有模型和训练代码

这里包含了所有模型、数据处理和训练逻辑的代码，它们将直接在内存中运行。

In [7]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import numpy as np
from tqdm import tqdm
import time
import yaml
import soundfile as sf
import pynvml
import math

# ==================== Data Utilities ====================
def pad_to_fixed_length_waveform(waveform, max_len):
    """
    Pads or truncates a 1D waveform to a fixed length.
    """
    waveform_len = waveform.shape[0]
    if waveform_len < max_len:
        pad_width = max_len - waveform_len
        waveform = np.pad(waveform, (0, pad_width), mode='constant')
    elif waveform_len > max_len:
        waveform = waveform[:max_len]
    return waveform

class AudioDatasetWaveform(Dataset):
    """
    Custom audio dataset to load waveform data from pre-processed .npy files.
    """
    def __init__(self, processed_folder, fixed_length):
        self.processed_files = [
            os.path.join(processed_folder, f)
            for f in os.listdir(processed_folder)
            if f.endswith('.npy')
        ]
        self.fixed_length = fixed_length
        print(f"Found {len(self.processed_files)} pre-processed audio files...")

        print("Loading all waveforms into memory...")
        self.waveforms = []
        with tqdm(total=len(self.processed_files), desc="Loading .npy files") as pbar:
            for file_path in self.processed_files:
                try:
                    waveform = np.load(file_path)
                    if waveform.ndim == 2 and waveform.shape[0] == 1:
                        waveform = waveform.squeeze(0)
                    elif waveform.ndim != 1:
                        print(f"Warning: Skipping {file_path} due to unexpected dimensions: {waveform.shape}")
                        continue
                    waveform = pad_to_fixed_length_waveform(waveform, self.fixed_length)
                    self.waveforms.append(waveform)
                except Exception as e:
                    print(f"Error loading {file_path}: {e}")
                pbar.update(1)

        if not all(w.shape[0] == self.fixed_length for w in self.waveforms):
            print("Warning: Waveform lengths are not consistent. Some files might be corrupted.")

        print("All waveforms loaded.")

    def __len__(self):
        return len(self.waveforms)

    def __getitem__(self, idx):
        waveform = self.waveforms[idx]
        return torch.tensor(waveform, dtype=torch.float32).unsqueeze(0)

# ==================== WaveNet Model ====================
def swish(x):
    return x * torch.sigmoid(x)

class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
        super(Conv, self).__init__()
        self.padding = dilation * (kernel_size - 1) // 2
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding)
        self.conv = nn.utils.weight_norm(self.conv)
        nn.init.kaiming_normal_(self.conv.weight)

    def forward(self, x):
        return self.conv(x)

class ZeroConv1d(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ZeroConv1d, self).__init__()
        self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0)
        self.conv.weight.data.zero_()
        self.conv.bias.data.zero_()

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, res_channels, skip_channels, dilation, time_embed_dim):
        super(ResidualBlock, self).__init__()
        self.res_channels = res_channels
        self.fc_t = nn.Linear(time_embed_dim, res_channels)
        self.dilated_conv = Conv(res_channels, 2 * res_channels, kernel_size=3, dilation=dilation)
        self.res_conv = Conv(res_channels, res_channels, kernel_size=1)
        self.skip_conv = Conv(res_channels, skip_channels, kernel_size=1)

    def forward(self, input_data):
        x, time_embed = input_data
        h = x
        time_part = self.fc_t(time_embed).unsqueeze(-1)
        h = h + time_part
        h = self.dilated_conv(h)
        out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :])
        res = self.res_conv(out)
        skip = self.skip_conv(out)
        return (x + res, skip)

class ResidualGroup(nn.Module):
    def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle, time_embed_dim):
        super(ResidualGroup, self).__init__()
        self.layers = nn.ModuleList([
            ResidualBlock(
                res_channels=res_channels,
                skip_channels=skip_channels,
                dilation=2**(i % dilation_cycle),
                time_embed_dim=time_embed_dim
            )
            for i in range(num_res_layers)
        ])

    def forward(self, x, time_embed):
        skip_connections = 0
        for layer in self.layers:
            x, skip = layer((x, time_embed))
            skip_connections = skip_connections + skip
        return skip_connections

class WaveNet1D(nn.Module):
    def __init__(self, in_channels, out_channels, time_embedding_dim, res_channels, skip_channels, num_res_layers, dilation_cycle):
        super(WaveNet1D, self).__init__()
        self.time_embedding_dim = time_embedding_dim

        self.time_mlp = nn.Sequential(
            nn.Linear(time_embedding_dim, time_embedding_dim * 4),
            nn.ReLU(),
            nn.Linear(time_embedding_dim * 4, res_channels)
        )

        self.init_conv = nn.Sequential(
            Conv(in_channels, res_channels, kernel_size=1),
            nn.ReLU()
        )

        self.residual_group = ResidualGroup(
            res_channels=res_channels,
            skip_channels=skip_channels,
            num_res_layers=num_res_layers,
            dilation_cycle=dilation_cycle,
            time_embed_dim=res_channels
        )

        self.final_conv = nn.Sequential(
            Conv(skip_channels, skip_channels, kernel_size=1),
            nn.ReLU(),
            ZeroConv1d(skip_channels, out_channels)
        )

    def forward(self, x, t):
        time_embed = self.sinusoidal_embedding(t, self.time_embedding_dim)
        time_embed = self.time_mlp(time_embed)

        x = self.init_conv(x)
        x = self.residual_group(x, time_embed)
        return self.final_conv(x)

    def sinusoidal_embedding(self, t, dim):
        half_dim = dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=t.device) * -embeddings)
        embeddings = t.float()[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

# ==================== Diffusion Model ====================

class DiffusionModel(torch.nn.Module):
    def __init__(self, unet, timesteps=1000, device='cpu'):
        super(DiffusionModel, self).__init__()
        self.unet = unet
        self.timesteps = timesteps
        self.device = device

        betas = torch.linspace(1e-4, 0.02, timesteps).to(device)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)

        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)

    def forward_diffusion(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        sqrt_alphas_cumprod_t = torch.sqrt(self.alphas_cumprod[t])[:, None, None]
        sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1. - self.alphas_cumprod[t])[:, None, None]

        x_t = sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
        return x_t, noise

    def get_noise_prediction(self, x_t, t):
        return self.unet(x_t, t)

    @torch.no_grad()
    def sample(self, sample_shape):
        batch_size = sample_shape[0]
        channels = sample_shape[1]
        waveform_length = sample_shape[2]
        x_t = torch.randn(batch_size, channels, waveform_length, device=self.device)

        for t in reversed(range(self.timesteps)):
            t_tensor = torch.full((batch_size,), t, device=self.device, dtype=torch.long)
            predicted_noise = self.get_noise_prediction(x_t, t_tensor)
            alpha_t = self.alphas[t]
            alpha_t_cumprod = self.alphas_cumprod[t]

            mean = 1.0 / torch.sqrt(alpha_t) * (x_t - (1.0 - alpha_t) / torch.sqrt(1.0 - alpha_t_cumprod) * predicted_noise)
            variance = self.betas[t]

            if t > 0:
                noise = torch.randn_like(x_t)
                x_t = mean + torch.sqrt(variance) * noise
            else:
                x_t = mean

        return x_t

# ==================== Training Utilities ====================

def get_gpu_memory_info():
    """Gets the GPU memory usage for the first GPU."""
    try:
        pynvml.nvmlInit()
        handle = pynvml.nvmlDeviceGetHandleByIndex(0)
        info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        free_memory_mib = info.free / (1024 ** 2)
        total_memory_mib = info.total / (1024 ** 2)
        print(f"GPU 0 - Total: {total_memory_mib:.2f} MiB, Free: {free_memory_mib:.2f} MiB")
        pynvml.nvmlShutdown()
        return free_memory_mib / 1024
    except pynvml.NVMLError as error:
        print(f"Error getting GPU memory info: {error}")
        return 0


def train(config_path):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)

    print("Initializing dataset and model...")
    dataset = AudioDatasetWaveform(
        processed_folder=config["processed_folder"],
        fixed_length=config["waveform_length"]
    )
    dataloader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=True)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")


    wavenet = WaveNet1D(
        in_channels=config["channels"],
        out_channels=config["channels"],
        time_embedding_dim=config["embedding_dim"],
        res_channels=config["res_channels"],
        skip_channels=config["skip_channels"],
        num_res_layers=config["num_res_layers"],
        dilation_cycle=config["dilation_cycle"]
    ).to(device)

    diffusion_model = DiffusionModel(wavenet, timesteps=config["timesteps"], device=device).to(device)

    optimizer = torch.optim.Adam(wavenet.parameters(), lr=config["learning_rate"])
    loss_fn = nn.MSELoss()
    epochs = config["epochs"]
    scaler = GradScaler()

    audio_save_dir = "temp_noisy_audios"
    os.makedirs(audio_save_dir, exist_ok=True)

    t_to_save = [50, 100, 200, 300, 400, 500]
    saved_t_vals = set()
    best_epoch_loss = float('inf')
    global_step = 0

    for epoch in range(epochs):
        total_epoch_loss = 0
        start_time = time.time()

        for i, waveform_batch in enumerate(dataloader):
            waveform_batch = waveform_batch.to(device)

            try:
                t = torch.randint(0, diffusion_model.timesteps, (waveform_batch.shape[0],), device=device).long()
                x_t, noise = diffusion_model.forward_diffusion(waveform_batch, t)

                optimizer.zero_grad()
                with torch.cuda.amp.autocast():
                    predicted_noise = wavenet(x_t, t)
                    loss = loss_fn(predicted_noise, noise)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                total_epoch_loss += loss.item()
                global_step += 1

                if (i + 1) % 100 == 0:
                    end_time = time.time()
                    duration = end_time - start_time
                    print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}, Time: {duration:.2f}s")
                    start_time = time.time()

                if (i + 1) % 500 == 0 and len(saved_t_vals) < len(t_to_save):
                    for t_val in t_to_save:
                        if t_val not in saved_t_vals:
                            random_waveform = waveform_batch[0]
                            t_tensor = torch.full((1,), t_val, device=device).long()
                            noisy_waveform, _ = diffusion_model.forward_diffusion(random_waveform.unsqueeze(0), t_tensor)

                            noisy_waveform_np = noisy_waveform.squeeze().cpu().numpy()
                            sf.write(os.path.join(audio_save_dir, f'noisy_t{t_val}.wav'), noisy_waveform_np, 44100)
                            saved_t_vals.add(t_val)
                            print(f"Saved noisy audio for t={t_val}")

            except RuntimeError as e:
                if "out of memory" in str(e):
                    print(f"{e}\nWARNING: Out of memory at step {i+1}. Skipping batch and clearing cache.")
                    torch.cuda.empty_cache()
                    time.sleep(5)
                    print("GPU memory status after clearing cache:")
                    print(torch.cuda.memory_stats())
                    del wavenet
                    del diffusion_model
                    del loss_fn
                    del scaler
                    return False
                else:
                    raise e

        avg_epoch_loss = total_epoch_loss / len(dataloader)

        print(f"Epoch {epoch+1} finished, Average Loss: {avg_epoch_loss:.4f}")
        if avg_epoch_loss < best_epoch_loss:
            best_epoch_loss = avg_epoch_loss
            model_save_path = os.path.join(config['model_folder'], "wave_diffusion_best_model.pth")
            torch.save(wavenet.state_dict(), model_save_path)
            print(f"New best model saved! Average Loss: {best_epoch_loss:.4f}")

        model_save_path = os.path.join(config['model_folder'], "wave_diffusion_current_model.pth")
        torch.save(wavenet.state_dict(), model_save_path)

    print("\nTraining finished.")
    time_stamp = time.strftime("%Y%m%d-%H%M%S")
    version = 0
    while True:
        model_save_path = f"diffusion_{time_stamp}_{config['waveform_length']/config['sample_rate']}_v{version}.pth"
        model_save_path = config['model_folder'] + "/" + model_save_path
        try:
            with open(model_save_path, 'x'):
                break
        except FileExistsError:
            version += 1
    torch.save(wavenet.state_dict(), model_save_path)
    print(f"Final model saved at: {model_save_path}")


## 6. 启动训练

现在，你可以运行以下代码来启动训练过程。训练时，模型文件和日志将保存到你 Google Drive 指定的路径下。

In [None]:
if __name__ == "__main__":
    config_path = 'config.yaml'
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file not found at: {config_path}. Please run the previous cell.")

    # 启动训练
    train(config_path)

Initializing dataset and model...
Found 811 pre-processed audio files...
Loading all waveforms into memory...


Loading .npy files: 100%|██████████| 811/811 [00:25<00:00, 31.33it/s]
  WeightNorm.apply(module, name, dim)


All waveforms loaded.
Using device: cuda


  scaler = GradScaler()
  with torch.cuda.amp.autocast():


Epoch [1/1000], Step [100/811], Loss: 0.5916, Time: 27.53s
Epoch [1/1000], Step [200/811], Loss: 0.0992, Time: 25.74s
Epoch [1/1000], Step [300/811], Loss: 0.0651, Time: 25.76s
