In [None]:
import json
import torch
import gc
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

# --- CONFIGURATION ---
EVAL_DATASET_PATH = "Dataset/eval_data/eval_(Mixed).jsonl" # Change to your actual 350 eval file path
BASE_MODEL_ID = "defog/llama-3-sqlcoder-8b"
ADAPTER_ID = "Sourish-Kanna/CenQuery"
# ---------------------

def load_model():
    print("‚è≥ Loading Base Model and Adapter for Evaluation...")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_ID, device_map="auto", quantization_config=bnb_config
    )
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    
    model = PeftModel.from_pretrained(base_model, ADAPTER_ID, is_trainable=False)
    model.eval()
    return model, tokenizer

def evaluate():
    model, tokenizer = load_model()
    terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
    
    with open(EVAL_DATASET_PATH, 'r') as f:
        eval_data = [json.loads(line) for line in f]

    print(f"üìã Found {len(eval_data)} questions for evaluation.")
    
    exact_matches = 0
    total = len(eval_data)
    
    # Loop through the dataset with a progress bar
    for item in tqdm(eval_data, desc="Evaluating"):
        prompt = item['prompt'] # Adjust key if your JSONL uses a different key for the prompt
        ground_truth_sql = item['completion'].strip().lower() # Adjust key for target SQL
        
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs, max_new_tokens=300, eos_token_id=terminators, 
                pad_token_id=tokenizer.eos_token_id, do_sample=False
            )
            
        full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Clean the generated SQL
        generated_sql = full_output.replace(prompt, "").strip()
        if "### SQL" in generated_sql:
            generated_sql = generated_sql.split("### SQL")[-1].strip()
        generated_sql = generated_sql.split("assistant")[0].split("<|start_header_id|>")[0].strip().lower()
        if ";" in generated_sql:
            generated_sql = generated_sql.split(";")[0] + ";"
            
        # Check Exact Match
        if generated_sql == ground_truth_sql:
            exact_matches += 1

    exact_match_accuracy = (exact_matches / total) * 100
    print("\n" + "="*50)
    print(f"üéØ EVALUATION COMPLETE")
    print(f"Total Questions: {total}")
    print(f"Exact Match Accuracy: {exact_match_accuracy:.2f}%")
    print("="*50)

if __name__ == "__main__":
    evaluate()