# Hyper parameter tuning search

---



In [None]:
def hyperparameter_search(train_data, val_data, base_output_dir, n_trials=4):
    """Perform hyperparameter search to find optimal training configuration"""

    import optuna
    from datetime import datetime

    # Create directory for search results
    search_dir = f"{base_output_dir}/hparam_search_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    os.makedirs(search_dir, exist_ok=True)

    # Define the objective function
    def objective(trial):
        # Sample hyperparameters to explore
        lr = trial.suggest_float("learning_rate", 1e-5, 5e-5, log=True)
        bs = trial.suggest_categorical("batch_size", [2, 4, 8])
        grad_accum = trial.suggest_categorical("gradient_accumulation_steps", [2, 4, 8])

        # Create trial directory
        trial_dir = f"{search_dir}/trial_{trial.number}"
        os.makedirs(trial_dir, exist_ok=True)

        # Use subset of data for faster exploration
        train_subset = train_data[:min(300, len(train_data))]
        val_subset = val_data[:min(100, len(val_data))]

        # Log hyperparameters
        with open(f"{trial_dir}/params.json", 'w') as f:
            json.dump({
                "learning_rate": lr,
                "batch_size": bs,
                "gradient_accumulation_steps": grad_accum,
            }, f, indent=2)

        # Train with this configuration
        try:
            print(f"\nTrial {trial.number}: lr={lr}, bs={bs}, grad_accum={grad_accum}")
            model, tokenizer, trainer = train_on_a100(
                train_data=train_subset,
                val_data=val_subset,
                output_dir=trial_dir,
                model_name="Salesforce/codet5-base",
                epochs=3,  # Use fewer epochs for search
                batch_size=bs,
                gradient_accumulation_steps=grad_accum,
                learning_rate=lr,
                fp16=True
            )

            # Get the final validation score
            metrics = [x for x in trainer.state.log_history if 'eval_rougeL' in x]
            if not metrics:
                return 0.0

            best_rouge_l = max(x['eval_rougeL'] for x in metrics)

            # Save result
            with open(f"{trial_dir}/result.json", 'w') as f:
                json.dump({
                    "best_eval_rougeL": best_rouge_l,
                    "final_metrics": metrics[-1] if metrics else None
                }, f, indent=2)

            return best_rouge_l

        except Exception as e:
            print(f"Error in trial {trial.number}: {e}")
            return 0.0

    # Create study and optimize
    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=n_trials)

    # Save best parameters
    with open(f"{search_dir}/best_params.json", 'w') as f:
        json.dump({
            "best_params": study.best_params,
            "best_value": study.best_value,
            "best_trial": study.best_trial.number
        }, f, indent=2)

    print(f"\nBest hyperparameters: {study.best_params}")
    print(f"Best RougeL: {study.best_value:.4f}")

    # Train final model with best parameters
    best_model_dir = f"{search_dir}/best_model"
    os.makedirs(best_model_dir, exist_ok=True)

    final_model, final_tokenizer, _ = train_on_a100(
        train_data=train_data,
        val_data=val_data,
        output_dir=best_model_dir,
        model_name="Salesforce/codet5-base",
        epochs=8,
        batch_size=study.best_params['batch_size'],
        gradient_accumulation_steps=study.best_params['gradient_accumulation_steps'],
        learning_rate=study.best_params['learning_rate'],
        fp16=True
    )

    return best_model_dir, study.best_params

