In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torchaudio
import pandas as pd
from tqdm import tqdm
from jiwer import wer, cer
import gc
import types
from peft import PeftModel
import warnings

# --- Suppress Warnings ---
warnings.filterwarnings("ignore")

# Hugging Face Imports
from transformers import (
    Wav2Vec2ConformerForCTC, 
    Wav2Vec2FeatureExtractor,
    get_linear_schedule_with_warmup,
    logging as hf_logging
)
from peft import get_peft_model, LoraConfig
from huggingface_hub import login
from torch.cuda.amp import autocast, GradScaler

hf_logging.set_verbosity_error()

# --- Configuration ---
CONFIG = {
    "train_csv": "geo/train.csv",
    "val_csv": "geo/dev.csv",
    "hf_token": None,
    
    # --- TUNING V3 ---
    "mask_time_prob": 0.075,     
    "mask_time_length": 10,
    "mask_feature_prob": 0.075,
    "mask_feature_length": 10,
    
    "lora_rank": 64,             
    "lora_alpha": 128,
    "lora_dropout": 0.1,
    
    "batch_size": 2,           
    "grad_accum_steps": 4,     
    "max_audio_len": 160000,   
    "learning_rate": 1e-4,       
    "num_epochs": 25,            
    "base_model": "facebook/wav2vec2-conformer-rope-large-960h-ft",
    "target_modules": ["linear_q", "linear_k", "linear_v", "linear_out", "intermediate_dense", "output_dense"],
    
    # Checkpoint Configuration
    # Make sure this path points to the folder containing 'adapter_model.bin'
    "checkpoint_path": "conformer_lora_best", 
    "previous_best_wer": 0.38
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Auth ---
auth_token_arg = False 
if CONFIG["hf_token"]:
    login(token=CONFIG["hf_token"])
    auth_token_arg = True

# --- Vocabulary Builder ---
class Vocabulary:
    def __init__(self, csv_paths):
        self.vocab = {"<pad>": 0, "<s>": 1, "</s>": 2, "<unk>": 3, "|": 4}
        self.idx = 5
        self._build_vocab(csv_paths)
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        
    def _build_vocab(self, paths):
        chars = set()
        for path in paths:
            if not os.path.exists(path): continue
            df = pd.read_csv(path)
            all_text = "".join(df['transcript'].fillna("").astype(str).tolist())
            chars.update(list(all_text))
        for c in sorted(chars):
            if c not in self.vocab:
                self.vocab[c] = self.idx
                self.idx += 1
    
    def encode(self, text):
        text = text.replace(" ", "|")
        return [self.vocab.get(c, self.vocab["<unk>"]) for c in text]
    
    def decode(self, tokens):
        res = []
        for t in tokens:
            if t == 0: continue 
            char = self.inv_vocab.get(t, "")
            res.append(char)
        return "".join(res).replace("|", " ").replace("<s>", "").replace("</s>", "")

# --- Dataset ---
class EsperantoDataset(Dataset):
    def __init__(self, csv_path, vocab, processor, max_len=None):
        self.df = pd.read_csv(csv_path)
        self.vocab = vocab
        self.processor = processor 
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        audio_path = self.df.iloc[idx]['file']
        transcript = str(self.df.iloc[idx]['transcript']) if pd.notna(self.df.iloc[idx]['transcript']) else ""

        try:
            waveform, sr = torchaudio.load(audio_path)
            if sr != 16000:
                waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            waveform = waveform.squeeze()

            if self.max_len and waveform.size(0) > self.max_len:
                waveform = waveform[:self.max_len]

            features = self.processor(waveform, sampling_rate=16000, return_tensors="pt").input_values[0]
            labels = torch.tensor(self.vocab.encode(transcript), dtype=torch.long)
            return features, labels
        except Exception:
            return torch.zeros(16000), torch.tensor([0], dtype=torch.long)

def collate_fn(batch):
    batch = [item for item in batch if item[1].sum() != 0]
    if not batch: return torch.empty(0), torch.empty(0), torch.empty(0)
    
    features, labels = zip(*batch)
    features_padded = pad_sequence(features, batch_first=True, padding_value=0.0)
    attention_mask = (features_padded != 0).long() 
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=-100)
    return features_padded, attention_mask, labels_padded

