In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torchaudio
import pandas as pd
from tqdm import tqdm
from jiwer import wer, cer
import gc
import types
import warnings
import logging
import json
import shutil

# Hugging Face Imports
from transformers import (
    Wav2Vec2ForCTC, 
    Wav2Vec2FeatureExtractor,
    get_linear_schedule_with_warmup,
    logging as hf_logging
)
from transformers import get_cosine_schedule_with_warmup
from peft import get_peft_model, LoraConfig, PeftModel
from huggingface_hub import login
from torch.cuda.amp import autocast, GradScaler

# --- Suppress Warnings ---
warnings.filterwarnings("ignore")
hf_logging.set_verbosity_error()
logging.getLogger("torchaudio").setLevel(logging.ERROR)

# --- Configuration ---
CONFIG = {
    "train_csv": "geo/train.csv",
    "val_csv": "geo/dev.csv",
    "hf_token": None,  
    
    # Hyperparameters
    "mask_time_prob": 0.025,
    "mask_time_length": 10,
    "mask_feature_prob": 0.025,
    "mask_feature_length": 10,
    
    "lora_rank": 64,             
    "lora_alpha": 128,
    "lora_dropout": 0.1,
    
    "batch_size": 6,           
    "grad_accum_steps": 16,     
    "max_audio_len": 160000,   
    "learning_rate": 5e-6,     
    "num_epochs": 30,            
    
    "base_model": "facebook/wav2vec2-large-xlsr-53",
    
    "target_modules": ["q_proj", "k_proj", "v_proj", "out_proj", "intermediate_dense", "output_dense"],
    "checkpoint_path": "xlsr_lora_gibberish_best",
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if CONFIG["hf_token"]:
    login(token=CONFIG["hf_token"])

# --- Vocabulary Builder ---
class Vocabulary:
    def __init__(self, csv_paths):
        self.vocab = {"<pad>": 0, "<s>": 1, "</s>": 2, "<unk>": 3, "|": 4}
        self.idx = 5
        self._build_vocab(csv_paths)
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        
    def _build_vocab(self, paths):
        chars = set()
        for path in paths:
            if not os.path.exists(path): continue
            df = pd.read_csv(path)
            all_text = "".join(df['transcript'].fillna("").astype(str).tolist())
            chars.update(list(all_text))
        
        for c in sorted(chars):
            if c not in self.vocab:
                self.vocab[c] = self.idx
                self.idx += 1
    
    def encode(self, text):
        text = str(text).replace(" ", "|")
        return [self.vocab.get(c, self.vocab["<unk>"]) for c in text]
    
    def decode(self, tokens):
        res = []
        for t in tokens:
            if t == 0: continue 
            char = self.inv_vocab.get(t, "")
            res.append(char)
        return "".join(res).replace("|", " ").replace("<s>", "").replace("</s>", "")

# --- Dataset ---
class EsperantoDataset(Dataset):
    def __init__(self, csv_path, vocab, processor, max_len=None):
        self.df = pd.read_csv(csv_path)
        initial_len = len(self.df)
        self.df = self.df[self.df['file'].apply(os.path.exists)].reset_index(drop=True)
        print(f"Loaded {len(self.df)} valid samples from {csv_path} (Filtered out {initial_len - len(self.df)})")
        self.vocab = vocab
        self.processor = processor 
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        audio_path = self.df.iloc[idx]['file']
        transcript = self.df.iloc[idx]['transcript']
        if pd.isna(transcript): transcript = ""
        transcript = str(transcript)

        try:
            waveform, sr = torchaudio.load(audio_path)
            if sr != 16000:
                waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            waveform = waveform.squeeze()
            if self.max_len and waveform.size(0) > self.max_len:
                waveform = waveform[:self.max_len]

            features = self.processor(waveform, sampling_rate=16000, return_tensors="pt").input_values[0]
            labels = torch.tensor(self.vocab.encode(transcript), dtype=torch.long)
            return features, labels
        except Exception as e:
            print(f"Error loading {audio_path}: {e}")
            return None

def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if not batch: return torch.empty(0), torch.empty(0), torch.empty(0)
    features, labels = zip(*batch)
    features_padded = pad_sequence(features, batch_first=True, padding_value=0.0)
    attention_mask = (features_padded != 0).long() 
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=-100)
    return features_padded, attention_mask, labels_padded