In [None]:
def evaluate_test_set(model_path, test_data, output_dir=None):
    """Perform comprehensive evaluation on the full test set with metrics and examples"""

    if output_dir is None:
        output_dir = os.path.dirname(model_path) + "/test_evaluation"
    os.makedirs(output_dir, exist_ok=True)

    # Load model and tokenizer
    tokenizer = RobertaTokenizer.from_pretrained(model_path)
    model = T5ForConditionalGeneration.from_pretrained(model_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    # Track metrics
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    all_scores = {
        'rouge1': [], 'rouge2': [], 'rougeL': [],
        'code_length': [], 'reference_length': [], 'generated_length': []
    }

    # Category classifications
    categories = {
        'excellent': 0,  # RougeL > 0.9
        'good': 0,       # RougeL > 0.7
        'moderate': 0,   # RougeL > 0.5
        'poor': 0        # RougeL <= 0.5
    }

    # Detailed results for each example
    detailed_results = []

    print(f"Evaluating model on {len(test_data)} test examples...")
    batch_size = 8

    for i in range(0, len(test_data), batch_size):
        batch_end = min(i + batch_size, len(test_data))
        batch = test_data[i:batch_end]

        for example in batch:
            # Get code and reference
            code = example['input']
            reference = example['output']

            # Generate documentation
            input_text = f"Generate documentation for TypeScript code: {code}"
            inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)

            with torch.no_grad():
                outputs = model.generate(
                    inputs.input_ids,
                    max_length=256,
                    num_beams=4,
                    early_stopping=True
                )

            generated = tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Calculate metrics
            scores = scorer.score(reference, generated)

            # Track lengths
            all_scores['code_length'].append(len(code.split()))
            all_scores['reference_length'].append(len(reference.split()))
            all_scores['generated_length'].append(len(generated.split()))

            # Track ROUGE scores
            for key in ['rouge1', 'rouge2', 'rougeL']:
                all_scores[key].append(scores[key].fmeasure)

            # Categorize result
            rouge_l = scores['rougeL'].fmeasure
            if rouge_l > 0.9:
                categories['excellent'] += 1
            elif rouge_l > 0.7:
                categories['good'] += 1
            elif rouge_l > 0.5:
                categories['moderate'] += 1
            else:
                categories['poor'] += 1

            # Save detailed result
            detailed_results.append({
                'code': code,
                'reference': reference,
                'generated': generated,
                'rouge1': scores['rouge1'].fmeasure,
                'rouge2': scores['rouge2'].fmeasure,
                'rougeL': scores['rougeL'].fmeasure,
                'category': next(k for k, v in {'excellent': 0.9, 'good': 0.7, 'moderate': 0.5, 'poor': 0}
                                 .items() if rouge_l > v)
            })

        # Log progress
        if (i // batch_size) % 5 == 0:
            print(f"Processed {batch_end}/{len(test_data)} examples")

    # Calculate average metrics
    avg_metrics = {
        key: sum(values) / len(values) if values else 0
        for key, values in all_scores.items()
    }

    # Category percentages
    total = len(test_data)
    category_pcts = {k: 100 * v / total for k, v in categories.items()}

    # Save results
    results = {
        'avg_metrics': avg_metrics,
        'categories': categories,
        'category_percentages': category_pcts,
        'detailed_results': detailed_results[:20]  # Save only first 20 for space
    }

    with open(f"{output_dir}/test_results.json", 'w') as f:
        json.dump(results, f, indent=2)

    # Generate HTML report
    html_report = f"""
    <!DOCTYPE html>
    <html>
    <head>
        <title>TypeScript Documentation Model Evaluation</title>
        <style>
            body {{ font-family: Arial, sans-serif; margin: 20px; }}
            .header {{ background-color: #f0f0f0; padding: 15px; border-radius: 5px; }}
            .metrics {{ margin: 20px 0; }}
            .categories {{ display: flex; justify-content: space-between; margin: 20px 0; }}
            .category {{ text-align: center; padding: 10px; border-radius: 5px; }}
            .excellent {{ background-color: #d4edda; }}
            .good {{ background-color: #d1ecf1; }}
            .moderate {{ background-color: #fff3cd; }}
            .poor {{ background-color: #f8d7da; }}
            table {{ width: 100%; border-collapse: collapse; }}
            th, td {{ border: 1px solid #ddd; padding: 8px; }}
            tr:nth-child(even) {{ background-color: #f2f2f2; }}
            th {{ background-color: #4CAF50; color: white; }}
            .example {{ margin: 20px 0; padding: 15px; border: 1px solid #ddd; border-radius: 5px; }}
            pre {{ background-color: #f5f5f5; padding: 10px; border-radius: 5px; overflow-x: auto; }}
        </style>
    </head>
    <body>
        <div class="header">
            <h1>TypeScript Documentation Model Evaluation</h1>
            <p>Evaluation Date: {time.strftime('%Y-%m-%d %H:%M:%S')}</p>
        </div>

        <div class="metrics">
            <h2>Overall Metrics</h2>
            <table>
                <tr>
                    <th>Metric</th>
                    <th>Value</th>
                </tr>
                <tr>
                    <td>Average ROUGE-1</td>
                    <td>{avg_metrics['rouge1']:.4f}</td>
                </tr>
                <tr>
                    <td>Average ROUGE-2</td>
                    <td>{avg_metrics['rouge2']:.4f}</td>
                </tr>
                <tr>
                    <td>Average ROUGE-L</td>
                    <td>{avg_metrics['rougeL']:.4f}</td>
                </tr>
                <tr>
                    <td>Average Code Length (words)</td>
                    <td>{avg_metrics['code_length']:.1f}</td>
                </tr>
                <tr>
                    <td>Average Reference Doc Length (words)</td>
                    <td>{avg_metrics['reference_length']:.1f}</td>
                </tr>
                <tr>
                    <td>Average Generated Doc Length (words)</td>
                    <td>{avg_metrics['generated_length']:.1f}</td>
                </tr>
            </table>
        </div>

        <div class="categories">
            <div class="category excellent">
                <h3>Excellent</h3>
                <p>{categories['excellent']} examples</p>
                <p>({category_pcts['excellent']:.1f}%)</p>
            </div>
            <div class="category good">
                <h3>Good</h3>
                <p>{categories['good']} examples</p>
                <p>({category_pcts['good']:.1f}%)</p>
            </div>
            <div class="category moderate">
                <h3>Moderate</h3>
                <p>{categories['moderate']} examples</p>
                <p>({category_pcts['moderate']:.1f}%)</p>
            </div>
            <div class="category poor">
                <h3>Poor</h3>
                <p>{categories['poor']} examples</p>
                <p>({category_pcts['poor']:.1f}%)</p>
            </div>
        </div>

        <h2>Example Results</h2>
    """

    # Add a few examples from each category
    for category in ['excellent', 'good', 'moderate', 'poor']:
        examples = [r for r in detailed_results if r['category'] == category][:2]

        html_report += f"""
        <h3>{category.title()} Examples</h3>
        """

        for i, example in enumerate(examples):
            html_report += f"""
            <div class="example {category}">
                <h4>Example {i+1}</h4>
                <h5>Code:</h5>
                <pre>{example['code']}</pre>

                <h5>Generated Documentation:</h5>
                <pre>{example['generated']}</pre>

                <h5>Reference Documentation:</h5>
                <pre>{example['reference']}</pre>

                <p>
                    <strong>ROUGE-1:</strong> {example['rouge1']:.4f}
                    <strong>ROUGE-2:</strong> {example['rouge2']:.4f}
                    <strong>ROUGE-L:</strong> {example['rougeL']:.4f}
                </p>
            </div>
            """

    html_report += """
    </body>
    </html>
    """

    # Save HTML report
    with open(f"{output_dir}/evaluation_report.html", 'w') as f:
        f.write(html_report)

    print(f"\nTest evaluation completed. Results saved to {output_dir}")
    print(f"Summary metrics:")
    print(f"- ROUGE-L: {avg_metrics['rougeL']:.4f}")
    print(f"- Quality breakdown: {categories['excellent']} excellent, {categories['good']} good, " +
          f"{categories['moderate']} moderate, {categories['poor']} poor")

    return avg_metrics, categories

In [None]:
def implement_enhanced_typescript_documentation():
    """Implement all enhancements for the TypeScript documentation generator"""

    # Set up directories
    base_dir = '/content/drive/MyDrive/ts_documentation'
    enhanced_model_dir = f'{base_dir}/models/enhanced_codet5'
    os.makedirs(enhanced_model_dir, exist_ok=True)

    # Load data
    with open(f'{base_dir}/data/full_train_split.json', 'r') as f:
        train_data = json.load(f)

    with open(f'{base_dir}/data/full_val_split.json', 'r') as f:
        val_data = json.load(f)

    with open(f'{base_dir}/data/full_test_split.json', 'r') as f:
        test_data = json.load(f)

    print(f"Loaded {len(train_data)} training, {len(val_data)} validation, and {len(test_data)} test examples")

    # STEP 1: Find optimal hyperparameters (smaller search for demonstration)
    print("\n=== ENHANCEMENT 1: Hyperparameter Optimization ===")
    # Install optuna if needed
    try:
        import optuna
    except ImportError:
        !pip install optuna
        import optuna

    # Run hyperparameter search (use small n_trials for demonstration)
    best_model_dir, best_params = hyperparameter_search(
        train_data=train_data[:500],  # Use subset for faster search
        val_data=val_data,
        base_output_dir=enhanced_model_dir,
        n_trials=4  # Increase for better results
    )

    # STEP 2: Comprehensive test evaluation
    print("\n=== ENHANCEMENT 2: Comprehensive Test Evaluation ===")
    avg_metrics, categories = evaluate_test_set(
        model_path=best_model_dir,
        test_data=test_data,
        output_dir=f"{enhanced_model_dir}/test_evaluation"
    )

    # STEP 3: Generate enhanced documentation examples
    print("\n=== ENHANCEMENT 3: Enhanced Documentation Generation ===")
    example_code = """
    export function formatDate(date: Date, format: string = 'YYYY-MM-DD'): string {
      const year = date.getFullYear();
      const month = String(date.getMonth() + 1).padStart(2, '0');
      const day = String(date.getDate()).padStart(2, '0');

      let result = format;
      result = result.replace('YYYY', String(year));
      result = result.replace('MM', month);
      result = result.replace('DD', day);

      return result;
    }
    """

    print("\nOriginal code:")
    print(example_code)

    # Generate different documentation styles
    for style in ["standard", "jsdoc", "detailed", "markdown"]:
        print(f"\n--- {style.upper()} Documentation Style ---")
        enhanced_doc = generate_comprehensive_docs(
            model_path=best_model_dir,
            code_example=example_code,
            doc_style=style
        )
        print(enhanced_doc)

    print("\nEnhancements implementation completed!")
    return {
        "best_model_dir": best_model_dir,
        "best_params": best_params,
        "test_metrics": avg_metrics
    }

# Run the implementation
# implementation_results = implement_enhanced_typescript_documentation()

In [None]:
!pip install optuna  # For hyperparameter search
implementation_results = implement_enhanced_typescript_documentation()

In [None]:
def create_enhanced_templates():
    """Create improved documentation templates for different styles"""

    # Dictionary of template formats with improved prompts
    templates = {
        "standard": {
            "prompt": "Generate clear, accurate TypeScript documentation for this code. " +
                     "Focus only on the function or class name, parameters, return type, and purpose. " +
                     "Do not include any import paths or external references. CODE: {code}",
            "postprocess": standard_postprocessor
        },
        "jsdoc": {
            "prompt": "Generate JSDoc documentation for this TypeScript code. " +
                     "Include properly formatted @param tags for each parameter with accurate type information, " +
                     "an @returns tag with the correct return type, and a clear description of functionality. " +
                     "Exclude any import paths. Always extract the exact function name. CODE: {code}",
            "postprocess": jsdoc_postprocessor
        },
        "markdown": {
            "prompt": "Generate markdown documentation with precise headings for this TypeScript code. " +
                     "Include function name as title, followed by Parameters section with bullet points " +
                     "for each parameter showing name and type, Returns section with correct type, " +
                     "and Example section with sample usage code. Ensure exact function name is used. CODE: {code}",
            "postprocess": markdown_postprocessor
        }
    }

    return templates

def standard_postprocessor(generated_doc, original_code):
    """Post-process standard documentation to fix common issues"""
    # Extract actual function/class name from original code
    import re

    # Find function or class name
    fn_match = re.search(r'function\s+(\w+)', original_code)
    class_match = re.search(r'class\s+(\w+)', original_code)
    interface_match = re.search(r'interface\s+(\w+)', original_code)

    entity_name = None
    if fn_match:
        entity_name = fn_match.group(1)
    elif class_match:
        entity_name = class_match.group(1)
    elif interface_match:
        entity_name = interface_match.group(1)

    # If we found a name, ensure it's correct in the documentation
    if entity_name:
        # Replace any incorrect function names (e.g., formatFormat -> formatDate)
        generated_doc = re.sub(r'\*\*\w+\*\*', f'**{entity_name}**', generated_doc)
        generated_doc = re.sub(r'`\w+\(', f'`{entity_name}(', generated_doc)

    # Remove any import paths
    generated_doc = re.sub(r'import\(".*?"\)\.', '', generated_doc)

    return generated_doc

def jsdoc_postprocessor(generated_doc, original_code):
    """Post-process JSDoc documentation to fix common issues"""
    # Similar to standard postprocessor, with JSDoc-specific fixes
    import re

    # Find function or class name
    fn_match = re.search(r'function\s+(\w+)', original_code)
    class_match = re.search(r'class\s+(\w+)', original_code)

    entity_name = None
    if fn_match:
        entity_name = fn_match.group(1)
    elif class_match:
        entity_name = class_match.group(1)

    # If we found a name, ensure it's correct in the documentation
    if entity_name:
        # Replace function name in description
        generated_doc = re.sub(r'\*\*\w+\*\*', f'**{entity_name}**', generated_doc)
        generated_doc = re.sub(r'`\w+\(', f'`{entity_name}(', generated_doc)

    # Extract parameters from original code to ensure accurate @param tags
    params_match = re.search(r'\(([^)]*)\)', original_code)
    if params_match and entity_name:
        params_text = params_match.group(1).strip()
        param_list = []

        if params_text:
            for param in params_text.split(','):
                param = param.strip()
                if param:
                    # Extract parameter name and type
                    param_parts = param.split(':')
                    param_name = param_parts[0].strip()
                    param_type = param_parts[1].strip() if len(param_parts) > 1 else "any"

                    # Clean up default values
                    if '=' in param_name:
                        param_name = param_name.split('=')[0].strip()

                    param_list.append((param_name, param_type))

        # Check if @param tags are present or correct, fix if needed
        if not '@param' in generated_doc:
            # Add params if missing
            param_block = ""
            for name, type_info in param_list:
                param_block += f" * @param {name} - Parameter of type {type_info}\n"

            # Add to JSDoc before closing */
            if '*/' in generated_doc:
                generated_doc = generated_doc.replace('*/', param_block + ' */')

    # Remove any import paths
    generated_doc = re.sub(r'import\(".*?"\)\.', '', generated_doc)

    return generated_doc

def markdown_postprocessor(generated_doc, original_code):
    """Post-process markdown documentation to fix common issues"""
    # Similar process but for markdown formatting
    import re

    # Find function or class name
    fn_match = re.search(r'function\s+(\w+)', original_code)
    class_match = re.search(r'class\s+(\w+)', original_code)

    entity_name = None
    if fn_match:
        entity_name = fn_match.group(1)
    elif class_match:
        entity_name = class_match.group(1)

    # If we found a name, ensure it's correct in the documentation
    if entity_name:
        # Fix the title
        if re.search(r'^#\s+\w+', generated_doc):
            generated_doc = re.sub(r'^#\s+\w+', f'# {entity_name}', generated_doc)
        else:
            generated_doc = f"# {entity_name}\n\n" + generated_doc

    # Remove any import paths
    generated_doc = re.sub(r'import\(".*?"\)\.', '', generated_doc)

    # Ensure parameters section exists
    if "## Parameters" not in generated_doc:
        # Extract parameters from original code
        params_match = re.search(r'\(([^)]*)\)', original_code)
        if params_match:
            params_text = params_match.group(1).strip()
            param_section = "\n## Parameters\n\n"

            if params_text:
                for param in params_text.split(','):
                    param = param.strip()
                    if param:
                        # Extract parameter name and type
                        param_parts = param.split(':')
                        param_name = param_parts[0].strip()
                        param_type = param_parts[1].strip() if len(param_parts) > 1 else "any"

                        # Clean up default values
                        if '=' in param_name:
                            param_name, default = param_name.split('=')
                            param_name = param_name.strip()
                            param_type += f" = {default.strip()}"

                        param_section += f"- **{param_name}** (`{param_type}`): Description of parameter\n"
            else:
                param_section += "- No parameters\n"

            # Add parameters section after description
            generated_doc += param_section

    # Ensure returns section exists for functions
    if "## Returns" not in generated_doc and fn_match:
        # Extract return type from original code
        return_match = re.search(r'\):\s*([^{]+)', original_code)
        returns_section = "\n## Returns\n\n"

        if return_match:
            return_type = return_match.group(1).strip()
            returns_section += f"- `{return_type}`: Return value description\n"
        else:
            returns_section += "- `void`: This function doesn't return a value\n"

        # Add returns section
        generated_doc += returns_section

    # Ensure example section exists
    if "## Example" not in generated_doc and entity_name:
        example_section = "\n## Example\n\n```typescript\n// Example usage of " + entity_name + "\n```\n"
        generated_doc += example_section

    return generated_doc

In [None]:
import os
import json
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from transformers import RobertaTokenizer, T5ForConditionalGeneration
from rouge_score import rouge_scorer
from collections import defaultdict

def run_comprehensive_evaluation(model_path, test_data_path, output_dir):
    """
    Perform a detailed evaluation of the TypeScript documentation model.

    Args:
        model_path: Path to the fine-tuned model
        test_data_path: Path to the test data JSON
        output_dir: Directory to save evaluation results
    """
    print(f"Starting comprehensive evaluation of model at {model_path}")
    os.makedirs(output_dir, exist_ok=True)

    # Load test data
    with open(test_data_path, 'r') as f:
        test_data = json.load(f)

    print(f"Loaded {len(test_data)} test examples")

    # Load model and tokenizer
    tokenizer = RobertaTokenizer.from_pretrained(model_path)
    model = T5ForConditionalGeneration.from_pretrained(model_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    # Initialize scorer and metrics
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

    # Prepare metrics collection
    all_results = []
    metrics_by_length = defaultdict(list)
    metrics_by_complexity = defaultdict(list)

    # Define documentation styles for comparison
    doc_styles = {
        "standard": "Generate documentation for TypeScript code: ",
        "jsdoc": "Generate JSDoc-style documentation with @param, @returns, and @example tags for TypeScript code: ",
        "detailed": "Generate detailed documentation explaining purpose, parameters, return types, and usage examples for TypeScript code: ",
        "markdown": "Generate markdown documentation with sections for Parameters, Returns, and Examples for TypeScript code: "
    }

    # Run evaluation
    print("Starting evaluation...")
    for idx, example in enumerate(tqdm(test_data[:min(len(test_data), 200)])):  # Limit to 200 examples to keep it manageable
        code = example['input']
        reference = example['output']

        # Calculate code complexity (using length as a simple proxy)
        code_length = len(code.split())
        if code_length < 50:
            complexity = "simple"
        elif code_length < 150:
            complexity = "medium"
        else:
            complexity = "complex"

        # Evaluate across different documentation styles
        style_results = {}
        for style_name, style_prompt in doc_styles.items():
            input_text = f"{style_prompt}{code}"
            inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)

            # Generate documentation
            with torch.no_grad():
                output_sequences = model.generate(
                    inputs.input_ids,
                    max_length=256,
                    num_beams=4,
                    do_sample=True,
                    temperature=0.7 if style_name != "standard" else 0.5,  # Allow more creativity for detailed styles
                    early_stopping=True
                )

            # Decode generated text
            generated = tokenizer.decode(output_sequences[0], skip_special_tokens=True)

            # Calculate metrics
            scores = scorer.score(reference, generated)

            style_results[style_name] = {
                "generated": generated,
                "rouge1": scores['rouge1'].fmeasure,
                "rouge2": scores['rouge2'].fmeasure,
                "rougeL": scores['rougeL'].fmeasure,
                "length": len(generated.split())
            }

            # Track metrics by code length and complexity
            metrics_by_length[code_length].append(scores['rougeL'].fmeasure)
            metrics_by_complexity[complexity].append(scores['rougeL'].fmeasure)

        # Store result
        result = {
            "id": idx,
            "code": code,
            "reference": reference,
            "code_length": code_length,
            "complexity": complexity,
            "styles": style_results
        }
        all_results.append(result)

        # Save results periodically
        if (idx + 1) % 20 == 0:
            with open(f"{output_dir}/results_partial.json", 'w') as f:
                json.dump(all_results, f, indent=2)

    # Save full results
    with open(f"{output_dir}/evaluation_results.json", 'w') as f:
        json.dump(all_results, f, indent=2)

    # Calculate overall metrics
    style_metrics = {style: [] for style in doc_styles.keys()}
    for result in all_results:
        for style, metrics in result["styles"].items():
            style_metrics[style].append({
                "rouge1": metrics["rouge1"],
                "rouge2": metrics["rouge2"],
                "rougeL": metrics["rougeL"],
                "length": metrics["length"]
            })

    # Average metrics by style
    avg_metrics = {}
    for style, metrics_list in style_metrics.items():
        avg_metrics[style] = {
            "rouge1": np.mean([m["rouge1"] for m in metrics_list]),
            "rouge2": np.mean([m["rouge2"] for m in metrics_list]),
            "rougeL": np.mean([m["rougeL"] for m in metrics_list]),
            "length": np.mean([m["length"] for m in metrics_list])
        }

    # Save average metrics
    with open(f"{output_dir}/avg_metrics.json", 'w') as f:
        json.dump(avg_metrics, f, indent=2)

    # Generate visualizations
    generate_visualizations(all_results, avg_metrics, metrics_by_complexity, output_dir)

    # Create HTML report
    create_html_report(all_results, avg_metrics, metrics_by_complexity, output_dir)

    print(f"Evaluation complete! Results saved to {output_dir}")
    return avg_metrics

def generate_visualizations(results, avg_metrics, metrics_by_complexity, output_dir):
    """Generate visualization charts for the evaluation results"""
    os.makedirs(f"{output_dir}/charts", exist_ok=True)

    # Style comparison bar chart
    plt.figure(figsize=(12, 6))
    styles = list(avg_metrics.keys())
    rouge_l_scores = [avg_metrics[style]["rougeL"] for style in styles]

    bars = plt.bar(styles, rouge_l_scores, color='skyblue')

    # Add data labels
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                 f'{height:.4f}', ha='center', va='bottom', fontsize=10)

    plt.title('ROUGE-L Scores by Documentation Style')
    plt.xlabel('Documentation Style')
    plt.ylabel('Average ROUGE-L Score')
    plt.ylim(0, 1.0)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.savefig(f"{output_dir}/charts/style_comparison.png", dpi=300, bbox_inches='tight')

    # Complexity comparison
    plt.figure(figsize=(10, 6))
    complexity_categories = ["simple", "medium", "complex"]
    complexity_scores = [np.mean(metrics_by_complexity[cat]) for cat in complexity_categories]

    bars = plt.bar(complexity_categories, complexity_scores, color='lightgreen')

    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                 f'{height:.4f}', ha='center', va='bottom', fontsize=10)

    plt.title('ROUGE-L Scores by Code Complexity')
    plt.xlabel('Code Complexity')
    plt.ylabel('Average ROUGE-L Score')
    plt.ylim(0, 1.0)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.savefig(f"{output_dir}/charts/complexity_comparison.png", dpi=300, bbox_inches='tight')

    # Create a style vs. complexity heatmap
    complexity_style_matrix = {}
    for complexity in ["simple", "medium", "complex"]:
        complexity_style_matrix[complexity] = {}
        for style in avg_metrics.keys():
            # Filter results by complexity and style
            filtered_scores = []
            for result in results:
                if result["complexity"] == complexity:
                    filtered_scores.append(result["styles"][style]["rougeL"])

            if filtered_scores:
                complexity_style_matrix[complexity][style] = np.mean(filtered_scores)
            else:
                complexity_style_matrix[complexity][style] = 0

    # Convert to DataFrame for heatmap
    df = pd.DataFrame(complexity_style_matrix).T

    plt.figure(figsize=(10, 8))
    sns.heatmap(df, annot=True, cmap="YlGnBu", fmt=".4f", vmin=0, vmax=1)
    plt.title('ROUGE-L Scores by Style and Complexity')
    plt.ylabel('Code Complexity')
    plt.xlabel('Documentation Style')
    plt.savefig(f"{output_dir}/charts/style_complexity_heatmap.png", dpi=300, bbox_inches='tight')

