In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torchaudio
import pandas as pd
from tqdm import tqdm
from jiwer import wer, cer
import gc
import types
import warnings
import logging
import json
import shutil

# Hugging Face Imports
from transformers import (
    Wav2Vec2ForCTC, 
    Wav2Vec2FeatureExtractor,
    get_linear_schedule_with_warmup,
    logging as hf_logging
)
from peft import get_peft_model, LoraConfig, PeftModel
from huggingface_hub import login
from torch.cuda.amp import autocast, GradScaler

# --- Suppress Warnings ---
warnings.filterwarnings("ignore")
hf_logging.set_verbosity_error()
logging.getLogger("torchaudio").setLevel(logging.ERROR)

# --- Configuration ---
CONFIG = {
    "train_csv": "geo/train.csv",
    "val_csv": "geo/dev.csv",
    "hf_token": None,  
    
    # Hyperparameters
    "mask_time_prob": 0.05,
    "mask_time_length": 10,
    "mask_feature_prob": 0.05,
    "mask_feature_length": 10,
    
    "lora_rank": 64,             
    "lora_alpha": 128,
    "lora_dropout": 0.1,
    
    "batch_size": 3,           
    "grad_accum_steps": 8,     
    "max_audio_len": 160000,   
    "learning_rate": 1e-4,     
    "num_epochs": 50,            
    
    "base_model": "facebook/wav2vec2-xls-r-1b",
    
    "target_modules": ["q_proj", "k_proj", "v_proj", "out_proj", "intermediate_dense", "output_dense"],
    "checkpoint_path": "xlsr_lora_1b_best",
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if CONFIG["hf_token"]:
    login(token=CONFIG["hf_token"])

# --- Vocabulary Builder ---
class Vocabulary:
    def __init__(self, csv_paths):
        self.vocab = {"<pad>": 0, "<s>": 1, "</s>": 2, "<unk>": 3, "|": 4}
        self.idx = 5
        self._build_vocab(csv_paths)
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        
    def _build_vocab(self, paths):
        chars = set()
        for path in paths:
            if not os.path.exists(path): continue
            df = pd.read_csv(path)
            all_text = "".join(df['transcript'].fillna("").astype(str).tolist())
            chars.update(list(all_text))
        
        for c in sorted(chars):
            if c not in self.vocab:
                self.vocab[c] = self.idx
                self.idx += 1
    
    def encode(self, text):
        text = str(text).replace(" ", "|")
        return [self.vocab.get(c, self.vocab["<unk>"]) for c in text]
    
    def decode(self, tokens):
        res = []
        for t in tokens:
            if t == 0: continue 
            char = self.inv_vocab.get(t, "")
            res.append(char)
        return "".join(res).replace("|", " ").replace("<s>", "").replace("</s>", "")

# --- Dataset ---
class EsperantoDataset(Dataset):
    def __init__(self, csv_path, vocab, processor, max_len=None):
        self.df = pd.read_csv(csv_path)
        initial_len = len(self.df)
        self.df = self.df[self.df['file'].apply(os.path.exists)].reset_index(drop=True)
        print(f"Loaded {len(self.df)} valid samples from {csv_path} (Filtered out {initial_len - len(self.df)})")
        self.vocab = vocab
        self.processor = processor 
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        audio_path = self.df.iloc[idx]['file']
        transcript = self.df.iloc[idx]['transcript']
        if pd.isna(transcript): transcript = ""
        transcript = str(transcript)

        try:
            waveform, sr = torchaudio.load(audio_path)
            if sr != 16000:
                waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            waveform = waveform.squeeze()
            if self.max_len and waveform.size(0) > self.max_len:
                waveform = waveform[:self.max_len]

            features = self.processor(waveform, sampling_rate=16000, return_tensors="pt").input_values[0]
            labels = torch.tensor(self.vocab.encode(transcript), dtype=torch.long)
            return features, labels
        except Exception as e:
            print(f"Error loading {audio_path}: {e}")
            return None

def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if not batch: return torch.empty(0), torch.empty(0), torch.empty(0)
    features, labels = zip(*batch)
    features_padded = pad_sequence(features, batch_first=True, padding_value=0.0)
    attention_mask = (features_padded != 0).long() 
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=-100)
    return features_padded, attention_mask, labels_padded

