In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torchaudio
import pandas as pd
from tqdm import tqdm
from jiwer import wer, cer
import gc
import types
import warnings
import logging
import json
import shutil
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Gain
import numpy as np

# Hugging Face Imports
from transformers import (
    Wav2Vec2ForCTC, 
    Wav2Vec2FeatureExtractor,
    get_cosine_schedule_with_warmup,
    logging as hf_logging
)
from peft import get_peft_model, LoraConfig, PeftModel
from huggingface_hub import login
from torch.cuda.amp import autocast, GradScaler

#Define the augmentation pipeline
augmentations = Compose([
    AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.3),
    TimeStretch(min_rate=0.85, max_rate=1.15, p=0.3),
    PitchShift(min_semitones=-2, max_semitones=2, p=0.3),
])
# --- Suppress Warnings ---
warnings.filterwarnings("ignore")
hf_logging.set_verbosity_error()
logging.getLogger("torchaudio").setLevel(logging.ERROR)

# --- Configuration for 1B Model on RTX 3080 ---
CONFIG = {
    "train_csv": "geo/train.csv",
    "val_csv": "geo/dev.csv",
    "hf_token": None,  
    
    # --- 1B MODEL SETTINGS ---
    "base_model": "facebook/wav2vec2-xls-r-1b", # <--- The 1B Model
    "checkpoint_path": "xlsr_1b_Ngram", #
    
    "batch_size": 1,             # Must be 1 to fit in memory
    "grad_accum_steps": 16,      # Increase this to simulate batch size 16
    "max_audio_len": 160000,     # Reduced slightly (8 seconds) to prevent OOM
    
    "lora_rank": 64,             
    "lora_alpha": 128,
    "lora_dropout": 0.05,
    
    "learning_rate": 5e-5,       
    "num_epochs": 50,            
    
    "mask_time_prob": 0.15,      
    "mask_time_length": 10,
    "mask_feature_prob": 0.10,   
    "mask_feature_length": 64,
    
    "target_modules": ["q_proj", "k_proj", "v_proj", "out_proj", "intermediate_dense", "output_dense"],
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if CONFIG["hf_token"]:
    login(token=CONFIG["hf_token"])

# --- Vocabulary Builder ---
class Vocabulary:
    def __init__(self, csv_paths):
        self.vocab = {"<pad>": 0, "<s>": 1, "</s>": 2, "<unk>": 3, "|": 4}
        self.idx = 5
        self._build_vocab(csv_paths)
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        
    def _build_vocab(self, paths):
        chars = set()
        for path in paths:
            if not os.path.exists(path): continue
            df = pd.read_csv(path)
            all_text = "".join(df['transcript'].fillna("").astype(str).tolist())
            chars.update(list(all_text))
        
        for c in sorted(chars):
            if c not in self.vocab:
                self.vocab[c] = self.idx
                self.idx += 1
    
    def encode(self, text):
        text = str(text).replace(" ", "|")
        return [self.vocab.get(c, self.vocab["<unk>"]) for c in text]
    
    def decode(self, tokens):
        res = []
        for t in tokens:
            if t == 0: continue 
            char = self.inv_vocab.get(t, "")
            res.append(char)
        return "".join(res).replace("|", " ").replace("<s>", "").replace("</s>", "")

# --- Dataset ---
class EsperantoDataset(Dataset):
    def __init__(self, csv_path, vocab, processor, max_len=None, augment=False):
        self.df = pd.read_csv(csv_path)
        self.df = self.df[self.df['file'].apply(os.path.exists)].reset_index(drop=True)
        print(f"Loaded {len(self.df)} valid samples from {csv_path}")
        self.vocab = vocab
        self.processor = processor 
        self.max_len = max_len
        self.augment = augment  # <--- Flag to turn on/off

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        audio_path = self.df.iloc[idx]['file']
        transcript = self.df.iloc[idx]['transcript']
        if pd.isna(transcript): transcript = ""
        transcript = str(transcript)

        try:
            waveform, sr = torchaudio.load(audio_path)
            if sr != 16000:
                waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
            
            # Mix to mono
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            waveform = waveform.squeeze()

            # --- AUGMENTATION BLOCK START ---
            if self.augment:
                # 1. Apply Audiomentations (TimeStretch, PitchShift, Noise)
                wav_numpy = waveform.numpy()
                # Ensure it's not empty before augmenting
                if len(wav_numpy) > 0:
                    try:
                        wav_numpy = augmentations(samples=wav_numpy, sample_rate=16000)
                    except Exception as e:
                        print(f"Augmentation warning: {e}")

                # 2. Apply Manual Gain (Replaces the broken Gain class)
                # Randomly scale volume by +/- 6dB (factor 0.5x to 2.0x)
                if np.random.random() < 0.3:
                    gain_factor = 10 ** (np.random.uniform(-6, 6) / 20)
                    wav_numpy = wav_numpy * gain_factor

                # Convert back to tensor
                waveform = torch.from_numpy(wav_numpy).float()
            # --- AUGMENTATION BLOCK END ---

            if self.max_len and waveform.size(0) > self.max_len:
                waveform = waveform[:self.max_len]

            features = self.processor(waveform, sampling_rate=16000, return_tensors="pt").input_values[0]
            labels = torch.tensor(self.vocab.encode(transcript), dtype=torch.long)
            return features, labels
        except Exception as e:
            # print(f"Error loading {audio_path}: {e}") # Optional: debug print
            return None

def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if not batch: return torch.empty(0), torch.empty(0), torch.empty(0)
    features, labels = zip(*batch)
    features_padded = pad_sequence(features, batch_first=True, padding_value=0.0)
    attention_mask = (features_padded != 0).long() 
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=-100)
    return features_padded, attention_mask, labels_padded

