In [3]:
from google.colab import drive
import os
# ====== Google Drive 設定 ======
drive.mount('/content/drive')
SAVE_DIR = "/content/drive/MyDrive/asr_checkpoints"
LOG_PATH = os.path.join(SAVE_DIR, "validation_log.json")
os.makedirs(SAVE_DIR, exist_ok=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [10]:
# ===== 安裝必要套件（只需執行一次） =====
!pip install datasets
!pip install transformers
!pip install torchaudio
!pip install jiwer
!pip install mamba-ssm
#要注意一下，如果沒有連線GPU的話mamba的載入可能會出問題

!pip install matplotlib

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.

In [11]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
import torchaudio
import torchaudio.transforms as T
from mamba_ssm import Mamba
import random

import jiwer  # 用於計算 WER
import json
import matplotlib.pyplot as plt

In [12]:
# ====== Config ======
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 4
SEQ_LEN = 1024
MEL_DIM = 80
MAX_TOKEN_LEN = 128

In [13]:
# ====== Load Tokenizer ======
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# ====== Audio Feature Extractor ======
melspec_transform = T.MelSpectrogram(sample_rate=16000, n_fft=400, hop_length=160, n_mels=MEL_DIM)

def extract_features(waveform, sample_rate):
    if sample_rate != 16000:
        resample = T.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resample(waveform)
    mel = melspec_transform(waveform).squeeze(0).transpose(0, 1)
    return mel[:SEQ_LEN]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [14]:
# ====== Dataset ======
class ASRDataset(Dataset):
    def __init__(self, split="train.clean.100", limit=100):
        self.dataset = load_dataset("librispeech_asr", split=split)
        self.dataset = self.dataset.select(range(min(len(self.dataset), limit)))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        waveform = torch.tensor(sample['audio']['array'])
        sr = sample['audio']['sampling_rate']
        mel = extract_features(waveform, sr)

        text = sample['text'].lower()
        token = tokenizer(text, padding="max_length", truncation=True, max_length=MAX_TOKEN_LEN, return_tensors="pt")

        return mel, token.input_ids.squeeze(0), token.attention_mask.squeeze(0), text


def collate_fn(batch):
    mels, input_ids, attn_masks, texts = zip(*batch)
    mels = nn.utils.rnn.pad_sequence(mels, batch_first=True)
    input_ids = torch.stack(input_ids)
    attn_masks = torch.stack(attn_masks)
    return mels, input_ids, attn_masks, texts

In [15]:
# ====== Mamba Encoder + Transformer Decoder ======
class Seq2SeqASR(nn.Module):
    def __init__(self, mel_dim, model_dim, vocab_size):
        super().__init__()
        self.encoder_proj = nn.Linear(mel_dim, model_dim)
        self.encoder = Mamba(d_model=model_dim)

        decoder_layer = nn.TransformerDecoderLayer(d_model=model_dim, nhead=4, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=2)

        self.token_embedding = nn.Embedding(vocab_size, model_dim)
        self.output_proj = nn.Linear(model_dim, vocab_size)

    def forward(self, mel, tgt_input, tgt_mask=None):
        x = self.encoder_proj(mel)
        memory = self.encoder(x)

        tgt_emb = self.token_embedding(tgt_input)
        out = self.decoder(tgt=tgt_emb, memory=memory, tgt_mask=tgt_mask)
        return self.output_proj(out)

def generate_square_subsequent_mask(sz):
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1).to(device)

In [16]:
# ====== Checkpoint loading ======
def load_checkpoint(filepath, model, optimizer):
    checkpoint = torch.load(filepath, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"✅ Loaded checkpoint from epoch {checkpoint['epoch'] + 1}")
    return model, optimizer, start_epoch


In [17]:
# ====== Inference ======
def greedy_decode(model, mel, max_len=100):
    model.eval()
    with torch.no_grad():
        memory = model.encoder(model.encoder_proj(mel))
        ys = torch.full((1, 1), tokenizer.cls_token_id, dtype=torch.long, device=device)

        for i in range(max_len):
            tgt_mask = generate_square_subsequent_mask(ys.size(1))
            out = model.decoder(model.token_embedding(ys), memory, tgt_mask=tgt_mask)
            prob = model.output_proj(out[:, -1])
            next_token = prob.argmax(dim=-1).unsqueeze(1)
            ys = torch.cat([ys, next_token], dim=1)
            if next_token.item() == tokenizer.sep_token_id:
                break
        return tokenizer.decode(ys.squeeze(), skip_special_tokens=True)