# --- Model Loading Logic (Modified for Resume) ---
def get_lora_model(vocab_size):
    print(f"Loading Base Model: {CONFIG['base_model']}...")
    use_auth_token = CONFIG["hf_token"] if CONFIG["hf_token"] else False

    model = Wav2Vec2ForCTC.from_pretrained(
        CONFIG["base_model"], 
        ctc_loss_reduction="mean", 
        pad_token_id=0,
        vocab_size=vocab_size,
        ignore_mismatched_sizes=True,
        token=use_auth_token,
        mask_time_prob=CONFIG["mask_time_prob"],
        mask_time_length=CONFIG["mask_time_length"],
        mask_feature_prob=CONFIG["mask_feature_prob"],
        mask_feature_length=CONFIG["mask_feature_length"],
    )
    
    # Gradient Checkpointing Hook
    if not hasattr(model, "enable_input_require_grads"):
        def enable_input_require_grads(self):
            def make_inputs_require_grads(module, input, output):
                output.requires_grad_(True)
            self.wav2vec2.feature_projection.register_forward_hook(make_inputs_require_grads)
        model.enable_input_require_grads = types.MethodType(enable_input_require_grads, model)

    # Check for existing checkpoint to RESUME
    adapter_path = os.path.join(CONFIG["checkpoint_path"], "adapter_model.safetensors")
    
    if os.path.exists(adapter_path):
        print(f"\n[INFO] Found existing checkpoint at {CONFIG['checkpoint_path']}. RESUMING TRAINING...")
        # is_trainable=True is REQUIRED to continue training
        model = PeftModel.from_pretrained(model, CONFIG["checkpoint_path"], is_trainable=True)
    else:
        print(f"\n[INFO] No checkpoint found. Initializing FRESH LoRA model...")
        modules_to_save = ["lm_head", "layer_norm"] 
        peft_config = LoraConfig(
            inference_mode=False,
            r=CONFIG["lora_rank"],
            lora_alpha=CONFIG["lora_alpha"],
            lora_dropout=CONFIG["lora_dropout"],
            target_modules=CONFIG["target_modules"],
            modules_to_save=modules_to_save
        )
        model = get_peft_model(model, peft_config)
    
    model.gradient_checkpointing_enable()
    model.print_trainable_parameters()
    return model

# --- Helper: Evaluation Function ---
def evaluate(model, dataloader, vocab):
    model.eval()
    refs, preds = [], []
    
    with torch.no_grad():
        for features, mask, labels in dataloader:
            if features.size(0) == 0: continue
            features = features.to(device)
            mask = mask.to(device)
            
            with autocast():
                outputs = model(features, attention_mask=mask)
            
            logits = outputs.logits.float().cpu().numpy()
            pred_ids = logits.argmax(axis=-1)
            
            for i in range(len(labels)):
                label_idx = labels[i][labels[i] != -100].cpu().tolist()
                refs.append(vocab.decode(label_idx))
                
                pred_str_raw = []
                prev_token = -1
                for token in pred_ids[i]:
                    if token != prev_token and token != 0: 
                        pred_str_raw.append(token)
                    prev_token = token
                preds.append(vocab.decode(pred_str_raw))

    epoch_wer = wer(refs, preds) if len(refs) > 0 else 1.0
    epoch_cer = cer(refs, preds) if len(refs) > 0 else 1.0
    return epoch_wer, epoch_cer

