<a href="https://colab.research.google.com/github/GalJakob/NLP/blob/main/post_asr_20250922_15_25.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
### THIS CELL is for installations and setup ###

### FIRST THINGS BEFORE STARTING : ###
# open terminal
# copy this: hf auth login
# copy this: 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
# press y in "add token as git credentials?"
!pip install -U datasets
!pip install transformers datasets evaluate --quiet
!pip install jiwer
!pip install torchcodec 

In [None]:
### THIS CELL is for global constants and hyper parameters for H200 + ByT5-small###

import torch
# ---- Data / paths ----

DATASET_DIR = "combined_asr_dataset"
INPUT_COL   = "asr_output"
TARGET_COL  = "sentence"
VAL_SIZE    = 0.10
SEED        = 42

# ---- Model / output ----
MODEL_NAME  = "google/byt5-small"
OUTPUT_DIR  = "checkpoints/byt5_postasr_h200"
PRECISION   = torch.bfloat16  # Hopper: prefer bf16 over fp16/fp32
MODEL_TAG = "byt5-small"       
GPU_TAG   =  "H200"

# ---- Tokenization (ByT5: bytes ≈ tokens)
# Your percentiles: inputs P99≈402 (+prefix) → 416; targets P99.5≈632 → 640
MAX_INPUT_LEN   =  728    # encoder cap
MAX_TARGET_LEN  = 728  # decoder / labels cap
TRUNCATION_SIDE = "right"
PADDING_SIDE    = "right"
PAD_TO_MULTIPLE_OF = 8     # tensor cores happy


# ---- Training (optimizer/schedule) ----
# H200 has ample VRAM; you can run bigger batches even with long targets.
TRAIN_BATCH_SIZE = 32
EVAL_BATCH_SIZE  = 32
GRADIENT_ACCUMULATION_STEPS = 1          # effective batch = 64 (single GPU)
LEARNING_RATE   = 1e-4                   # good default for T5 fine-tuning
WEIGHT_DECAY    = 0.01
WARMUP_RATIO    = 0.06
LR_SCHEDULER    = "linear"
LABEL_SMOOTHING = 0.0
GROUP_BY_LENGTH = True
GRADIENT_CLIP_NORM = 1.0

# Choose ONE of these two stopping modes:
NUM_TRAIN_EPOCHS = 3                      # typical fine-tune: 2–5 epochs
MAX_STEPS        = -1                     # set >0 to override epochs

# ---- Eval / save / logging ----
EVALUATION_STRATEGY = "steps"
EVAL_STEPS          = 2000                # evaluate ~every 2k steps
SAVE_STRATEGY       = "steps"
SAVE_STEPS          = 2000
SAVE_TOTAL_LIMIT    = 3
LOGGING_STRATEGY    = "steps"
LOGGING_STEPS       = 100
LOAD_BEST_MODEL_AT_END = True
METRIC_FOR_BEST_MODEL  = "wer"
GREATER_IS_BETTER      = False
REPORT_TO              = "none"           # set "tensorboard"/"wandb" if you use them
EARLY_STOPPING_PATIENCE = 3

# ---- Generation (used when predict_with_generate=True)
PREDICT_WITH_GENERATE = True
GEN_MAX_NEW_TOKENS    = MAX_TARGET_LEN    # cap decoder output
GEN_NUM_BEAMS         = 1                 # greedy is standard for WER

# If you want sampling/penalties, set via model.generation_config (Trainer ignores here):
DO_SAMPLE             = False
NO_REPEAT_NGRAM_SIZE  = 0
REPETITION_PENALTY    = 1.0
TEMPERATURE           = 1.0
TOP_P                 = 1.0

# ---- Trainer misc ----
REMOVE_UNUSED_COLUMNS = True
LABEL_NAMES           = ["labels"]





In [None]:
### THIS CELL is for global constants and hyper parameters for A40 + ByT5-small###

import torch
# ---- Data / paths ----
DATASET_DIR = "combined_asr_dataset"
INPUT_COL   = "asr_output"
TARGET_COL  = "sentence"
VAL_SIZE    = 0.10
SEED        = 42

# ---- Model / output ----
MODEL_NAME  = "google/byt5-small"
OUTPUT_DIR  = "checkpoints/byt5_postasr_a40"
PRECISION   = torch.bfloat16   
MODEL_TAG = "byt5-small"       
GPU_TAG   =  "A40"

