In [11]:
# ============================================================
# Create Evaluation Dataset from Training Data
# Splits code into prompt and reference for evaluation
# ============================================================

import json
from datasets import load_dataset
import random

def create_eval_dataset_from_training(
    training_dataset_path="/content/drive/MyDrive/code_dataset.json",
    output_path="/content/eval_dataset.json",
    test_split=0.2,
    max_samples=None
):
    """
    Create evaluation dataset from training data
    Extracts prompts and references from code samples
    """

    print("="*60)
    print("CREATING EVALUATION DATASET")
    print("="*60 + "\n")

    # Load training dataset
    print(f"Loading training data from {training_dataset_path}...")
    dataset = load_dataset("json", data_files=training_dataset_path, split="train")
    print(f"✓ Loaded {len(dataset)} samples\n")

    # Create eval samples
    eval_samples = []

    for item in dataset:
        # Get the code text
        if "text" in item:
            code = item["text"]
            language = item.get("language", "python")
        elif "output" in item:
            code = item["output"]
            language = item.get("language", "python")
        else:
            continue

        # Extract prompt and reference
        prompt, reference = extract_prompt_and_reference(code, language)

        if prompt and reference:
            eval_samples.append({
                "prompt": prompt,
                "reference": reference,
                "language": language,
                "file_name": item.get("file_name", "unknown")
            })

    print(f"✓ Created {len(eval_samples)} evaluation samples\n")

    # Shuffle and limit
    random.shuffle(eval_samples)

    if max_samples:
        eval_samples = eval_samples[:max_samples]
        print(f"Limited to {max_samples} samples")

    # Save
    with open(output_path, "w") as f:
        json.dump(eval_samples, f, indent=2)

    print(f"\n✓ Evaluation dataset saved to {output_path}")
    print(f"Total samples: {len(eval_samples)}")

    # Print statistics
    print("\nDataset Statistics:")
    print("-" * 60)
    languages = {}
    for sample in eval_samples:
        lang = sample["language"]
        languages[lang] = languages.get(lang, 0) + 1

    for lang, count in sorted(languages.items()):
        print(f"  {lang}: {count} samples")

    return eval_samples

def extract_prompt_and_reference(code, language="python"):
    """
    Extract prompt and reference from code

    Strategies:
    1. For functions: Use signature as prompt, full code as reference
    2. For classes: Use class declaration as prompt, full code as reference
    3. For other code: Use first line as prompt, full code as reference
    """

    lines = code.strip().split('\n')

    if not lines:
        return None, None

    if language == "python":
        # Function definition
        if lines[0].strip().startswith("def "):
            prompt = lines[0].strip()
            reference = code
            return prompt, reference

        # Class definition
        elif lines[0].strip().startswith("class "):
            prompt = lines[0].strip()
            reference = code
            return prompt, reference

        # Import or other
        else:
            # Use first significant line as prompt
            for line in lines:
                if line.strip() and not line.strip().startswith("#"):
                    prompt = line.strip()
                    reference = code
                    return prompt, reference

    elif language == "javascript" or language == "ts":
        # Function definition
        if "function " in lines[0] or "=>" in lines[0] or "const " in lines[0]:
            prompt = lines[0].strip()
            reference = code
            return prompt, reference

        # Class definition
        elif lines[0].strip().startswith("class "):
            prompt = lines[0].strip()
            reference = code
            return prompt, reference

    elif language == "java":
        # Method or class
        for i, line in enumerate(lines):
            if "public " in line or "private " in line or "protected " in line:
                if "class " in line or "interface " in line:
                    prompt = line.strip()
                    reference = code
                    return prompt, reference
                elif "(" in line:  # Method
                    prompt = line.strip()
                    reference = code
                    return prompt, reference

    elif language == "cpp":
        # Function definition
        for line in lines:
            if "(" in line and "{" not in line:
                prompt = line.strip()
                reference = code
                return prompt, reference

    # Default: use first line as prompt
    prompt = lines[0].strip()
    reference = code
    return prompt, reference

def split_train_eval(
    dataset_path="code_dataset.json",
    train_output="train_dataset.json",
    eval_output="eval_dataset.json",
    test_split=0.2
):
    """
    Split existing dataset into train and eval sets
    """

    print("="*60)
    print("SPLITTING DATASET INTO TRAIN/EVAL")
    print("="*60 + "\n")

    # Load dataset
    print(f"Loading dataset from {dataset_path}...")
    with open(dataset_path, "r") as f:
        data = json.load(f)

    print(f"✓ Loaded {len(data)} samples\n")

    # Shuffle
    random.shuffle(data)

    # Split
    split_idx = int(len(data) * (1 - test_split))
    train_data = data[:split_idx]
    eval_data = data[split_idx:]

    # Save
    with open(train_output, "w") as f:
        json.dump(train_data, f, indent=2)

    with open(eval_output, "w") as f:
        json.dump(eval_data, f, indent=2)

    print(f"✓ Train set saved to {train_output} ({len(train_data)} samples)")
    print(f"✓ Eval set saved to {eval_output} ({len(eval_data)} samples)")
    print(f"\nSplit ratio: {(1-test_split)*100:.0f}% train, {test_split*100:.0f}% eval")