# --- Main ---
def main():
    torch.cuda.empty_cache()
    gc.collect()

    vocab = Vocabulary([CONFIG["train_csv"], CONFIG["val_csv"]])
    use_auth_token = CONFIG["hf_token"] if CONFIG["hf_token"] else False

    print("Initializing Processor...")
    processor = Wav2Vec2FeatureExtractor.from_pretrained(CONFIG["base_model"], token=use_auth_token)
    
    train_dataset = EsperantoDataset(CONFIG["train_csv"], vocab, processor, max_len=CONFIG["max_audio_len"])
    val_dataset = EsperantoDataset(CONFIG["val_csv"], vocab, processor) 

    train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True, collate_fn=collate_fn, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False, collate_fn=collate_fn, num_workers=1)

    # LOAD MODEL (Handles Resume)
    model = get_lora_model(len(vocab.vocab))
    model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"])
    scaler = GradScaler()
    
    num_training_steps = len(train_loader) * CONFIG["num_epochs"] // CONFIG["grad_accum_steps"]
    num_warmup_steps = int(0.1 * num_training_steps) 
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
    )

    print(f"\nStarting Training Loop...")
    
    # --- CRITICAL: Establish Baseline ---
    # If we resumed, we must know the current WER. If we don't check,
    # we might overwrite a good model (WER 0.20) with a bad first epoch (WER 0.25)
    # because best_wer would start at infinity.
    best_wer = float('inf')
    
    if os.path.exists(os.path.join(CONFIG["checkpoint_path"], "adapter_model.safetensors")):
        print("Checking baseline performance of loaded model... ")
        baseline_wer, baseline_cer = evaluate(model, val_loader, vocab)
        best_wer = baseline_wer
        print(f"Baseline restored -> WER: {baseline_wer:.4f} | CER: {baseline_cer:.4f}")
    
    for epoch in range(CONFIG["num_epochs"]):
        model.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['num_epochs']}")
        
        for step, (features, mask, labels) in enumerate(pbar):
            if features.size(0) == 0: continue 

            features = features.to(device)
            mask = mask.to(device)
            labels = labels.to(device)

            with autocast():
                outputs = model(features, attention_mask=mask, labels=labels)
                loss = outputs.loss / CONFIG["grad_accum_steps"]

            scaler.scale(loss).backward()

            if (step + 1) % CONFIG["grad_accum_steps"] == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()

            total_loss += loss.item() * CONFIG["grad_accum_steps"]
            pbar.set_postfix({"Loss": f"{loss.item()*CONFIG['grad_accum_steps']:.4f}", "LR": f"{scheduler.get_last_lr()[0]:.2e}"})

        avg_train_loss = total_loss / len(train_loader)

        # --- Evaluation ---
        print(f"Evaluating Epoch {epoch+1}...")
        epoch_wer, epoch_cer = evaluate(model, val_loader, vocab)

        print(f"Epoch {epoch+1} | Loss: {avg_train_loss:.4f} | WER: {epoch_wer:.4f} | CER: {epoch_cer:.4f}")
        
        if epoch_wer < best_wer:
            best_wer = epoch_wer
            model.save_pretrained(CONFIG["checkpoint_path"])
            print(f"New Best Model Saved! (WER: {best_wer:.4f})")
            
            # Also save vocab whenever we save a new best model
            with open("vocab_1b.json", "w") as f:
                json.dump(vocab.vocab, f)
            
    torch.cuda.empty_cache()
    print("Training Complete.")

if __name__ == "__main__":
    main()

2025-11-23 05:20:07.547151: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-23 05:20:07.585439: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Using device: cuda
Initializing Processor...
Loaded 6000 valid samples from geo/train.csv (Filtered out 0)
Loaded 1000 valid samples from geo/dev.csv (Filtered out 0)
Loading Base Model: facebook/wav2vec2-xls-r-1b...

[INFO] Found existing checkpoint at xlsr_lora_1b_best. RESUMING TRAINING...
trainable params: 71,082,789 || all params: 1,033,627,594 || trainable%: 6.8770

Starting Training Loop...
Checking baseline performance of loaded model... 
Baseline restored -> WER: 1.0000 | CER: 1.0000


Epoch 1/50: 100%|█████████████████████████████████████████| 2000/2000 [13:40<00:00,  2.44it/s, Loss=2.9223, LR=2.00e-05]

Evaluating Epoch 1...





Epoch 1 | Loss: 3.0872 | WER: 1.0000 | CER: 1.0000


Epoch 2/50: 100%|█████████████████████████████████████████| 2000/2000 [13:38<00:00,  2.44it/s, Loss=1.0335, LR=4.00e-05]

Evaluating Epoch 2...





Epoch 2 | Loss: 2.5313 | WER: 0.8409 | CER: 0.2856
New Best Model Saved! (WER: 0.8409)


Epoch 3/50: 100%|█████████████████████████████████████████| 2000/2000 [13:36<00:00,  2.45it/s, Loss=0.4988, LR=6.00e-05]

Evaluating Epoch 3...