# --- Model Setup ---
def get_lora_conformer(vocab_size):
    print(f"Loading Base Model: {CONFIG['base_model']}...")
    
    model = Wav2Vec2ConformerForCTC.from_pretrained(
        CONFIG["base_model"], 
        ctc_loss_reduction="mean", 
        pad_token_id=0,
        vocab_size=vocab_size,
        ignore_mismatched_sizes=True,
        token=auth_token_arg,
        mask_time_prob=CONFIG["mask_time_prob"],
        mask_time_length=CONFIG["mask_time_length"],
        mask_feature_prob=CONFIG["mask_feature_prob"],
        mask_feature_length=CONFIG["mask_feature_length"],
    )
    
    # --- Monkey Patching ---
    def enable_input_require_grads(self):
        def make_inputs_require_grads(module, input, output):
            if isinstance(output, tuple): output[0].requires_grad_(True) 
            else: output.requires_grad_(True)
        self._require_grads_hook = self.wav2vec2_conformer.feature_projection.register_forward_hook(make_inputs_require_grads)

    model.enable_input_require_grads = types.MethodType(enable_input_require_grads, model)
    def get_input_embeddings(self): return self.wav2vec2_conformer.feature_projection
    model.get_input_embeddings = types.MethodType(get_input_embeddings, model)
    
    model.gradient_checkpointing_enable()
    
    # Freeze Logic - Start by freezing everything
    for param in model.parameters():
        param.requires_grad = False

    checkpoint_exists = os.path.exists(CONFIG["checkpoint_path"])

    if checkpoint_exists:
        print(f"Checkpoint found! Loading LoRA weights from {CONFIG['checkpoint_path']}...")
        # Load adapters and ensure they are trainable
        model = PeftModel.from_pretrained(model, CONFIG["checkpoint_path"], is_trainable=True)
        
        # FORCE TRAINABLE: Sometimes is_trainable=True isn't enough if base is frozen
        for name, param in model.named_parameters():
            if "lora" in name:
                param.requires_grad = True
    else:
        print("Initializing new LoRA adapters...")
        peft_config = LoraConfig(
            inference_mode=False,
            r=CONFIG["lora_rank"],
            lora_alpha=CONFIG["lora_alpha"],
            lora_dropout=CONFIG["lora_dropout"],
            target_modules=CONFIG["target_modules"]
        )
        model = get_peft_model(model, peft_config)
    
    # Re-apply Unfreezing Logic (Must be done AFTER loading PEFT)
    
    # 1. Unfreeze LayerNorms
    for name, param in model.named_parameters():
        if "layer_norm" in name:
            param.requires_grad = True
            
    # 2. Unfreeze Feature Extractor LayerNorms
    # Accessing the base model inside PeftModel wrapper
    base = model.base_model.model if hasattr(model.base_model, "model") else model.base_model
    for name, param in base.wav2vec2_conformer.feature_extractor.named_parameters():
        if "layer_norm" in name or "ln" in name: 
             param.requires_grad = True

    # 3. Unfreeze LM Head
    if hasattr(model, "lm_head"):
         for param in model.lm_head.parameters(): param.requires_grad = True
    elif hasattr(model.base_model, "lm_head"): 
         for param in model.base_model.lm_head.parameters(): param.requires_grad = True

    model.print_trainable_parameters()
    return model, checkpoint_exists