# --- Model Loading (1B Configured) ---
def get_lora_model(vocab_size):
    print(f"Loading Base Model: {CONFIG['base_model']}...")
    use_auth_token = CONFIG["hf_token"] if CONFIG["hf_token"] else False

    model = Wav2Vec2ForCTC.from_pretrained(
        CONFIG["base_model"], 
        ctc_loss_reduction="mean", 
        pad_token_id=0,
        vocab_size=vocab_size,
        ignore_mismatched_sizes=True,
        token=use_auth_token, # <--- Fix applied here
        mask_time_prob=CONFIG["mask_time_prob"],
        mask_time_length=CONFIG["mask_time_length"],
        mask_feature_prob=CONFIG["mask_feature_prob"],
        mask_feature_length=CONFIG["mask_feature_length"],
    )
    
    if not hasattr(model, "enable_input_require_grads"):
        def enable_input_require_grads(self):
            def make_inputs_require_grads(module, input, output):
                output.requires_grad_(True)
            self.wav2vec2.feature_projection.register_forward_hook(make_inputs_require_grads)
        model.enable_input_require_grads = types.MethodType(enable_input_require_grads, model)

    # We use a new checkpoint path for 1B
    adapter_path = os.path.join(CONFIG["checkpoint_path"], "adapter_model.safetensors")
    
    if os.path.exists(adapter_path):
        print(f"\n[INFO] Found existing 1B checkpoint. Resuming...")
        model = PeftModel.from_pretrained(model, CONFIG["checkpoint_path"], is_trainable=True)
    else:
        print(f"\n[INFO] Initializing FRESH 1B LoRA model...")
        modules_to_save = ["lm_head", "layer_norm"] 
        peft_config = LoraConfig(
            inference_mode=False,
            r=CONFIG["lora_rank"],
            lora_alpha=CONFIG["lora_alpha"],
            lora_dropout=CONFIG["lora_dropout"],
            target_modules=CONFIG["target_modules"],
            modules_to_save=modules_to_save
        )
        model = get_peft_model(model, peft_config)
    
    model.gradient_checkpointing_enable()
    model.print_trainable_parameters()
    return model

def evaluate(model, dataloader, vocab):
    model.eval()
    refs, preds = [], []
    
    with torch.no_grad():
        for features, mask, labels in dataloader:
            if features.size(0) == 0: continue
            features = features.to(device)
            mask = mask.to(device)
            
            with autocast():
                outputs = model(features, attention_mask=mask)
            
            logits = outputs.logits.float().cpu().numpy()
            pred_ids = logits.argmax(axis=-1)
            
            for i in range(len(labels)):
                label_idx = labels[i][labels[i] != -100].cpu().tolist()
                refs.append(vocab.decode(label_idx))
                
                pred_str_raw = []
                prev_token = -1
                for token in pred_ids[i]:
                    if token != prev_token and token != 0: 
                        pred_str_raw.append(token)
                    prev_token = token
                preds.append(vocab.decode(pred_str_raw))

    epoch_wer = wer(refs, preds) if len(refs) > 0 else 1.0
    epoch_cer = cer(refs, preds) if len(refs) > 0 else 1.0
    return epoch_wer, epoch_cer