Epoch 3 | Loss: 1.0273 | WER: 0.6217 | CER: 0.1595
New Best Model Saved! (WER: 0.6217)


Epoch 4/50: 100%|█████████████████████████████████████████| 2000/2000 [13:35<00:00,  2.45it/s, Loss=0.5830, LR=8.00e-05]

Evaluating Epoch 4...





Epoch 4 | Loss: 0.7475 | WER: 0.5483 | CER: 0.1390
New Best Model Saved! (WER: 0.5483)


Epoch 5/50: 100%|█████████████████████████████████████████| 2000/2000 [13:41<00:00,  2.44it/s, Loss=0.6796, LR=1.00e-04]

Evaluating Epoch 5...





Epoch 5 | Loss: 0.5831 | WER: 0.4812 | CER: 0.1253
New Best Model Saved! (WER: 0.4812)


Epoch 6/50: 100%|█████████████████████████████████████████| 2000/2000 [13:37<00:00,  2.45it/s, Loss=0.5366, LR=9.78e-05]

Evaluating Epoch 6...





Epoch 6 | Loss: 0.4887 | WER: 0.4674 | CER: 0.1230
New Best Model Saved! (WER: 0.4674)


Epoch 7/50: 100%|█████████████████████████████████████████| 2000/2000 [13:36<00:00,  2.45it/s, Loss=0.3604, LR=9.56e-05]

Evaluating Epoch 7...





Epoch 7 | Loss: 0.4270 | WER: 0.4310 | CER: 0.1114
New Best Model Saved! (WER: 0.4310)


Epoch 8/50: 100%|█████████████████████████████████████████| 2000/2000 [13:39<00:00,  2.44it/s, Loss=0.3222, LR=9.33e-05]

Evaluating Epoch 8...





Epoch 8 | Loss: 0.3727 | WER: 0.4078 | CER: 0.1060
New Best Model Saved! (WER: 0.4078)


Epoch 9/50: 100%|█████████████████████████████████████████| 2000/2000 [13:37<00:00,  2.45it/s, Loss=0.8841, LR=9.11e-05]

Evaluating Epoch 9...





Epoch 9 | Loss: 0.3487 | WER: 0.3790 | CER: 0.0987
New Best Model Saved! (WER: 0.3790)


Epoch 10/50: 100%|████████████████████████████████████████| 2000/2000 [13:32<00:00,  2.46it/s, Loss=0.2533, LR=8.89e-05]

Evaluating Epoch 10...





Epoch 10 | Loss: 0.3199 | WER: 0.3605 | CER: 0.0933
New Best Model Saved! (WER: 0.3605)


Epoch 11/50: 100%|████████████████████████████████████████| 2000/2000 [13:37<00:00,  2.45it/s, Loss=0.2373, LR=8.67e-05]

Evaluating Epoch 11...





Epoch 11 | Loss: 0.3013 | WER: 0.3432 | CER: 0.0870
New Best Model Saved! (WER: 0.3432)


Epoch 12/50: 100%|████████████████████████████████████████| 2000/2000 [13:34<00:00,  2.45it/s, Loss=0.3431, LR=8.44e-05]

Evaluating Epoch 12...





Epoch 12 | Loss: 0.2869 | WER: 0.3351 | CER: 0.0857
New Best Model Saved! (WER: 0.3351)


Epoch 13/50: 100%|████████████████████████████████████████| 2000/2000 [13:36<00:00,  2.45it/s, Loss=0.1637, LR=8.22e-05]

Evaluating Epoch 13...





Epoch 13 | Loss: 0.2747 | WER: 0.3182 | CER: 0.0802
New Best Model Saved! (WER: 0.3182)


Epoch 14/50: 100%|████████████████████████████████████████| 2000/2000 [13:36<00:00,  2.45it/s, Loss=0.1857, LR=8.00e-05]

Evaluating Epoch 14...





Epoch 14 | Loss: 0.2566 | WER: 0.3092 | CER: 0.0785
New Best Model Saved! (WER: 0.3092)


Epoch 15/50: 100%|████████████████████████████████████████| 2000/2000 [13:33<00:00,  2.46it/s, Loss=0.2771, LR=7.78e-05]

Evaluating Epoch 15...