# --- Main ---
def main():
    torch.cuda.empty_cache()
    gc.collect()

    vocab = Vocabulary([CONFIG["train_csv"], CONFIG["val_csv"]])
    processor = Wav2Vec2FeatureExtractor.from_pretrained(CONFIG["base_model"], token=auth_token_arg)
    
    train_dataset = EsperantoDataset(CONFIG["train_csv"], vocab, processor, max_len=CONFIG["max_audio_len"])
    val_dataset = EsperantoDataset(CONFIG["val_csv"], vocab, processor) 

    train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True, collate_fn=collate_fn, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False, collate_fn=collate_fn, num_workers=1)

    model, checkpoint_loaded = get_lora_conformer(len(vocab.vocab))
    model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"])
    scaler = GradScaler()
    
    # Adjust steps for remaining epochs if resuming? 
    # Simpler to just treat it as a new run with 'num_epochs'
    num_training_steps = len(train_loader) * CONFIG["num_epochs"] // CONFIG["grad_accum_steps"]
    num_warmup_steps = int(0.1 * num_training_steps) 
    
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
    )

    status = "RESUMED" if checkpoint_loaded else "NEW"
    print(f"\nStarting {status} Training V3...")
    
    if checkpoint_loaded:
        best_wer = CONFIG.get("previous_best_wer", float('inf'))
        print(f"Resuming with previous best WER baseline: {best_wer}")
    else:
        best_wer = float('inf')

    for epoch in range(CONFIG["num_epochs"]):
        model.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['num_epochs']}")
        
        for step, (features, mask, labels) in enumerate(pbar):
            if features.size(0) == 0: continue 

            features = features.to(device)
            mask = mask.to(device)
            labels = labels.to(device)

            with autocast():
                outputs = model(features, attention_mask=mask, labels=labels)
                loss = outputs.loss / CONFIG["grad_accum_steps"]

            scaler.scale(loss).backward()

            if (step + 1) % CONFIG["grad_accum_steps"] == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()

            total_loss += loss.item() * CONFIG["grad_accum_steps"]
            pbar.set_postfix({"Loss": f"{loss.item()*CONFIG['grad_accum_steps']:.4f}", "LR": f"{scheduler.get_last_lr()[0]:.2e}"})

        avg_train_loss = total_loss / len(train_loader)

        # Evaluation
        model.eval()
        refs, preds = [], []
        
        print(f"Evaluating Epoch {epoch+1}...")
        with torch.no_grad():
            for features, mask, labels in val_loader:
                if features.size(0) == 0: continue
                features = features.to(device)
                mask = mask.to(device)
                
                with autocast():
                    outputs = model(features, attention_mask=mask)
                
                logits = outputs.logits.float()
                pred_ids = torch.argmax(logits, dim=-1)
                
                for i in range(len(labels)):
                    label_idx = labels[i][labels[i] != -100].cpu().tolist()
                    pred_idx = pred_ids[i].cpu().tolist()
                    
                    pred_str_raw = []
                    prev_token = -1
                    for token in pred_idx:
                        if token != prev_token and token != 0: 
                            pred_str_raw.append(token)
                        prev_token = token
                    
                    refs.append(vocab.decode(label_idx))
                    preds.append(vocab.decode(pred_str_raw))

        epoch_wer = wer(refs, preds) if len(refs) > 0 else 1.0
        epoch_cer = cer(refs, preds) if len(refs) > 0 else 1.0

        print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | WER: {epoch_wer:.4f} | CER: {epoch_cer:.4f}")
        
        if epoch_wer < best_wer:
            best_wer = epoch_wer
            model.save_pretrained(CONFIG["checkpoint_path"])
            print(f"New Best Model Saved! (WER: {best_wer:.4f})")
        else:
            print(f"No improvement over best WER: {best_wer:.4f}")
            
    torch.cuda.empty_cache()

if __name__ == "__main__":
    main()

2025-11-22 21:41:39.045793: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-22 21:41:39.084854: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Using device: cuda
Loading Base Model: facebook/wav2vec2-conformer-rope-large-960h-ft...
Checkpoint found in conformer_lora_best! Attempting to load...
Weight Sum BEFORE Load (LM Head): 1.6663
Weight Sum AFTER Load (LM Head): 1.6663
trainable params: 293,925 || all params: 604,392,101 || trainable%: 0.0486

Starting RESUMED Training V3...
Resuming with previous best WER baseline: 0.38


Epoch 1/25: 100%|█████████████████████████████████████████| 3000/3000 [13:18<00:00,  3.76it/s, Loss=3.2242, LR=4.00e-05]

Evaluating Epoch 1...





Epoch 1 | Train Loss: 13.9236 | WER: 1.0000 | CER: 1.0000
No improvement over best WER: 0.3800


Epoch 2/25:   5%|█▉                                        | 142/3000 [00:31<10:24,  4.58it/s, Loss=3.1312, LR=4.19e-05]


KeyboardInterrupt: 