# ---- Tokenization (ByT5: bytes ≈ tokens)
# Your percentiles: inputs P99≈402 (+prefix) → 416; targets P99.5≈632 → 640
MAX_INPUT_LEN   = 728     # encoder cap
MAX_TARGET_LEN  = 728     # decoder / labels cap
TRUNCATION_SIDE = "right"
PADDING_SIDE    = "right"
PAD_TO_MULTIPLE_OF = 8     # tensor cores happy

# ---- Training (optimizer/schedule) ----
# Safe defaults for 48 GB VRAM with long decoder sequences.
# Effective train batch = PER_DEVICE_TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * num_gpus
TRAIN_BATCH_SIZE = 16
EVAL_BATCH_SIZE  = 32
GRADIENT_ACCUMULATION_STEPS = 2          # effective train batch = 16*2=32 on 1 GPU
LEARNING_RATE   = 1e-4
WEIGHT_DECAY    = 0.01
WARMUP_RATIO    = 0.03
LR_SCHEDULER    = "linear"
LABEL_SMOOTHING = 0.0
GROUP_BY_LENGTH = True
GRADIENT_CLIP_NORM = 1.0

# Choose ONE of these two stopping modes:
NUM_TRAIN_EPOCHS = 3
MAX_STEPS        = -1

# ---- Eval / save / logging ----
EVALUATION_STRATEGY = "steps"
EVAL_STEPS          = 2000
SAVE_STRATEGY       = "steps"
SAVE_STEPS          = 2000
SAVE_TOTAL_LIMIT    = 3
LOGGING_STRATEGY    = "steps"
LOGGING_STEPS       = 100
LOAD_BEST_MODEL_AT_END = True
METRIC_FOR_BEST_MODEL  = "wer"
GREATER_IS_BETTER      = False
REPORT_TO              = "none"
EARLY_STOPPING_PATIENCE = 3

# ---- Generation (used when predict_with_generate=True)
PREDICT_WITH_GENERATE = True
GEN_MAX_NEW_TOKENS    = MAX_TARGET_LEN    # cap decoder output
GEN_NUM_BEAMS         = 1                 # greedy is standard for WER

# If you want sampling/penalties, set via model.generation_config (Trainer ignores here):
DO_SAMPLE             = False
NO_REPEAT_NGRAM_SIZE  = 0
REPETITION_PENALTY    = 1.0
TEMPERATURE           = 1.0
TOP_P                 = 1.0

# ---- Trainer misc ----
REMOVE_UNUSED_COLUMNS = True
LABEL_NAMES           = ["labels"]


In [None]:
### THIS CELL is for loading datasets made with different ASR's ###

import torch
from datasets import load_from_disk

ds = load_from_disk(DATASET_DIR)
splits = ds.train_test_split(test_size=0.1, seed=42)
training_data = splits["train"]
val_data      = splits["test"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def _row_ok(x):
    src = (x.get("asr_output") or "").strip()
    tgt = (x.get("sentence") or "").strip()
    return len(src) > 0 and len(tgt) > 0

training_data_raw = training_data.filter(_row_ok)
val_data = val_data.filter(_row_ok)



In [None]:
### THIS CELL is for filtering long sentence - they won't pass MAX_IN/MAX/TGT ###

def _len_bytes(s):
    return len((s or "").encode("utf-8"))

def _keep_example(ex):
    return (
        _len_bytes(ex["asr_output"]) <= MAX_INPUT_LEN 
        and _len_bytes(ex["sentence"]) <= MAX_TARGET_LEN
    )

training_data = training_data_raw.filter(_keep_example)

print("Before:", len(training_data_raw))
print("After :", len(training_data))


In [None]:
### THIS CELL is for tokenizer adjustments ###
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def _clean(s):
    s = "" if s is None else str(s)
    return " ".join(s.split()).strip()

def preprocess_batch(batch):
    srcs = [f"fix mistakes: {_clean(x)}" for x in batch["asr_output"]]
    tgts = [_clean(x) for x in batch["sentence"]]

    # 1) Encode inputs
    model_inputs = tokenizer(
        srcs,
        truncation=True,
        max_length=MAX_INPUT_LEN,
        padding=False,               
        return_attention_mask=True,
    )

    # 2) Encode targets 
    target_enc = tokenizer(
        text_target=tgts,
        truncation=True,
        max_length=MAX_TARGET_LEN,
        padding=False,
    )
    model_inputs["labels"] = target_enc["input_ids"]

    return model_inputs

training_data = training_data.filter(lambda ex: ex["asr_output"] and ex["sentence"])
val_data = val_data.filter(lambda ex: ex["asr_output"] and ex["sentence"])

tokenized_training_dataset = training_data.map(
    preprocess_batch, batched=True, remove_columns=training_data.column_names
)
tokenized_test_dataset = val_data.map(
    preprocess_batch, batched=True, remove_columns=val_data.column_names
)

print("dataset tokenized")



dataset tokenized


In [None]:
### THIS CELL is for defining what's needed for training ###

import torch
from evaluate import load as load_metric
from transformers import (
    T5ForConditionalGeneration,
    DataCollatorForSeq2Seq,
)
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
import math

torch.cuda.empty_cache()

# model loading 
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME,dtype=PRECISION, ignore_mismatched_sizes=True).to(device)