Epoch 15 | Loss: 0.2468 | WER: 0.3021 | CER: 0.0758
New Best Model Saved! (WER: 0.3021)


Epoch 16/50: 100%|████████████████████████████████████████| 2000/2000 [13:37<00:00,  2.45it/s, Loss=0.2112, LR=7.56e-05]

Evaluating Epoch 16...





Epoch 16 | Loss: 0.2298 | WER: 0.3042 | CER: 0.0776


Epoch 17/50: 100%|████████████████████████████████████████| 2000/2000 [13:34<00:00,  2.46it/s, Loss=0.2446, LR=7.33e-05]

Evaluating Epoch 17...





Epoch 17 | Loss: 0.2280 | WER: 0.2942 | CER: 0.0741
New Best Model Saved! (WER: 0.2942)


Epoch 18/50: 100%|████████████████████████████████████████| 2000/2000 [13:37<00:00,  2.45it/s, Loss=0.2886, LR=7.11e-05]

Evaluating Epoch 18...





Epoch 18 | Loss: 0.2223 | WER: 0.2940 | CER: 0.0735
New Best Model Saved! (WER: 0.2940)


Epoch 19/50: 100%|████████████████████████████████████████| 2000/2000 [13:37<00:00,  2.45it/s, Loss=0.0551, LR=6.89e-05]

Evaluating Epoch 19...





Epoch 19 | Loss: 0.2139 | WER: 0.2818 | CER: 0.0721
New Best Model Saved! (WER: 0.2818)


Epoch 20/50: 100%|████████████████████████████████████████| 2000/2000 [13:38<00:00,  2.44it/s, Loss=0.2631, LR=6.67e-05]

Evaluating Epoch 20...





Epoch 20 | Loss: 0.2070 | WER: 0.2761 | CER: 0.0700
New Best Model Saved! (WER: 0.2761)


Epoch 21/50: 100%|████████████████████████████████████████| 2000/2000 [13:37<00:00,  2.45it/s, Loss=0.0884, LR=6.44e-05]

Evaluating Epoch 21...





Epoch 21 | Loss: 0.1982 | WER: 0.2839 | CER: 0.0710


Epoch 22/50: 100%|████████████████████████████████████████| 2000/2000 [13:37<00:00,  2.45it/s, Loss=0.4031, LR=6.22e-05]

Evaluating Epoch 22...





Epoch 22 | Loss: 0.1878 | WER: 0.2665 | CER: 0.0665
New Best Model Saved! (WER: 0.2665)


Epoch 23/50: 100%|████████████████████████████████████████| 2000/2000 [13:36<00:00,  2.45it/s, Loss=0.2669, LR=6.00e-05]

Evaluating Epoch 23...





Epoch 23 | Loss: 0.1881 | WER: 0.2651 | CER: 0.0670
New Best Model Saved! (WER: 0.2651)


Epoch 24/50: 100%|████████████████████████████████████████| 2000/2000 [13:36<00:00,  2.45it/s, Loss=0.2518, LR=5.78e-05]

Evaluating Epoch 24...





Epoch 24 | Loss: 0.1876 | WER: 0.2484 | CER: 0.0617
New Best Model Saved! (WER: 0.2484)


Epoch 25/50: 100%|████████████████████████████████████████| 2000/2000 [13:35<00:00,  2.45it/s, Loss=0.2881, LR=5.56e-05]

Evaluating Epoch 25...





Epoch 25 | Loss: 0.1716 | WER: 0.2556 | CER: 0.0628


Epoch 26/50: 100%|████████████████████████████████████████| 2000/2000 [13:41<00:00,  2.43it/s, Loss=0.0593, LR=5.33e-05]

Evaluating Epoch 26...





Epoch 26 | Loss: 0.1694 | WER: 0.2484 | CER: 0.0621


Epoch 27/50: 100%|████████████████████████████████████████| 2000/2000 [13:38<00:00,  2.44it/s, Loss=0.3493, LR=5.11e-05]

Evaluating Epoch 27...





Epoch 27 | Loss: 0.1682 | WER: 0.2453 | CER: 0.0605
New Best Model Saved! (WER: 0.2453)


Epoch 28/50:   4%|█▌                                        | 77/2000 [00:31<14:21,  2.23it/s, Loss=0.0723, LR=5.10e-05]