# ============================================================
# EXAMPLE USAGE
# ============================================================
if __name__ == "__main__":

    print("\n" + "="*60)
    print("OPTION 1: Create eval dataset with prompts extracted")
    print("="*60)

    eval_samples = create_eval_dataset_from_training(
        training_dataset_path="/content/drive/MyDrive/code_dataset.json",
        output_path="eval_dataset.json",
        max_samples=100  # Limit to 100 samples for faster evaluation
    )

    # Show examples
    print("\n" + "="*60)
    print("SAMPLE EVALUATION ENTRIES")
    print("="*60)

    for i, sample in enumerate(eval_samples[:3]):
        print(f"\nSample {i+1}:")
        print(f"Language: {sample['language']}")
        print(f"Prompt: {sample['prompt']}")
        print(f"Reference (first 100 chars): {sample['reference'][:100]}...")
        print("-" * 60)

    print("\n" + "="*60)
    print("OPTION 2: Simple train/eval split")
    print("="*60)

    # Uncomment to use simple split instead
    # split_train_eval(
    #     dataset_path="code_dataset.json",
    #     train_output="train_dataset.json",
    #     eval_output="eval_dataset.json",
    #     test_split=0.2
    # )

    print("\n✓ Dataset preparation complete!")
    print("\nNext step: Run the evaluation script:")
    print("  python evaluate_model.py")


OPTION 1: Create eval dataset with prompts extracted
CREATING EVALUATION DATASET

Loading training data from /content/drive/MyDrive/code_dataset.json...
✓ Loaded 42 samples

✓ Created 42 evaluation samples

Limited to 100 samples

✓ Evaluation dataset saved to eval_dataset.json
Total samples: 42

Dataset Statistics:
------------------------------------------------------------
  cpp: 42 samples

SAMPLE EVALUATION ENTRIES

Sample 1:
Language: cpp
Prompt: UNIT_TEST(Assert_Smoke)
Reference (first 100 chars): #include "testing/testing.hpp"

#include "base/base.hpp"
#include "base/exception.hpp"
#include "bas...
------------------------------------------------------------