# data_collator for dynamic padding
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    label_pad_token_id=-100,   
    pad_to_multiple_of=PAD_TO_MULTIPLE_OF      
)

# DataLoader to handle batch size and,tokenized dataset, and to collate
train_loader = DataLoader(tokenized_training_dataset,shuffle=True, batch_size=TRAIN_BATCH_SIZE, collate_fn=data_collator)
test_loader = DataLoader(tokenized_test_dataset,shuffle=False, batch_size=EVAL_BATCH_SIZE, collate_fn=data_collator)


# to compute num_warmup_steps and num_training_steps
train_size = len(tokenized_training_dataset)
steps_per_epoch = math.ceil(train_size / (TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS))
total_steps = steps_per_epoch * NUM_TRAIN_EPOCHS if MAX_STEPS <= 0 else MAX_STEPS
warmup_steps = int(WARMUP_RATIO * total_steps)
print(train_size,steps_per_epoch,warmup_steps)

# optimizer
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE,weight_decay=WEIGHT_DECAY)
lr_scheduler = get_scheduler(LR_SCHEDULER, optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)




print("finished defining what's needed for training")





89948 1406 253
finished defining what's needed for training


In [None]:
### THIS CELL is for function saving model during training ###

from pathlib import Path

def save_epoch_checkpoint(model, tokenizer, epoch, base_dir = "checkpoints", model_tag = MODEL_TAG,gpu_tag = GPU_TAG):
    save_path = Path(base_dir) / f"{model_tag}_{gpu_tag}_epoch{epoch+1}"
    save_path.mkdir(parents=True, exist_ok=True)
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)
    print(f"[checkpoint] Saved epoch {epoch+1} to {save_path.resolve()}")
    
    
    
    

In [None]:
### THIS CELL is for training on H200 - new ###
import os
import torch
from tqdm import tqdm
from jiwer import wer  # <-- added
import unicodedata, re # <-- added for light normalization

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

def tqdm_print(*args, **kwargs):
    tqdm.write(" ".join(map(str, args)), **kwargs)
torch.cuda.empty_cache()

# ---- light, consistent normalization for WER ----
def _norm(s: str) -> str:
    s = unicodedata.normalize("NFKC", (s or "").strip())
    s = re.sub(r"\s+", " ", s)
    return s

# ---------- ONE-TIME GENERATION DEFAULTS ----------
model.generation_config.num_beams = 1        # greedy
model.generation_config.do_sample = False    
model.generation_config.max_new_tokens = MAX_TARGET_LEN 
model.generation_config.eos_token_id = tokenizer.eos_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id

MAX_EVAL_BATCHES = 50  
train_losses, val_losses, wer_scores = [], [], []  # wer_scores will hold epoch-level WER

