In [1]:
# 📦 Dependencies
# !uv pip install -U transformers datasets evaluate wandb

In [53]:
# 📚 Imports
import os
import numpy as np
import pandas as pd
import torch
from datasets import DatasetDict, IterableDataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments
import evaluate
import wandb
from sklearn.model_selection import train_test_split
import gc

def flush_gpu():
    import torch, gc
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    print("✅ GPU memory flushed.")
    
flush_gpu()
!nvidia-smi

✅ GPU memory flushed.
Tue May  6 11:49:17 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120                Driver Version: 550.120        CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4070 ...    Off |   00000000:00:10.0 Off |                  N/A |
| 30%   47C    P2             23W /  285W |    4103MiB /  16376MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 

In [32]:
# 🔐 WANDB setup
os.environ["WANDB_PROJECT"] = "whisperlaz-asr-ja"
os.environ["WANDB_LOG_MODEL"] = "false"
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mhrnph[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [33]:
# 📂 Load preprocessed segment index
df = pd.read_csv("./manifest/preprocessed-segments-index.csv")
df = df[df.lang == "ja"].reset_index(drop=True)
print(f"Loaded {len(df)} JA training samples")

Loaded 16978 JA training samples


In [34]:
# 🔀 Split into train, val, test
train_df, test_df = train_test_split(df, test_size=0.1, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)
print(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")

Train: 13752, Val: 1528, Test: 1698


In [35]:
# 🧠 Load model + processor
model_name = "openai/whisper-small"
processor = WhisperProcessor.from_pretrained(model_name, )
model = WhisperForConditionalGeneration.from_pretrained(model_name)
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

In [43]:
# 🔄 Preprocessing
def preprocess(row):
    data = np.load(row["npz_path"], allow_pickle=True)
    waveform = data["audio"]
    text = str(data["text"])

    inputs = processor(
        waveform,
        sampling_rate=16000,
        return_tensors="pt"
    )

    labels = processor.tokenizer(text, return_tensors="pt").input_ids[0]

    return {
        "input_features": inputs.input_features[0],
        "labels": labels
    }

In [44]:
# 🧠 Dataset generators
def make_generator(df):
    for _, row in df.iterrows():
        try:
            data = np.load(row.npz_path, allow_pickle=True)
            yield {
                "audio": {"array": data["audio"], "sampling_rate": 16000},
                "text": str(data["text"]),
                "start": float(data["start"]),
                "end": float(data["end"])
            }
        except Exception as e:
            print(f"Skip: {row.npz_path} — {type(e).__name__}: {e}")

In [45]:
# 🧱 Build lazy datasets
dataset = DatasetDict({
    "train": IterableDataset.from_generator(lambda: map(preprocess, train_df.to_dict(orient="records"))),
    "val": IterableDataset.from_generator(lambda: map(preprocess, val_df.to_dict(orient="records"))),
    "test": IterableDataset.from_generator(lambda: map(preprocess, test_df.to_dict(orient="records")))
})

In [46]:
import evaluate

# 🧪 Evaluation metric
metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    return {"wer": metric.compute(predictions=pred_str, references=label_str)}

In [50]:
# 📚 Imports
import os
import numpy as np
import pandas as pd
import torch
from datasets import DatasetDict, IterableDataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments
import evaluate
import wandb
from sklearn.model_selection import train_test_split

# 🧮 Compute max_steps from known train_df size
num_samples = len(train_df)               # e.g., 13752
batch_size = 8                            # per device
accum_steps = 2                           # gradient accumulation
steps_per_epoch = num_samples // (batch_size * accum_steps)
max_steps = steps_per_epoch * 5          # for 5 epochs


print(f"🧾 Estimated max_steps: {max_steps}")

model.model_input_names = ["input_features"] # <- this shit fixes the input_ids mismatch, normally expected input_ids
# ⚙️ Training config
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-ja-asmr-sm-2-earlyst",
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=accum_steps,
    max_steps=max_steps,                      # dynamically computed from dataset size
    learning_rate=5e-6,
    fp16=True,
    
    # label_smoothing_factor=0.1,
    
    # Logging & saving
    logging_steps=50,
    save_steps=200,
    save_total_limit=2,
    
    # Evaluation & generation
    eval_strategy="steps",
    eval_steps=200,
    predict_with_generate=True,
    generation_max_length=256,                # ← add this to limit decoding memory use
    generation_num_beams=1,                   # ← beam search = 1 for speed/memory

    # Model selection
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,

    # Tracking
    report_to="wandb",
    save_only_model=True,
    save_safetensors=True
)

🧾 Estimated max_steps: 4295


In [51]:
from transformers import EarlyStoppingCallback
from typing import Any, Dict, List, Union
import transformers
import torch

class DataCollatorSpeechSeq2SeqWithPadding:
    def __init__(self, processor, padding=True, return_tensors="pt"):
        self.processor = processor
        self.padding = padding
        self.return_tensors = return_tensors

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Separate input_features and labels
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [feature["labels"] for feature in features]

        # Pad input features
        batch = self.processor.feature_extractor.pad(
            input_features,
            padding=self.padding,
            return_tensors=self.return_tensors
        )

        # Pad labels
        labels_batch = self.processor.tokenizer.pad(
            {"input_ids": label_features},
            padding=self.padding,
            return_tensors=self.return_tensors
        )

        # Replace padding token id's in labels by -100 so they're ignored by the loss function
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch["input_ids"] == self.processor.tokenizer.pad_token_id, -100
        )

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    padding=True,
    return_tensors="pt",
)


trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["val"],
    data_collator=data_collator,
    tokenizer=processor,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]  # ⬅️ Added
)

  trainer = Seq2SeqTrainer(


In [52]:
trainer.train()

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss,Validation Loss,Wer
200,1.115,0.785599,1.657823


Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


KeyboardInterrupt: 

In [None]:
model.save_pretrained("./whisper-ja-asmr-sm-2-earlyst/final", safe_serialization=True)