# --- Model Loading Logic (Modified for Resume) ---
def get_lora_model(vocab_size):
    print(f"Loading Base Model: {CONFIG['base_model']}...")
    use_auth_token = CONFIG["hf_token"] if CONFIG["hf_token"] else False

    model = Wav2Vec2ForCTC.from_pretrained(
        CONFIG["base_model"], 
        ctc_loss_reduction="mean", 
        pad_token_id=0,
        vocab_size=vocab_size,
        ignore_mismatched_sizes=True,
        token=use_auth_token,
        mask_time_prob=CONFIG["mask_time_prob"],
        mask_time_length=CONFIG["mask_time_length"],
        mask_feature_prob=CONFIG["mask_feature_prob"],
        mask_feature_length=CONFIG["mask_feature_length"],
    )
    
    # Gradient Checkpointing Hook
    if not hasattr(model, "enable_input_require_grads"):
        def enable_input_require_grads(self):
            def make_inputs_require_grads(module, input, output):
                output.requires_grad_(True)
            self.wav2vec2.feature_projection.register_forward_hook(make_inputs_require_grads)
        model.enable_input_require_grads = types.MethodType(enable_input_require_grads, model)

    # Check for existing checkpoint to RESUME
    adapter_path = os.path.join(CONFIG["checkpoint_path"], "adapter_model.safetensors")
    
    if os.path.exists(adapter_path):
        print(f"\n[INFO] Found existing checkpoint at {CONFIG['checkpoint_path']}. RESUMING TRAINING...")
        # is_trainable=True is REQUIRED to continue training
        model = PeftModel.from_pretrained(model, CONFIG["checkpoint_path"], is_trainable=True)
    else:
        print(f"\n[INFO] No checkpoint found. Initializing FRESH LoRA model...")
        modules_to_save = ["lm_head", "layer_norm"] 
        peft_config = LoraConfig(
            inference_mode=False,
            r=CONFIG["lora_rank"],
            lora_alpha=CONFIG["lora_alpha"],
            lora_dropout=CONFIG["lora_dropout"],
            target_modules=CONFIG["target_modules"],
            modules_to_save=modules_to_save
        )
        model = get_peft_model(model, peft_config)
    
    model.gradient_checkpointing_enable()
    model.print_trainable_parameters()
    return model

# --- Helper: Evaluation Function ---
def evaluate(model, dataloader, vocab):
    model.eval()
    refs, preds = [], []
    
    with torch.no_grad():
        for features, mask, labels in dataloader:
            if features.size(0) == 0: continue
            features = features.to(device)
            mask = mask.to(device)
            
            with autocast():
                outputs = model(features, attention_mask=mask)
            
            logits = outputs.logits.float().cpu().numpy()
            pred_ids = logits.argmax(axis=-1)
            
            for i in range(len(labels)):
                label_idx = labels[i][labels[i] != -100].cpu().tolist()
                refs.append(vocab.decode(label_idx))
                
                pred_str_raw = []
                prev_token = -1
                for token in pred_ids[i]:
                    if token != prev_token and token != 0: 
                        pred_str_raw.append(token)
                    prev_token = token
                preds.append(vocab.decode(pred_str_raw))

    epoch_wer = wer(refs, preds) if len(refs) > 0 else 1.0
    epoch_cer = cer(refs, preds) if len(refs) > 0 else 1.0
    return epoch_wer, epoch_cer