In [18]:
# ====== Training ======
dataset = ASRDataset(limit=300)
train_set, val_set = torch.utils.data.random_split(dataset, [250, 50])
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=1, shuffle=False, collate_fn=collate_fn)

model = Seq2SeqASR(mel_dim=MEL_DIM, model_dim=256, vocab_size=tokenizer.vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

start_epoch = 0
resume_path = os.path.join(SAVE_DIR, "checkpoint_epoch_3.pt")
if os.path.exists(resume_path):
    model, optimizer, start_epoch = load_checkpoint(resume_path, model, optimizer)

# 讀取 log 檔
if os.path.exists(LOG_PATH):
    with open(LOG_PATH, 'r') as f:
        validation_log = json.load(f)
else:
    validation_log = []

epochs = 5
for epoch in range(start_epoch, epochs):
    model.train()
    total_loss = 0
    for mel, tgt, attn_mask, _ in train_loader:
        mel, tgt = mel.to(device), tgt.to(device)
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]

        tgt_mask = generate_square_subsequent_mask(tgt_input.size(1))
        logits = model(mel, tgt_input, tgt_mask)

        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_output.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Train Loss: {avg_loss:.4f}")

    # ===== Validation =====
    model.eval()
    predictions = []
    references = []
    with torch.no_grad():
        for mel, _, _, texts in val_loader:
            mel = mel.to(device)
            pred = greedy_decode(model, mel)
            predictions.append(pred.lower())
            references.append(texts[0].lower())

    wer = jiwer.wer(references, predictions)
    acc = sum([p.strip() == r.strip() for p, r in zip(predictions, references)]) / len(references)
    print(f"Validation WER: {wer:.3f}, Exact Match Accuracy: {acc:.2%}")

    # ===== Save Checkpoint =====
    ckpt_path = os.path.join(SAVE_DIR, f"checkpoint_epoch_{epoch+1}.pt")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
    }, ckpt_path)
    print(f"✅ Saved checkpoint to {ckpt_path}")

    # ===== Save Validation Log =====
    validation_log.append({
        "epoch": epoch + 1,
        "train_loss": avg_loss,
        "wer": wer,
        "accuracy": acc
    })
    with open(LOG_PATH, 'w') as f:
        json.dump(validation_log, f, indent=2)
    print(f"📝 Logged validation results to {LOG_PATH}")

README.md:   0%|          | 0.00/10.2k [00:00<?, ?B/s]

librispeech_asr.py:   0%|          | 0.00/11.4k [00:00<?, ?B/s]

The repository for librispeech_asr contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/librispeech_asr.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading data:   0%|          | 0.00/338M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/314M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/347M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/329M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.39G [00:00<?, ?B/s]

FSTimeoutError: 

In [None]:
# ===== Plot Training Log =====
epochs_list = [entry['epoch'] for entry in validation_log]
loss_list = [entry['train_loss'] for entry in validation_log]
wer_list = [entry['wer'] for entry in validation_log]
acc_list = [entry['accuracy'] for entry in validation_log]

plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.plot(epochs_list, loss_list, marker='o')
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")

plt.subplot(1, 3, 2)
plt.plot(epochs_list, wer_list, marker='o', color='red')
plt.title("WER")
plt.xlabel("Epoch")
plt.ylabel("WER")

plt.subplot(1, 3, 3)
plt.plot(epochs_list, acc_list, marker='o', color='green')
plt.title("Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")

plt.tight_layout()
plt.savefig(os.path.join(SAVE_DIR, "training_curves.png"))
plt.show()


In [None]:
# Try a sample
sample_mel, _, _, _ = dataset[0]
sample_mel = sample_mel.unsqueeze(0).to(device)
transcription = greedy_decode(model, sample_mel)
print("Predicted:", transcription)