print("training started")
for epoch in range(NUM_TRAIN_EPOCHS):
    # ------------------- TRAIN ------------------- #
    train_loop = tqdm(train_loader, leave=True, desc=f"Train | Epoch {epoch}", dynamic_ncols=True, position=0)
    model.train()
    
    for batch in train_loop:
        optimizer.zero_grad()

        outputs = model(
            input_ids=batch["input_ids"].to(device),
            attention_mask=batch["attention_mask"].to(device),
            labels=batch["labels"].to(device),
        )
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()

        train_loop.set_postfix(loss=float(loss.item()))
        train_losses.append(float(loss.item()))
    
    tqdm_print("Finished train epoch: ", epoch)

    # ------------------- VALIDATION ------------------- #
    model.eval()
    pred_texts, ref_texts = [], []
    epoch_val_losses = []
    batch_wers = []  # <-- collect batch-level WERs to report per-epoch mean

    with torch.inference_mode():
        val_loop = tqdm(test_loader, leave=True, desc=f"Val | Epoch {epoch}")
        for i, batch in enumerate(val_loop):
            if i >= MAX_EVAL_BATCHES:
                val_loop.set_postfix(info=f"stopped at {MAX_EVAL_BATCHES} batches")
                break

            outputs = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                labels=batch["labels"].to(device),
            )
            val_loss = float(outputs.loss.item())
            epoch_val_losses.append(val_loss)

            # Deterministic decoding for WER
            gen_ids = model.generate(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                return_dict_in_generate=False,
            )

            # Decode predictions
            preds = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

            # Prepare and decode references (replace -100 with pad before decoding)
            labels = batch["labels"].clone()
            labels[labels == -100] = tokenizer.pad_token_id
            refs = tokenizer.batch_decode(labels, skip_special_tokens=True)

            # strip + normalize
            preds = [_norm(p) for p in preds]
            refs  = [_norm(r) for r in refs]

            # accumulate for optional later use
            pred_texts.extend(preds)
            ref_texts.extend(refs)

            # ---- batch-level WER (jiwer supports list-of-strings) ----
            batch_wer = wer(refs, preds)
            batch_wers.append(batch_wer)

            val_loop.set_postfix(loss=val_loss, batch_WER=f"{batch_wer:.4f}", n_preds=len(pred_texts))

    # Aggregate val loss and compute WER once per epoch (on the capped subset)
    mean_val_loss = (sum(epoch_val_losses) / len(epoch_val_losses)) if epoch_val_losses else float("nan")
    val_losses.append(mean_val_loss)

    # ---- epoch-level mean WER ----
    epoch_mean_wer = (sum(batch_wers) / len(batch_wers)) if batch_wers else float("nan")
    wer_scores.append(epoch_mean_wer)

    print(f"Finished validation epoch {epoch} | ValLoss(subset): {mean_val_loss:.4f} | MeanWER(epoch): {epoch_mean_wer:.4f}")

    # ------------------- CHECKPOINT -------------------
    save_epoch_checkpoint(
        model, tokenizer, epoch,
        base_dir="checkpoints",
        model_tag=MODEL_TAG,
        gpu_tag=GPU_TAG
    )
    print(f"saved model at epoch {epoch}")

# Optional: final summary
last_val = val_losses[-1] if len(val_losses) else float("nan")
last_wer = wer_scores[-1] if len(wer_scores) else float("nan")
print(f"Training complete. Epochs: {NUM_TRAIN_EPOCHS}, Last ValLoss(subset): {last_val:.4f}, Last MeanWER(epoch): {last_wer:.4f}")


In [None]:
### THIS CELL is for saving model and upload to HF ###

from pathlib import Path
import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Derive a short, filesystem-safe GPU tag (or set GPU_TAG manually)
def _gpu_tag():
    if torch.cuda.is_available():
        name = torch.cuda.get_device_name(0)
    else:
        name = "cpu"
    tag = re.sub(r"[^A-Za-z0-9]+", "_", name).strip("_").lower()
    return tag

MODEL_TAG = Path(MODEL_NAME).name         
GPU_TAG   = _gpu_tag()       

SAVE_DIR = Path.cwd() / "checkpoints" / f"{MODEL_TAG}__{GPU_TAG}"
SAVE_DIR.mkdir(parents=True, exist_ok=True)

model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)

print(f"Saved to: {SAVE_DIR.resolve()}")

!huggingface-cli login --token YOUR_HF_KEY

ckpt_dir = f"checkpoints/{OUTPUT_DIR}-new"   
repo_id  = "Gal-Jakob/byt5small-h200-new"  

# Load from your saved folder
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir)
model     = AutoModelForSeq2SeqLM.from_pretrained(ckpt_dir)

# Create repo if needed and push
tokenizer.push_to_hub(repo_id)
model.push_to_hub(repo_id, commit_message="Initial upload of fine-tuned ByT5-small")
