<a href="https://colab.research.google.com/github/CityHuman/Auto-GPT/blob/master/ast%E5%BE%AE%E8%B0%832.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install transformers datasets torch torchaudio librosa evaluate scikit-learn accelerate

Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.6


In [2]:
import torch
import numpy as np
import warnings
import os
from datasets import load_dataset, Audio
from transformers import ASTFeatureExtractor, ASTForAudioClassification, TrainingArguments, Trainer
import evaluate

# -----------------------------------------------------------------------------
# 0. 硬件加速优化设置
# -----------------------------------------------------------------------------
# 开启 TensorFloat-32 (TF32) 计算，在 A100/3090+ 上能大幅提升矩阵乘法速度
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# 忽略非关键警告
warnings.filterwarnings("ignore")
# 设置环境变量，减少 HuggingFace 的冗余输出
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# -----------------------------------------------------------------------------
# 1. 激进的参数配置 (针对 80G VRAM + 160G RAM)
# -----------------------------------------------------------------------------
MODEL_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593"

# 【优化点1】Batch Size 大幅提升
# AST Base 模型的 Patch 序列较长，80G 显存保守估计可开 64，激进可尝试 96 或 128
# 如果爆显存，请降为 48 或 32
BATCH_SIZE = 64

# 【优化点2】取消梯度累积
# 既然 Batch Size 已经够大（>32），直接设为 1 即可，减少通信开销
GRADIENT_ACCUMULATION_STEPS = 1

# 【优化点3】提高 CPU 数据加载并行度
# 拥有 160G 内存，可以大胆使用多进程加载
NUM_WORKERS = 16

LEARNING_RATE = 5e-5
NUM_EPOCHS = 10 # 速度快了，可以多跑几轮
MAX_AUDIO_LENGTH = 1024

# -----------------------------------------------------------------------------
# 2. 加载数据
# -----------------------------------------------------------------------------
print(f">> [System] 使用设备: {torch.cuda.get_device_name(0)}")
print(">> [Data] 正在加载 ESC-50 数据集...")
dataset = load_dataset("ashraq/esc50", split="train")
dataset = dataset.train_test_split(test_size=0.2, seed=42)

# 标签映射
labels_list = dataset["train"].unique("category")
labels_list.sort()
label2id = {label: i for i, label in enumerate(labels_list)}
id2label = {i: label for i, label in enumerate(labels_list)}
num_labels = len(labels_list)

def encode_labels(example):
    example["labels"] = label2id[example["category"]]
    return example

dataset = dataset.map(encode_labels)

# -----------------------------------------------------------------------------
# 3. 高性能预处理
# -----------------------------------------------------------------------------
print(">> [Data] 正在进行多进程音频预处理...")

feature_extractor = ASTFeatureExtractor.from_pretrained(MODEL_CHECKPOINT)
target_sampling_rate = feature_extractor.sampling_rate

dataset = dataset.cast_column("audio", Audio(sampling_rate=target_sampling_rate))

def preprocess_function(examples):
    audio_arrays = [x["array"] for x in examples["audio"]]
    inputs = feature_extractor(
        audio_arrays,
        sampling_rate=target_sampling_rate,
        max_length=MAX_AUDIO_LENGTH,
        padding="max_length",
        truncation=True,
        return_tensors="np"
    )
    return inputs

cols_to_remove = [col for col in dataset["train"].column_names if col not in ["labels"]]

# 【优化点4】多进程 Map
# 利用你的多核 CPU 并行处理波形转频谱图
encoded_dataset = dataset.map(
    preprocess_function,
    remove_columns=cols_to_remove,
    batched=True,
    batch_size=100, # 批处理量增大
    num_proc=16,    # 开启 16 个进程并行处理
    desc="Preprocessing audio"
)

# -----------------------------------------------------------------------------
# 4. 模型加载
# -----------------------------------------------------------------------------
print(">> [Model] 加载模型...")
model = ASTForAudioClassification.from_pretrained(
    MODEL_CHECKPOINT,
    num_labels=num_labels,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True
)

# -----------------------------------------------------------------------------
# 5. 训练与评估设置
# -----------------------------------------------------------------------------
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