Sample 2:
Language: cpp
Prompt: * Copyright (c) 2004-present, The University of Notre Dame. All rights
Reference (first 100 chars): /*
 * Copyright (c) 2004-present, The University of Notre Dame. All rights
 * reserved.
 *
 * Redist...
------------------------------------------------------------

Sample 3:
Language: cpp
P

In [10]:
# ============================================================
# Comprehensive Model Evaluation for Code Generation
# Metrics: CodeBLEU, BLEU, Exact Match, Pass@k, etc.
# ============================================================

# Install dependencies
import subprocess
import sys

def install_packages():
    packages = [
        "transformers",
        "datasets",
        "peft",
        "torch",
        "evaluate",
        "sacrebleu",
        "codebleu",
        "tree-sitter",
        "tree-sitter-python",
        "tree-sitter-java",
        "tree-sitter-javascript",
        "tree-sitter-cpp",
        "tree-sitter-c-sharp",
        "tree-sitter-go",
        "nltk",
        "rouge_score" # Added rouge_score dependency
    ]

    print("Installing evaluation packages...")
    for package in packages:
        subprocess.run(
            [sys.executable, "-m", "pip", "install", "-q", package],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL
        )
    print("✓ Packages installed\n")

install_packages()

Installing evaluation packages...
✓ Packages installed



In [1]:
# Imports
import torch
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from datasets import load_dataset
from codebleu import calc_codebleu
import evaluate
from tqdm import tqdm
import numpy as np
from collections import defaultdict
import time

In [2]:
# ============================================================
# LOAD MODEL
# ============================================================
def load_finetuned_model(base_model_name, lora_path, use_gpu=True):
    """Load the fine-tuned model for evaluation"""

    print("Loading model for evaluation...")

    tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"  # Fix for decoder-only models

    if use_gpu and torch.cuda.is_available():
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )
    else:
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            torch_dtype=torch.float32,
            trust_remote_code=True
        )

    # Load LoRA weights and merge
    model = PeftModel.from_pretrained(base_model, lora_path)
    model = model.merge_and_unload()
    model.eval()

    print("Model loaded\n")
    return model, tokenizer

In [3]:
# ============================================================
# LOAD EVALUATION DATASET
# ============================================================
def load_eval_dataset(dataset_path, max_samples=None):
    """
    Load evaluation dataset
    Expected format: {"prompt": "...", "reference": "..."}
    or {"text": "..."}
    """

    print(f"Loading evaluation dataset from {dataset_path}...")
    dataset = load_dataset("json", data_files=dataset_path, split="train")

    if max_samples:
        dataset = dataset.select(range(min(max_samples, len(dataset))))

    print(f"Loaded {len(dataset)} samples\n")
    return dataset

def prepare_eval_data(dataset):
    """
    Prepare evaluation data: extract prompts and references
    """
    eval_data = []

    for item in dataset:
        if "prompt" in item and "reference" in item:
            # Instruction format
            eval_data.append({
                "prompt": item["prompt"],
                "reference": item["reference"]
            })
        elif "text" in item:
            # Extract prompt and reference from full text
            # Assume format: prompt + reference code
            text = item["text"]
            lines = text.split('\n')

            # Use first line/function signature as prompt
            prompt = lines[0] if lines else text[:100]
            reference = text

            eval_data.append({
                "prompt": prompt,
                "reference": reference
            })
        elif "instruction" in item and "output" in item:
            eval_data.append({
                "prompt": item["instruction"],
                "reference": item["output"]
            })

    return eval_data

In [4]:
# ============================================================
# GENERATE PREDICTIONS
# ============================================================
def generate_predictions(model, tokenizer, eval_data, max_length=256, batch_size=8):
    """Generate predictions for all prompts"""

    print("Generating predictions...")
    predictions = []

    use_gpu = torch.cuda.is_available()

    # Process in batches
    for i in tqdm(range(0, len(eval_data), batch_size)):
        batch = eval_data[i:i+batch_size]
        prompts = [item["prompt"] for item in batch]

        # Tokenize
        inputs = tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )

        if use_gpu:
            inputs = {k: v.to("cuda") for k, v in inputs.items()}

        # Generate
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=max_length,
                do_sample=False,  # Greedy for consistency
                pad_token_id=tokenizer.eos_token_id,
                use_cache=True,
            )

        # Decode
        batch_predictions = [
            tokenizer.decode(output, skip_special_tokens=True)
            for output in outputs
        ]
        predictions.extend(batch_predictions)

    print("Predictions generated\n")
    return predictions

In [5]:
# ============================================================
# EVALUATION METRICS
# ============================================================

def calculate_codebleu(predictions, references, language="python"):
    """
    Calculate CodeBLEU score
    CodeBLEU considers: BLEU, syntax match, dataflow match
    """
    print("Calculating CodeBLEU...")

    # Map language names to codebleu format
    lang_mapping = {
        "python": "python",
        "py": "python",
        "java": "java",
        "javascript": "javascript",
        "js": "javascript",
        "ts": "javascript",
        "cpp": "cpp",
        "c": "c",
        "c++": "cpp",
        "csharp": "c_sharp",
        "cs": "c_sharp",
        "go": "go"
    }

    codebleu_lang = lang_mapping.get(language.lower(), "cpp")

    try:
        result = calc_codebleu(
            references=[[ref] for ref in references],
            predictions=predictions,
            lang=codebleu_lang,
            weights=(0.25, 0.25, 0.25, 0.25),
            tokenizer=None
        )

        return {
            "codebleu": result["codebleu"],
            "ngram_match_score": result["ngram_match_score"],
            "weighted_ngram_match_score": result["weighted_ngram_match_score"],
            "syntax_match_score": result["syntax_match_score"],
            "dataflow_match_score": result["dataflow_match_score"]
        }
    except Exception as e:
        print(f"⚠ CodeBLEU calculation failed: {e}")
        print("  Skipping CodeBLEU metric...")
        return {
            "codebleu": None,
            "ngram_match_score": None,
            "weighted_ngram_match_score": None,
            "syntax_match_score": None,
            "dataflow_match_score": None
        }

def calculate_bleu(predictions, references):
    """Calculate BLEU score"""
    print("Calculating BLEU...")

    bleu = evaluate.load("bleu")

    # IMPORTANT: References must be list of lists
    # Each prediction maps to a list of reference(s)
    references_formatted = [[ref] for ref in references]

    result = bleu.compute(
        predictions=predictions,
        references=references_formatted
    )

    return result

def calculate_exact_match(predictions, references):
    """Calculate exact match accuracy"""
    print("Calculating Exact Match...")

    matches = sum(
        1 for pred, ref in zip(predictions, references)
        if pred.strip() == ref.strip()
    )

    accuracy = matches / len(predictions) if predictions else 0

    return {
        "exact_match": accuracy,
        "total_matches": matches,
        "total_samples": len(predictions)
    }

def calculate_chrf(predictions, references):
    """Calculate ChrF score (character n-gram F-score)"""
    print("Calculating ChrF...")

    chrf = evaluate.load("chrf")
    result = chrf.compute(
        predictions=predictions,
        references=[[ref] for ref in references]
    )

    return result

def calculate_rouge(predictions, references):
    """Calculate ROUGE scores"""
    print("Calculating ROUGE...")

    rouge = evaluate.load("rouge")
    result = rouge.compute(
        predictions=predictions,
        references=references
    )

    return result

def calculate_edit_distance(predictions, references):
    """Calculate average normalized edit distance"""
    print("Calculating Edit Distance...")

    from difflib import SequenceMatcher

    distances = []
    for pred, ref in zip(predictions, references):
        ratio = SequenceMatcher(None, pred, ref).ratio()
        distances.append(ratio)

    return {
        "similarity_score": np.mean(distances),
        "avg_edit_distance": 1 - np.mean(distances)
    }

def calculate_syntax_validity(predictions, language="python"):
    """Check if generated code is syntactically valid"""
    print("Calculating Syntax Validity...")

    import ast

    valid_count = 0
    syntax_errors = []

    for i, pred in enumerate(predictions):
        try:
            if language == "python":
                ast.parse(pred)
                valid_count += 1
        except SyntaxError as e:
            syntax_errors.append((i, str(e)))

    validity_rate = valid_count / len(predictions) if predictions else 0

    return {
        "syntax_validity": validity_rate,
        "valid_samples": valid_count,
        "total_samples": len(predictions),
        "sample_errors": syntax_errors[:5]  # First 5 errors
    }


In [6]:
# ============================================================
# COMPREHENSIVE EVALUATION
# ============================================================
def evaluate_model(
    model,
    tokenizer,
    eval_data,
    language="cpp",
    max_length=256,
    batch_size=8
):
    """
    Comprehensive model evaluation
    """

    print("="*60)
    print("STARTING COMPREHENSIVE EVALUATION")
    print("="*60 + "\n")

    start_time = time.time()

    # Generate predictions
    predictions = generate_predictions(
        model, tokenizer, eval_data, max_length, batch_size
    )

    # Extract references
    references = [item["reference"] for item in eval_data]

    # Calculate all metrics
    results = {}

    # 1. CodeBLEU (most important for code)
    codebleu_results = calculate_codebleu(predictions, references, language)
    results.update(codebleu_results)

    # 2. BLEU
    bleu_results = calculate_bleu(predictions, references)
    results["bleu"] = bleu_results["bleu"]
    results["bleu_precisions"] = bleu_results["precisions"]

    # 3. Exact Match
    exact_match_results = calculate_exact_match(predictions, references)
    results.update(exact_match_results)

    # 4. ChrF
    chrf_results = calculate_chrf(predictions, references)
    results["chrf"] = chrf_results["score"]

    # 5. ROUGE
    rouge_results = calculate_rouge(predictions, references)
    results["rouge1"] = rouge_results["rouge1"]
    results["rouge2"] = rouge_results["rouge2"]
    results["rougeL"] = rouge_results["rougeL"]

    # 6. Edit Distance
    edit_results = calculate_edit_distance(predictions, references)
    results.update(edit_results)

    # 7. Syntax Validity
    syntax_results = calculate_syntax_validity(predictions, language)
    results.update(syntax_results)

    # Evaluation time
    eval_time = time.time() - start_time
    results["evaluation_time_seconds"] = eval_time
    results["samples_per_second"] = len(eval_data) / eval_time

    return results, predictions, references

In [7]:
# ============================================================
# PRINT RESULTS
# ============================================================
def print_results(results):
    """Print evaluation results in a formatted way"""

    print("\n" + "="*60)
    print("EVALUATION RESULTS")
    print("="*60 + "\n")

    # Helper to format metric
    def fmt(key, default=0.0):
        val = results.get(key, default)
        return f"{val:.4f}" if val is not None else "N/A"

    print("CODE-SPECIFIC METRICS:")
    print("-" * 60)
    print(f"  CodeBLEU:              {fmt('codebleu')}")
    print(f"    - N-gram Match:      {fmt('ngram_match_score')}")
    print(f"    - Weighted N-gram:   {fmt('weighted_ngram_match_score')}")
    print(f"    - Syntax Match:      {fmt('syntax_match_score')}")
    print(f"    - Dataflow Match:    {fmt('dataflow_match_score')}")
    print(f"  Syntax Validity:       {fmt('syntax_validity')}")

    print("\nGENERAL METRICS:")
    print("-" * 60)
    print(f"  BLEU:                  {fmt('bleu')}")
    print(f"  ChrF:                  {fmt('chrf')}")
    print(f"  ROUGE-1:               {fmt('rouge1')}")
    print(f"  ROUGE-2:               {fmt('rouge2')}")
    print(f"  ROUGE-L:               {fmt('rougeL')}")

    print("\nACCURACY METRICS:")
    print("-" * 60)
    print(f"  Exact Match:           {fmt('exact_match')}")
    print(f"  Similarity Score:      {fmt('similarity_score')}")
    print(f"  Edit Distance:         {fmt('avg_edit_distance')}")

    print("\nPERFORMANCE:")
    print("-" * 60)
    print(f"  Evaluation Time:       {results.get('evaluation_time_seconds', 0):.2f}s")
    print(f"  Samples/Second:        {results.get('samples_per_second', 0):.2f}")
    print(f"  Total Samples:         {results.get('total_samples', 0)}")

    print("\n" + "="*60)

def save_results(results, predictions, references, output_path="evaluation_results.json"):
    """Save evaluation results to file"""

    output = {
        "metrics": results,
        "samples": [
            {
                "prediction": pred,
                "reference": ref
            }
            for pred, ref in zip(predictions[:10], references[:10])  # Save first 10 samples
        ]
    }

    with open(output_path, "w") as f:
        json.dump(output, f, indent=2)

    print(f"\n✓ Results saved to {output_path}")

In [14]:
# ============================================================
# MAIN EXECUTION
# ============================================================
# Configuration
BASE_MODEL = "/content/drive/MyDrive/starcoder2-3b"
LORA_PATH = "/content/starcoder-finetuned"
EVAL_DATASET_PATH = "eval_dataset_fixed.json"  # Your evaluation dataset
LANGUAGE = "cpp"  # or "java", "javascript", "go", etc.
MAX_SAMPLES = 100  # Set to None to evaluate all samples

print("="*60)
print("MODEL EVALUATION PIPELINE")
print("="*60 + "\n")

# Load model
model, tokenizer = load_finetuned_model(BASE_MODEL, LORA_PATH)

# Load evaluation dataset
eval_dataset = load_eval_dataset(EVAL_DATASET_PATH, max_samples=MAX_SAMPLES)
eval_data = prepare_eval_data(eval_dataset)

print(f"Evaluating on {len(eval_data)} samples...\n")

MODEL EVALUATION PIPELINE

Loading model for evaluation...
Model loaded

Loading evaluation dataset from eval_dataset_fixed.json...


Generating train split: 0 examples [00:00, ? examples/s]

Loaded 42 samples

Evaluating on 42 samples...



In [15]:
# Run evaluation
results, predictions, references = evaluate_model(
    model=model,
    tokenizer=tokenizer,
    eval_data=eval_data,
    language=LANGUAGE,
    max_length=256,
    batch_size=8
)

# Print results
print_results(results)

# Save results
save_results(results, predictions, references)

STARTING COMPREHENSIVE EVALUATION

Generating predictions...


100%|██████████| 6/6 [00:54<00:00,  9.02s/it]


Predictions generated

Calculating CodeBLEU...
⚠ CodeBLEU calculation failed: an integer is required
  Skipping CodeBLEU metric...
Calculating BLEU...
Calculating Exact Match...
Calculating ChrF...
Calculating ROUGE...
Calculating Edit Distance...
Calculating Syntax Validity...

EVALUATION RESULTS

CODE-SPECIFIC METRICS:
------------------------------------------------------------
  CodeBLEU:              N/A
    - N-gram Match:      N/A
    - Weighted N-gram:   N/A
    - Syntax Match:      N/A
    - Dataflow Match:    N/A
  Syntax Validity:       0.0000

GENERAL METRICS:
------------------------------------------------------------
  BLEU:                  0.0000
  ChrF:                  3.3889
  ROUGE-1:               0.1652
  ROUGE-2:               0.0966
  ROUGE-L:               0.1353

ACCURACY METRICS:
------------------------------------------------------------
  Exact Match:           0.0000
  Similarity Score:      0.0961
  Edit Distance:         0.9039

PERFORMANCE:
----------

In [16]:
# Print sample comparisons
print("\n" + "="*60)
print("SAMPLE PREDICTIONS (first 3)")
print("="*60)

for i in range(min(3, len(predictions))):
    print(f"\n--- Sample {i+1} ---")
    print(f"Prompt:\n{eval_data[i]['prompt']}")
    print(f"\nReference:\n{references[i][:200]}...")
    print(f"\nPrediction:\n{predictions[i][:200]}...")
    print("-" * 60)

print("\n✓ Evaluation complete!")


SAMPLE PREDICTIONS (first 3)

--- Sample 1 ---
Prompt:
UNIT_TEST(Assert_Smoke)

Reference:
#include "testing/testing.hpp"

#include "base/base.hpp"
#include "base/exception.hpp"
#include "base/logging.hpp"


UNIT_TEST(Assert_Smoke)
{
  int x = 5;
#ifdef RELE...

Prediction:
UNIT_TEST(Assert_Smoke)
{
    // Arrange
    auto const expected = std::vector<std::string>{ "a", "b", "c" };
    auto const actual = std::vector<std::string>{ "a", "b", "c" };

    // Act
    Assert:...
------------------------------------------------------------

--- Sample 2 ---
Prompt:
void SectionParser::parse(std::istream& input, ForceField& ff, int lineNo)

Reference:
/*
 * Copyright (c) 2004-present, The University of Notre Dame. All rights
 * reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided...

Prediction:
void SectionParser::parse(std::istream& input, ForceField& ff, int lineNo)
{
    std::string line;
    std::getline(input, line);

In [12]:
# ============================================================
# Diagnose Evaluation Dataset Issues
# ============================================================

import json
from collections import Counter

def diagnose_eval_dataset(file_path="eval_dataset.json"):
    """
    Analyze evaluation dataset to identify issues
    """

    print("="*60)
    print("EVALUATION DATASET DIAGNOSTICS")
    print("="*60 + "\n")

    # Load dataset
    with open(file_path, "r") as f:
        data = json.load(f)

    print(f"Total samples: {len(data)}\n")

    # Check structure
    print("Dataset Structure:")
    print("-" * 60)
    if data:
        print(f"Keys in first sample: {list(data[0].keys())}")
        print()

    # Check languages
    languages = Counter()
    for item in data:
        lang = item.get("language", "unknown")
        languages[lang] += 1

    print("Language Distribution:")
    print("-" * 60)
    for lang, count in languages.most_common():
        print(f"  {lang}: {count} samples ({count/len(data)*100:.1f}%)")
    print()

    # Check for mixed content
    has_prompt = sum(1 for item in data if "prompt" in item)
    has_reference = sum(1 for item in data if "reference" in item)
    has_text = sum(1 for item in data if "text" in item)

    print("Field Coverage:")
    print("-" * 60)
    print(f"  Has 'prompt': {has_prompt}/{len(data)} ({has_prompt/len(data)*100:.1f}%)")
    print(f"  Has 'reference': {has_reference}/{len(data)} ({has_reference/len(data)*100:.1f}%)")
    print(f"  Has 'text': {has_text}/{len(data)} ({has_text/len(data)*100:.1f}%)")
    print()

    # Check for empty or very short samples
    empty_prompts = sum(1 for item in data if not item.get("prompt", "").strip())
    empty_refs = sum(1 for item in data if not item.get("reference", "").strip())
    short_refs = sum(1 for item in data if len(item.get("reference", "")) < 10)

    print("Data Quality:")
    print("-" * 60)
    print(f"  Empty prompts: {empty_prompts}")
    print(f"  Empty references: {empty_refs}")
    print(f"  Very short references (<10 chars): {short_refs}")
    print()

    # Sample prompts and references
    print("Sample Entries:")
    print("-" * 60)
    for i in range(min(3, len(data))):
        item = data[i]
        print(f"\nSample {i+1}:")
        print(f"  Language: {item.get('language', 'N/A')}")
        print(f"  Prompt: {item.get('prompt', 'N/A')[:80]}...")
        print(f"  Reference length: {len(item.get('reference', ''))} chars")
        ref_preview = item.get('reference', '')[:100].replace('\n', ' ')
        print(f"  Reference preview: {ref_preview}...")

    print("\n" + "="*60)
    print("RECOMMENDATIONS:")
    print("="*60)

    # Recommendations
    if len(languages) > 3:
        print("⚠ You have many languages mixed. Consider:")
        print("  1. Evaluate each language separately")
        print("  2. Use the primary language for CodeBLEU")

    if empty_prompts > 0 or empty_refs > 0:
        print("⚠ Some samples have empty prompts/references")
        print("  Run: clean_eval_dataset()")

    if short_refs > len(data) * 0.3:
        print("⚠ Many references are very short")
        print("  This may affect metric reliability")

    print()

    return data

def clean_eval_dataset(input_file="eval_dataset.json", output_file="eval_dataset_clean.json"):
    """
    Clean evaluation dataset by removing problematic samples
    """

    print("Cleaning evaluation dataset...")

    with open(input_file, "r") as f:
        data = json.load(f)

    original_count = len(data)

    # Filter out problematic samples
    clean_data = []
    for item in data:
        prompt = item.get("prompt", "").strip()
        reference = item.get("reference", "").strip()

        # Keep if both prompt and reference are non-empty
        if prompt and reference and len(reference) >= 10:
            clean_data.append(item)

    # Save cleaned data
    with open(output_file, "w") as f:
        json.dump(clean_data, f, indent=2)

    removed = original_count - len(clean_data)
    print(f"✓ Cleaned dataset saved to {output_file}")
    print(f"  Original: {original_count} samples")
    print(f"  Cleaned: {len(clean_data)} samples")
    print(f"  Removed: {removed} samples")

    return clean_data

def split_by_language(input_file="eval_dataset.json"):
    """
    Split evaluation dataset by language for separate evaluation
    """

    print("Splitting dataset by language...")

    with open(input_file, "r") as f:
        data = json.load(f)

    # Group by language
    by_language = {}
    for item in data:
        lang = item.get("language", "unknown")
        if lang not in by_language:
            by_language[lang] = []
        by_language[lang].append(item)

    # Save each language separately
    for lang, items in by_language.items():
        output_file = f"eval_dataset_{lang}.json"
        with open(output_file, "w") as f:
            json.dump(items, f, indent=2)
        print(f"✓ {lang}: {len(items)} samples → {output_file}")

    print(f"\nTotal languages: {len(by_language)}")

    return by_language

# ============================================================
# RUN DIAGNOSTICS
# ============================================================
if __name__ == "__main__":

    # Diagnose the dataset
    data = diagnose_eval_dataset("eval_dataset.json")

    # Optional: Clean the dataset
    print("\n" + "="*60)
    print("Would you like to clean the dataset? (Uncomment below)")
    print("="*60)
    # clean_data = clean_eval_dataset("eval_dataset.json", "eval_dataset_clean.json")

    # Optional: Split by language
    print("\n" + "="*60)
    print("Would you like to split by language? (Uncomment below)")
    print("="*60)
    # split_by_language("eval_dataset.json")

EVALUATION DATASET DIAGNOSTICS

Total samples: 42

Dataset Structure:
------------------------------------------------------------
Keys in first sample: ['prompt', 'reference', 'language', 'file_name']

Language Distribution:
------------------------------------------------------------
  cpp: 42 samples (100.0%)

Field Coverage:
------------------------------------------------------------
  Has 'prompt': 42/42 (100.0%)
  Has 'reference': 42/42 (100.0%)
  Has 'text': 0/42 (0.0%)

Data Quality:
------------------------------------------------------------
  Empty prompts: 0
  Empty references: 0
  Very short references (<10 chars): 0

Sample Entries:
------------------------------------------------------------

Sample 1:
  Language: cpp
  Prompt: UNIT_TEST(Assert_Smoke)...
  Reference length: 722 chars
  Reference preview: #include "testing/testing.hpp"  #include "base/base.hpp" #include "base/exception.hpp" #include "bas...

Sample 2:
  Language: cpp
  Prompt: * Copyright (c) 2004-presen

In [13]:
# ============================================================
# Fix C++ Evaluation Dataset - Extract Proper Code Prompts
# ============================================================

import json
import re

def extract_cpp_prompt(code):
    """
    Extract a meaningful C++ prompt from code

    Strategy:
    1. Skip comments and copyright headers
    2. Find function signatures
    3. Find class declarations
    4. Use first meaningful code line
    """

    lines = code.split('\n')

    # Skip initial comments and includes
    code_start = 0
    for i, line in enumerate(lines):
        stripped = line.strip()

        # Skip empty lines, comments, includes, copyright
        if (not stripped or
            stripped.startswith('//') or
            stripped.startswith('/*') or
            stripped.startswith('*') or
            stripped.startswith('#include') or
            stripped.startswith('#define') or
            stripped.startswith('#pragma')):
            continue

        code_start = i
        break

    # Get code without headers
    code_lines = lines[code_start:]

    if not code_lines:
        # Fallback: use first line
        return lines[0] if lines else ""

    # Strategy 1: Find function signature
    for i, line in enumerate(code_lines):
        stripped = line.strip()

        # Function declaration/definition patterns
        if (('(' in stripped and ')' in stripped and
             not stripped.startswith('//') and
             not stripped.startswith('/*')) and
            any(keyword in stripped for keyword in ['void', 'int', 'bool', 'double',
                                                     'float', 'char', 'auto', 'string',
                                                     'TEST', 'UNIT_TEST'])):

            # Get function signature (might span multiple lines)
            signature = stripped

            # If line doesn't end with { or ;, it might continue
            if not (signature.endswith('{') or signature.endswith(';')):
                # Look for continuation
                for j in range(i+1, min(i+3, len(code_lines))):
                    signature += ' ' + code_lines[j].strip()
                    if '{' in code_lines[j] or ';' in code_lines[j]:
                        break

            # Clean up the signature - remove { at end
            signature = signature.replace('{', '').strip()

            return signature

    # Strategy 2: Find class declaration
    for line in code_lines:
        stripped = line.strip()
        if stripped.startswith('class ') or stripped.startswith('struct '):
            return stripped.replace('{', '').strip()

    # Strategy 3: Find template or namespace
    for line in code_lines:
        stripped = line.strip()
        if stripped.startswith('template') or stripped.startswith('namespace'):
            return stripped

    # Fallback: use first non-empty line
    for line in code_lines:
        stripped = line.strip()
        if stripped:
            return stripped

    return code_lines[0] if code_lines else ""

def fix_eval_dataset(
    input_file="eval_dataset.json",
    output_file="eval_dataset_fixed.json"
):
    """
    Fix evaluation dataset by extracting proper C++ prompts
    """

    print("="*60)
    print("FIXING C++ EVALUATION DATASET")
    print("="*60 + "\n")

    # Load dataset
    with open(input_file, "r") as f:
        data = json.load(f)

    print(f"Original dataset: {len(data)} samples\n")

    # Fix each sample
    fixed_data = []
    skipped = 0

    for i, item in enumerate(data):
        reference = item["reference"]

        # Extract better prompt
        new_prompt = extract_cpp_prompt(reference)

        # Skip if prompt is still a comment or too short
        if (new_prompt.startswith('//') or
            new_prompt.startswith('/*') or
            new_prompt.startswith('*') or
            len(new_prompt) < 10):
            print(f"⚠ Skipping sample {i+1}: Invalid prompt extracted")
            skipped += 1
            continue

        # Create fixed item
        fixed_item = {
            "prompt": new_prompt,
            "reference": reference,
            "language": item["language"],
            "file_name": item["file_name"],
            "original_prompt": item["prompt"]  # Keep original for reference
        }

        fixed_data.append(fixed_item)

        # Show progress for first few
        if i < 3:
            print(f"Sample {i+1}:")
            print(f"  Old prompt: {item['prompt'][:60]}...")
            print(f"  New prompt: {new_prompt[:60]}...")
            print()

    # Save fixed dataset
    with open(output_file, "w") as f:
        json.dump(fixed_data, f, indent=2)

    print("="*60)
    print("RESULTS")
    print("="*60)
    print(f"Original samples: {len(data)}")
    print(f"Fixed samples: {len(fixed_data)}")
    print(f"Skipped: {skipped}")
    print(f"\n✓ Fixed dataset saved to {output_file}")
    print("="*60 + "\n")

    # Show some examples
    print("Sample Fixed Entries:")
    print("-" * 60)
    for i in range(min(3, len(fixed_data))):
        item = fixed_data[i]
        print(f"\nSample {i+1}:")
        print(f"  Prompt: {item['prompt'][:100]}")
        print(f"  Reference length: {len(item['reference'])} chars")

    return fixed_data

def create_simplified_eval_dataset(
    input_file="eval_dataset.json",
    output_file="eval_dataset_simple.json",
    max_samples=50,
    max_reference_length=500
):
    """
    Create simplified evaluation dataset
    - Better prompts
    - Shorter references (for faster evaluation)
    - Filter for quality
    """

    print("="*60)
    print("CREATING SIMPLIFIED EVALUATION DATASET")
    print("="*60 + "\n")

    with open(input_file, "r") as f:
        data = json.load(f)

    simplified_data = []

    for item in data:
        reference = item["reference"]

        # Extract prompt
        prompt = extract_cpp_prompt(reference)

        # Skip bad prompts
        if (prompt.startswith('//') or prompt.startswith('/*') or
            prompt.startswith('*') or len(prompt) < 10):
            continue

        # Truncate long references (keep first function/class only)
        if len(reference) > max_reference_length:
            # Try to find end of first function/class
            lines = reference.split('\n')
            truncated_lines = []
            brace_count = 0
            started = False

            for line in lines:
                truncated_lines.append(line)

                # Count braces
                if '{' in line:
                    brace_count += line.count('{')
                    started = True
                if '}' in line:
                    brace_count -= line.count('}')

                # Stop when we've closed all braces
                if started and brace_count == 0:
                    break

                # Safety: don't go too long
                if len('\n'.join(truncated_lines)) > max_reference_length:
                    break

            reference = '\n'.join(truncated_lines)

        simplified_data.append({
            "prompt": prompt,
            "reference": reference,
            "language": "cpp"
        })

        if len(simplified_data) >= max_samples:
            break

    # Save
    with open(output_file, "w") as f:
        json.dump(simplified_data, f, indent=2)

    print(f"✓ Created {len(simplified_data)} simplified samples")
    print(f"  Saved to {output_file}")
    print(f"  Average reference length: {sum(len(x['reference']) for x in simplified_data) / len(simplified_data):.0f} chars")

    return simplified_data

# ============================================================
# MAIN EXECUTION
# ============================================================
if __name__ == "__main__":

    print("\n" + "="*60)
    print("OPTION 1: Fix existing dataset (better prompts)")
    print("="*60)
    fixed_data = fix_eval_dataset(
        "eval_dataset.json",
        "eval_dataset_fixed.json"
    )

    print("\n" + "="*60)
    print("OPTION 2: Create simplified dataset (shorter, faster)")
    print("="*60)
    simple_data = create_simplified_eval_dataset(
        "eval_dataset.json",
        "eval_dataset_simple.json",
        max_samples=30,
        max_reference_length=300
    )

    print("\n" + "="*60)
    print("NEXT STEPS")
    print("="*60)
    print("""
1. Use fixed dataset for comprehensive evaluation:
   python evaluate_model.py  # Update: EVAL_DATASET_PATH = "eval_dataset_fixed.json"

2. Or use simplified dataset for quick evaluation:
   python evaluate_model.py  # Update: EVAL_DATASET_PATH = "eval_dataset_simple.json"

3. The simplified dataset will be faster (shorter code = faster generation)
    """)


OPTION 1: Fix existing dataset (better prompts)
FIXING C++ EVALUATION DATASET

Original dataset: 42 samples

Sample 1:
  Old prompt: UNIT_TEST(Assert_Smoke)...
  New prompt: UNIT_TEST(Assert_Smoke)...

Sample 2:
  Old prompt: * Copyright (c) 2004-present, The University of Notre Dame. ...
  New prompt: void SectionParser::parse(std::istream& input, ForceField& f...

Sample 3:
  Old prompt: // Copyright (c) 2013-2020 Baptiste Wicht....
  New prompt: auto today = budget::local_day();...

RESULTS
Original samples: 42
Fixed samples: 42
Skipped: 0

✓ Fixed dataset saved to eval_dataset_fixed.json

Sample Fixed Entries:
------------------------------------------------------------

Sample 1:
  Prompt: UNIT_TEST(Assert_Smoke)
  Reference length: 722 chars

Sample 2:
  Prompt: void SectionParser::parse(std::istream& input, ForceField& ff, int lineNo)
  Reference length: 4779 chars

Sample 3:
  Prompt: auto today = budget::local_day();
  Reference length: 4160 chars

OPTION 2: Create simplified