In [None]:
import torch
import torchaudio
import json
import os
import pandas as pd
from tqdm import tqdm
from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor
from peft import PeftModel

# --- Configuration ---
CONFIG = {
    "base_model": "facebook/wav2vec2-large-xlsr-53",
    "checkpoint_path": "xlsr_lora_gibberish_best", # Your saved LoRA folder
    "vocab_path": "vocab.json",                    # Your saved Vocab file
    "input_csv": "geo/test.csv",                   # Input file
    "output_csv": "geo/test_predictions.csv",      # Output file
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

def load_resources():
    print(f"Using device: {CONFIG['device']}")
    
    # 1. Load Vocabulary
    if not os.path.exists(CONFIG["vocab_path"]):
        raise FileNotFoundError(f"Vocab file not found at {CONFIG['vocab_path']}. Did you save it during training?")
        
    print("Loading vocabulary...")
    with open(CONFIG["vocab_path"], "r") as f:
        vocab_dict = json.load(f)
    # Create inverse vocab (ID -> Char) for decoding
    inv_vocab = {v: k for k, v in vocab_dict.items()}
    
    # 2. Load Base Model
    print("Loading Base Model...")
    model = Wav2Vec2ForCTC.from_pretrained(
        CONFIG["base_model"], 
        vocab_size=len(vocab_dict),
        pad_token_id=vocab_dict.get("<pad>", 0),
        ignore_mismatched_sizes=True,
        token=False # Force anonymous to avoid 401 errors
    )
    
    # 3. Load LoRA Adapters
    print(f"Loading LoRA Adapters from {CONFIG['checkpoint_path']}...")
    if not os.path.exists(CONFIG["checkpoint_path"]):
        raise FileNotFoundError(f"Checkpoint folder not found at {CONFIG['checkpoint_path']}")
        
    model = PeftModel.from_pretrained(model, CONFIG["checkpoint_path"])
    model.to(CONFIG["device"])
    model.eval()
    
    # 4. Load Processor
    processor = Wav2Vec2FeatureExtractor.from_pretrained(CONFIG["base_model"], token=False)
    
    return model, processor, inv_vocab

def transcribe_single_audio(audio_path, model, processor, inv_vocab):
    """
    Reads audio, resamples, and returns string transcription.
    """
    if not os.path.exists(audio_path):
        return "[ERROR: FILE NOT FOUND]"

    try:
        # Load Audio
        waveform, sr = torchaudio.load(audio_path)
        
        # Resample to 16k if needed
        if sr != 16000:
            waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
        
        # Convert Stereo to Mono
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Preprocess
        input_values = processor(
            waveform.squeeze().numpy(), 
            sampling_rate=16000, 
            return_tensors="pt"
        ).input_values
        
        input_values = input_values.to(CONFIG["device"])

        # Inference
        with torch.no_grad():
            logits = model(input_values).logits
        
        # Greedy Decode (Argmax)
        pred_ids = torch.argmax(logits, dim=-1)
        pred_ids = pred_ids[0].cpu().tolist()
        
        # Decode to String (CTC Logic)
        pred_str = []
        prev_token = -1
        for token in pred_ids:
            if token != prev_token and token != 0: # 0 is <pad>
                char = inv_vocab.get(token, "")
                # Filter out special tokens just in case
                if char not in ["<s>", "</s>", "<unk>", "<pad>"]:
                    pred_str.append(char)
            prev_token = token
            
        # Join characters and replace pipe with space
        final_text = "".join(pred_str).replace("|", " ")
        return final_text

    except Exception as e:
        return f"[ERROR: {str(e)}]"

def main():
    # 1. Setup Resources
    model, processor, inv_vocab = load_resources()
    
    # 2. Load CSV
    print(f"Reading input CSV: {CONFIG['input_csv']}")
    df = pd.read_csv(CONFIG["input_csv"], delimiter=",")
    
    # Ensure 'transcript' column exists
    if 'transcript' not in df.columns:
        df['transcript'] = ""
    
    # 3. Iterate and Predict
    print(f"Starting inference on {len(df)} files...")
    
    # We iterate through the DataFrame using tqdm for a progress bar
    for index, row in tqdm(df.iterrows(), total=len(df), unit="file"):
        file_path = row['file']
        
        # Generate transcription
        transcription = transcribe_single_audio(file_path, model, processor, inv_vocab)
        
        # Save to DataFrame in memory
        df.at[index, 'transcript'] = transcription

    # 4. Save Results
    print(f"Saving results to {CONFIG['output_csv']}...")
    df.to_csv(CONFIG['output_csv'], index=False)
    print("Done!")

if __name__ == "__main__":
    main()