# 检测是否支持 BF16 (A100/H100 专属)
use_bf16 = torch.cuda.is_bf16_supported()
if use_bf16:
    print(">> [System] 检测到支持 BF16，已启用以获得最佳性能。")

training_args = TrainingArguments(
    output_dir="./ast_esc50_result_optimized",
    eval_strategy="epoch",
    save_strategy="epoch",

    # 学习率与 Batch Size 的关系
    learning_rate=LEARNING_RATE,

    # 核心显存优化
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE * 2, # 验证时不计算梯度，Batch 可以翻倍
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,

    num_train_epochs=NUM_EPOCHS,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",

    # 【优化点5】精度设置
    bf16=use_bf16,      # 如果是 A100，优先用 bf16
    fp16=(not use_bf16),# 否则回退到 fp16

    save_total_limit=1,
    remove_unused_columns=False,

    # 【优化点6】DataLoader 性能
    dataloader_num_workers=NUM_WORKERS, # 这里设为 16，利用你的内存带宽
    dataloader_pin_memory=True,         # 锁页内存，加速 CPU 到 GPU 传输

    # 优化器设置 (使用 Fused AdamW 加速)
    optim="adamw_torch"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    compute_metrics=compute_metrics,
)

# -----------------------------------------------------------------------------
# 6. 开始极速微调
# -----------------------------------------------------------------------------
print(">> [Train] 开始训练...")
trainer.train()

# -----------------------------------------------------------------------------
# 7. 保存结果
# -----------------------------------------------------------------------------
print(">> [Save] 保存模型...")
final_path = "./ast_esc50_final_optimized"
trainer.save_model(final_path)
feature_extractor.save_pretrained(final_path)
print(f">> 完成。模型保存在: {final_path}")

  self.setter(val)


>> [System] 使用设备: NVIDIA A100-SXM4-80GB
>> [Data] 正在加载 ESC-50 数据集...


README.md:   0%|          | 0.00/345 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


dataset_infos.json: 0.00B [00:00, ?B/s]



data/train-00000-of-00002-2f1ab7b824ec75(…):   0%|          | 0.00/387M [00:00<?, ?B/s]

data/train-00001-of-00002-27425e5c1846b4(…):   0%|          | 0.00/387M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2000 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/1600 [00:00<?, ? examples/s]

Map:   0%|          | 0/1600 [00:00<?, ? examples/s]

Map:   0%|          | 0/400 [00:00<?, ? examples/s]

>> [Data] 正在进行多进程音频预处理...


preprocessor_config.json:   0%|          | 0.00/297 [00:00<?, ?B/s]

Preprocessing audio (num_proc=16):   0%|          | 0/1600 [00:00<?, ? examples/s]

Preprocessing audio (num_proc=16):   0%|          | 0/400 [00:00<?, ? examples/s]

>> [Model] 加载模型...


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/203 [00:00<?, ?it/s]

ASTForAudioClassification LOAD REPORT from: MIT/ast-finetuned-audioset-10-10-0.4593
Key                     | Status   |                                                                                        
------------------------+----------+----------------------------------------------------------------------------------------
classifier.dense.weight | MISMATCH | Reinit due to size mismatch ckpt: torch.Size([527, 768]) vs model:torch.Size([50, 768])
classifier.dense.bias   | MISMATCH | Reinit due to size mismatch ckpt: torch.Size([527]) vs model:torch.Size([50])          

Notes:
- MISMATCH	:ckpt weights were loaded, but they did not match the original empty weight shapes.


Downloading builder script: 0.00B [00:00, ?B/s]

warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.


>> [System] 检测到支持 BF16，已启用以获得最佳性能。
>> [Train] 开始训练...


Epoch,Training Loss,Validation Loss,Accuracy
1,2.7534,1.038467,0.885
2,0.182234,0.232376,0.95
3,0.042307,0.148556,0.965
4,0.011707,0.142576,0.96
5,0.004679,0.111623,0.9775
6,0.00245,0.116884,0.9775
7,0.001938,0.114137,0.975
8,0.0017,0.113235,0.975
9,0.001595,0.112721,0.9775
10,0.001576,0.11284,0.9775


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

>> [Save] 保存模型...


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

>> 完成。模型保存在: ./ast_esc50_final_optimized