# --- Main ---
def main():
    torch.cuda.empty_cache()
    gc.collect()

    vocab = Vocabulary([CONFIG["train_csv"], CONFIG["val_csv"]])
    use_auth_token = CONFIG["hf_token"] if CONFIG["hf_token"] else False

    print("Initializing Processor...")
    processor = Wav2Vec2FeatureExtractor.from_pretrained(CONFIG["base_model"], token=use_auth_token)
    
    train_dataset = EsperantoDataset(CONFIG["train_csv"], vocab, processor, max_len=CONFIG["max_audio_len"])
    val_dataset = EsperantoDataset(CONFIG["val_csv"], vocab, processor) 

    train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True, collate_fn=collate_fn, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False, collate_fn=collate_fn, num_workers=1)

    # LOAD MODEL (Handles Resume)
    model = get_lora_model(len(vocab.vocab))
    model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"])
    scaler = GradScaler()
    
    num_training_steps = len(train_loader) * CONFIG["num_epochs"] // CONFIG["grad_accum_steps"]
    num_warmup_steps = int(0.1 * num_training_steps) 
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)

    print(f"\nStarting Training Loop...")
    
    # --- CRITICAL: Establish Baseline ---
    # If we resumed, we must know the current WER. If we don't check,
    # we might overwrite a good model (WER 0.20) with a bad first epoch (WER 0.25)
    # because best_wer would start at infinity.
    best_wer = float('inf')
    
    if os.path.exists(os.path.join(CONFIG["checkpoint_path"], "adapter_model.safetensors")):
        print("Checking baseline performance of loaded model... ")
        baseline_wer, baseline_cer = evaluate(model, val_loader, vocab)
        best_wer = baseline_wer
        print(f"Baseline restored -> WER: {baseline_wer:.4f} | CER: {baseline_cer:.4f}")
    
    for epoch in range(CONFIG["num_epochs"]):
        model.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['num_epochs']}")
        
        for step, (features, mask, labels) in enumerate(pbar):
            if features.size(0) == 0: continue 

            features = features.to(device)
            mask = mask.to(device)
            labels = labels.to(device)

            with autocast():
                outputs = model(features, attention_mask=mask, labels=labels)
                loss = outputs.loss / CONFIG["grad_accum_steps"]

            scaler.scale(loss).backward()

            if (step + 1) % CONFIG["grad_accum_steps"] == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()

            total_loss += loss.item() * CONFIG["grad_accum_steps"]
            pbar.set_postfix({"Loss": f"{loss.item()*CONFIG['grad_accum_steps']:.4f}", "LR": f"{scheduler.get_last_lr()[0]:.2e}"})

        avg_train_loss = total_loss / len(train_loader)

        # --- Evaluation ---
        print(f"Evaluating Epoch {epoch+1}...")
        epoch_wer, epoch_cer = evaluate(model, val_loader, vocab)
        print(f"Epoch {epoch+1} | Loss: {avg_train_loss:.4f} | WER: {epoch_wer:.4f} | CER: {epoch_cer:.4f}")
        
        if epoch_wer < best_wer:
            best_wer = epoch_wer
            model.save_pretrained(CONFIG["checkpoint_path"])
            print(f"New Best Model Saved! (WER: {best_wer:.4f})")
            
            # Also save vocab whenever we save a new best model
            with open("vocab.json", "w") as f:
                json.dump(vocab.vocab, f)
            
    torch.cuda.empty_cache()
    print("Training Complete.")

if __name__ == "__main__":
    main()

2025-11-23 04:20:04.150263: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-23 04:20:04.189586: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Using device: cuda
Initializing Processor...
Loaded 6000 valid samples from geo/train.csv (Filtered out 0)
Loaded 1000 valid samples from geo/dev.csv (Filtered out 0)
Loading Base Model: facebook/wav2vec2-large-xlsr-53...

[INFO] Found existing checkpoint at xlsr_lora_gibberish_best. RESUMING TRAINING...
trainable params: 28,458,021 || all params: 343,934,666 || trainable%: 8.2743

Starting Training Loop...
Checking baseline performance of loaded model... 
Baseline restored -> WER: 0.2982 | CER: 0.0748


Epoch 1/30: 100%|█████████████████████████████████████████| 1000/1000 [05:40<00:00,  2.93it/s, Loss=0.2482, LR=1.66e-06]

Evaluating Epoch 1...





Epoch 1 | Loss: 0.2166 | WER: 0.2994 | CER: 0.0748


Epoch 2/30: 100%|█████████████████████████████████████████| 1000/1000 [05:40<00:00,  2.93it/s, Loss=0.1789, LR=3.32e-06]

Evaluating Epoch 2...





Epoch 2 | Loss: 0.2143 | WER: 0.2998 | CER: 0.0747


Epoch 3/30: 100%|█████████████████████████████████████████| 1000/1000 [05:41<00:00,  2.93it/s, Loss=0.1928, LR=4.97e-06]

Evaluating Epoch 3...





Epoch 3 | Loss: 0.2116 | WER: 0.3000 | CER: 0.0749


Epoch 4/30: 100%|█████████████████████████████████████████| 1000/1000 [05:42<00:00,  2.92it/s, Loss=0.1534, LR=4.98e-06]

Evaluating Epoch 4...





Epoch 4 | Loss: 0.2137 | WER: 0.3002 | CER: 0.0750


Epoch 5/30:  55%|███████████████████████▏                  | 552/1000 [03:08<02:33,  2.93it/s, Loss=0.1528, LR=4.96e-06]


KeyboardInterrupt: 