def main():
    torch.cuda.empty_cache()
    gc.collect()

    vocab = Vocabulary([CONFIG["train_csv"], CONFIG["val_csv"]])
    use_auth_token = CONFIG["hf_token"] if CONFIG["hf_token"] else False

    print("Initializing Processor...")
    # --- FIX: This is the line that was crashing your notebook ---
    processor = Wav2Vec2FeatureExtractor.from_pretrained(
        CONFIG["base_model"], 
        token=use_auth_token # Added token=False/Token logic
    )
    
    train_dataset = EsperantoDataset(CONFIG["train_csv"], vocab, processor, max_len=CONFIG["max_audio_len"], augment=True)
    val_dataset = EsperantoDataset(CONFIG["val_csv"], vocab, processor, augment=False) # Keep validation clean!

    train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True, collate_fn=collate_fn, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False, collate_fn=collate_fn, num_workers=1)

    model = get_lora_model(len(vocab.vocab))
    model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"])
    scaler = GradScaler()
    
    num_training_steps = len(train_loader) * CONFIG["num_epochs"] // CONFIG["grad_accum_steps"]
    num_warmup_steps = int(0.1 * num_training_steps) 
    
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
    )

    print(f"\nStarting 1B Model Training...")
    best_wer = float('inf')
    
    if os.path.exists(os.path.join(CONFIG["checkpoint_path"], "adapter_model.safetensors")):
        baseline_wer, baseline_cer = evaluate(model, val_loader, vocab)
        best_wer = baseline_wer
        print(f"Baseline restored -> WER: {baseline_wer:.4f}")
    
    for epoch in range(CONFIG["num_epochs"]):
        model.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['num_epochs']}")
        
        for step, (features, mask, labels) in enumerate(pbar):
            if features.size(0) == 0: continue 

            features = features.to(device)
            mask = mask.to(device)
            labels = labels.to(device)

            with autocast():
                outputs = model(features, attention_mask=mask, labels=labels)
                loss = outputs.loss / CONFIG["grad_accum_steps"]

            scaler.scale(loss).backward()

            if (step + 1) % CONFIG["grad_accum_steps"] == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()
                
                # OPTIONAL: Explicit memory cleanup for 1B model
                # del loss, outputs
                # torch.cuda.empty_cache()

            total_loss += loss.item() * CONFIG["grad_accum_steps"]
            pbar.set_postfix({"Loss": f"{loss.item()*CONFIG['grad_accum_steps']:.4f}", "LR": f"{scheduler.get_last_lr()[0]:.2e}"})

        avg_train_loss = total_loss / len(train_loader)

        print(f"Evaluating Epoch {epoch+1}...")
        epoch_wer, epoch_cer = evaluate(model, val_loader, vocab)

        print(f"Epoch {epoch+1} | Loss: {avg_train_loss:.4f} | WER: {epoch_wer:.4f} | CER: {epoch_cer:.4f}")
        
        if epoch_wer < best_wer:
            best_wer = epoch_wer
            model.save_pretrained(CONFIG["checkpoint_path"])
            print(f"New Best Model Saved! (WER: {best_wer:.4f})")
            
            with open("vocab_1bfr.json", "w") as f:
                json.dump(vocab.vocab, f)
            
    torch.cuda.empty_cache()
    print("Training Complete.")

if __name__ == "__main__":
    main()

2025-12-01 01:25:36.410759: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-01 01:25:36.449455: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Using device: cuda
Initializing Processor...
Loaded 6000 valid samples from geo/train.csv
Loaded 1000 valid samples from geo/dev.csv
Loading Base Model: facebook/wav2vec2-xls-r-1b...

[INFO] Found existing 1B checkpoint. Resuming...
trainable params: 71,082,789 || all params: 1,033,627,594 || trainable%: 6.8770

Starting 1B Model Training...
Baseline restored -> WER: 0.1994


Epoch 1/50:   2%|â–‹                                          | 97/6000 [00:19<20:14,  4.86it/s, Loss=0.2617, LR=1.60e-07]