def create_html_report(results, avg_metrics, metrics_by_complexity, output_dir):
    """Create an HTML report with evaluation results and examples"""
    # Select a few examples of different qualities for the report
    sorted_results = sorted(results, key=lambda x: x["styles"]["standard"]["rougeL"], reverse=True)

    excellent_examples = sorted_results[:2]  # Top 2 examples
    good_examples = sorted_results[len(sorted_results)//4:len(sorted_results)//4+2]  # Examples from 25% mark
    poor_examples = sorted_results[-2:]  # Bottom 2 examples

    # Create HTML content
    html_content = f"""
    <!DOCTYPE html>
    <html>
    <head>
        <title>TypeScript Documentation Model Evaluation Report</title>
        <style>
            body {{
                font-family: Arial, sans-serif;
                line-height: 1.6;
                margin: 0;
                padding: 20px;
                color: #333;
            }}
            .container {{
                max-width: 1200px;
                margin: 0 auto;
            }}
            header {{
                background-color: #f8f9fa;
                padding: 20px;
                margin-bottom: 30px;
                border-radius: 5px;
                box-shadow: 0 2px 5px rgba(0,0,0,0.1);
            }}
            h1 {{
                color: #2c3e50;
                margin-top: 0;
            }}
            h2 {{
                margin-top: 30px;
                border-bottom: 2px solid #eee;
                padding-bottom: 10px;
                color: #3498db;
            }}
            h3 {{
                color: #2980b9;
            }}
            .metrics-table {{
                width: 100%;
                border-collapse: collapse;
                margin: 20px 0;
            }}
            .metrics-table th, .metrics-table td {{
                border: 1px solid #ddd;
                padding: 12px;
                text-align: left;
            }}
            .metrics-table th {{
                background-color: #f2f2f2;
            }}
            .metrics-table tr:nth-child(even) {{
                background-color: #f9f9f9;
            }}
            .example {{
                background-color: #f8f9fa;
                padding: 15px;
                margin: 20px 0;
                border-radius: 5px;
                border-left: 5px solid #3498db;
            }}
            .code-block {{
                background-color: #f5f5f5;
                padding: 15px;
                border-radius: 5px;
                overflow-x: auto;
                font-family: monospace;
                font-size: 14px;
                line-height: 1.4;
            }}
            .chart {{
                width: 100%;
                max-width: 800px;
                margin: 20px auto;
                display: block;
            }}
            .excellent {{ border-left-color: #2ecc71; }}
            .good {{ border-left-color: #f39c12; }}
            .poor {{ border-left-color: #e74c3c; }}
        </style>
    </head>
    <body>
        <div class="container">
            <header>
                <h1>TypeScript Documentation Model Evaluation Report</h1>
                <p>Evaluation date: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M')}</p>
                <p>Number of examples evaluated: {len(results)}</p>
            </header>

            <h2>Overall Metrics</h2>
            <table class="metrics-table">
                <tr>
                    <th>Documentation Style</th>
                    <th>ROUGE-1</th>
                    <th>ROUGE-2</th>
                    <th>ROUGE-L</th>
                    <th>Avg. Length (words)</th>
                </tr>
    """

    # Add metrics rows
    for style, metrics in avg_metrics.items():
        html_content += f"""
                <tr>
                    <td>{style.capitalize()}</td>
                    <td>{metrics['rouge1']:.4f}</td>
                    <td>{metrics['rouge2']:.4f}</td>
                    <td>{metrics['rougeL']:.4f}</td>
                    <td>{metrics['length']:.1f}</td>
                </tr>
        """

    html_content += """
            </table>

            <h2>Visualizations</h2>
            <h3>Style Comparison</h3>
            <img src="charts/style_comparison.png" alt="Style Comparison Chart" class="chart">

            <h3>Complexity Analysis</h3>
            <img src="charts/complexity_comparison.png" alt="Complexity Comparison Chart" class="chart">

            <h3>Style vs. Complexity</h3>
            <img src="charts/style_complexity_heatmap.png" alt="Style vs. Complexity Heatmap" class="chart">

            <h2>Example Outputs</h2>
    """

    # Add excellent examples
    html_content += """
            <h3>High-Quality Examples</h3>
    """

    for example in excellent_examples:
        html_content += f"""
            <div class="example excellent">
                <h4>Code (Complexity: {example['complexity']})</h4>
                <pre class="code-block">{example['code']}</pre>

                <h4>Reference Documentation</h4>
                <pre class="code-block">{example['reference']}</pre>

                <h4>Generated Documentation (ROUGE-L: {example['styles']['standard']['rougeL']:.4f})</h4>
                <pre class="code-block">{example['styles']['standard']['generated']}</pre>
            </div>
        """

    # Add good examples
    html_content += """
            <h3>Average-Quality Examples</h3>
    """

    for example in good_examples:
        html_content += f"""
            <div class="example good">
                <h4>Code (Complexity: {example['complexity']})</h4>
                <pre class="code-block">{example['code']}</pre>

                <h4>Reference Documentation</h4>
                <pre class="code-block">{example['reference']}</pre>

                <h4>Generated Documentation (ROUGE-L: {example['styles']['standard']['rougeL']:.4f})</h4>
                <pre class="code-block">{example['styles']['standard']['generated']}</pre>
            </div>
        """

    # Add poor examples
    html_content += """
            <h3>Low-Quality Examples</h3>
    """

    for example in poor_examples:
        html_content += f"""
            <div class="example poor">
                <h4>Code (Complexity: {example['complexity']})</h4>
                <pre class="code-block">{example['code']}</pre>

                <h4>Reference Documentation</h4>
                <pre class="code-block">{example['reference']}</pre>

                <h4>Generated Documentation (ROUGE-L: {example['styles']['standard']['rougeL']:.4f})</h4>
                <pre class="code-block">{example['styles']['standard']['generated']}</pre>

                <h4>Possible Improvements</h4>
                <ul>
                    <li>Better capture of parameter types and descriptions</li>
                    <li>More accurate return type documentation</li>
                    <li>Better understanding of the function's purpose</li>
                </ul>
            </div>
        """

    # Add style comparison examples
    html_content += """
            <h3>Documentation Style Comparison</h3>
    """

    # Pick a medium-complexity example to show style differences
    medium_examples = [r for r in results if r['complexity'] == 'medium']
    if medium_examples:
        style_example = medium_examples[0]

        html_content += f"""
            <div class="example">
                <h4>Original Code</h4>
                <pre class="code-block">{style_example['code']}</pre>

                <h4>Reference Documentation</h4>
                <pre class="code-block">{style_example['reference']}</pre>
        """

        for style, metrics in style_example['styles'].items():
            html_content += f"""
                <h4>{style.capitalize()} Style (ROUGE-L: {metrics['rougeL']:.4f})</h4>
                <pre class="code-block">{metrics['generated']}</pre>
            """

        html_content += """
            </div>
        """

    # Close HTML
    html_content += """
            <h2>Conclusion</h2>
            <p>The evaluation indicates that the model performs well across different types of TypeScript code and documentation styles.
            The best performance is observed with the standard documentation style, while more complex formats like JSDoc and Markdown show slightly lower ROUGE scores but provide more structured and detailed documentation.</p>

            <p>Performance tends to decrease as code complexity increases, which is expected. Future improvements could focus on better handling of complex TypeScript constructs and more accurate parameter inference.</p>
        </div>
    </body>
    </html>
    """

    # Write HTML to file
    with open(f"{output_dir}/evaluation_report.html", 'w') as f:
        f.write(html_content)

In [None]:
# Define paths
model_path = '/content/drive/MyDrive/ts_documentation/models/enhanced_codet5/hparam_search_20250509_080333/best_model'  # Update to your best model path
test_data_path = '/content/drive/MyDrive/ts_documentation/data/full_test_split.json'
output_dir = '/content/drive/MyDrive/ts_documentation/evaluation_results'

# Install required libraries if not already present
!pip install -q matplotlib seaborn

# Run the comprehensive evaluation
avg_metrics = run_comprehensive_evaluation(model_path, test_data_path, output_dir)

# Print summary of results
print("\nEvaluation Summary:")
for style, metrics in avg_metrics.items():
    print(f"- {style.capitalize()} Style:")
    print(f"  - ROUGE-L: {metrics['rougeL']:.4f}")
    print(f"  - Avg Length: {metrics['length']:.1f} words")