## Trying to improve RAG document selection

## Retrieval Improvements

1. **Hybrid Retrieval**: Instead of relying solely on embedding-based retrieval, implement a hybrid system that combines:
   - Dense retrieval (your current embedding approach)
   - Sparse retrieval (BM25 or TF-IDF for terminology matching)
   - Structural retrieval (specifically targeting OML syntax patterns)

2. **Chunking Strategy**: Rather than using entire examples as retrieval units, consider:
   - Breaking examples into functional components (e.g., aspect definitions, relation definitions)
   - Creating a hierarchical retrieval system that first retrieves relevant examples, then relevant components

3. **Domain-Specific Embeddings**: Fine-tune your embedding model specifically on OML code to better capture the semantic relationships between OML concepts.

## Prompt Engineering

1. **Task-Specific Instructions**: Add explicit instructions in your prompt that specifically ask for code generation, not explanation:
   ```
   Your task is to generate ONLY syntactically valid OML code based on the request.
   DO NOT explain the code or provide any text outside the code block.
   ```

2. **Format Forcing**: Use format-forcing techniques to increase the likelihood of getting code:
   ```
   You MUST respond with a code block that starts with ```oml and ends with ```
   DO NOT include any explanatory text before or after the code block.
   ```

3. **Few-Shot Examples**: Include 2-3 direct examples of input-output pairs in the prompt itself (not just in the retrieved content).

## Model Selection & Engineering

1. **Model Fine-Tuning**: Consider fine-tuning a smaller model specifically for OML code generation using your 97 examples. Even a small model fine-tuned on domain-specific data often outperforms larger general models.

2. **Output Parsing**: Implement a robust post-processing layer that:
   - Extracts code blocks from mixed outputs
   - Validates OML syntax
   - Gracefully handles and fixes common errors

3. **Temperature Setting**: For code generation, use a lower temperature setting (0.1-0.3) to produce more deterministic, focused outputs.

## Data Augmentation

1. **Synthetic Example Generation**: Use your existing examples to generate additional synthetic examples that cover a wider range of OML patterns and domains.

2. **Error Variants**: Create examples that include common errors and their corrections to help the model learn proper syntax.

3. **Decontextualized Examples**: Create versions of your examples that are stripped of domain-specific content, focusing purely on OML syntax patterns.

## Architecture Improvements

1. **Two-Stage Generation**: Split the generation into two steps:
   - First, generate a "skeleton" of the OML structure
   - Then, fill in the details with content-specific elements

2. **Self-Consistency**: Generate multiple candidates and select the one with the most valid OML syntax.

3. **Human Feedback Loop**: Implement a feedback mechanism where failed generations are collected, corrected, and added back to the training data.

By implementing some of these improvements, you could significantly enhance the reliability and quality of your OML code generation system, especially when working with smaller models like Mistral and CodeLlama-7B.l

### Implementing Hybrid Retrieval for OML Code Generation
A hybrid retrieval system combines multiple retrieval methods to leverage their complementary strengths. For your OML code generation task, here's a detailed approach to implementing hybrid retrieval:

1. Components of Hybrid Retrieval
Dense Retrieval (Semantic Search)

Current Approach: You're using SentenceTransformer ('intfloat/multilingual-e5-large-instruct') to create embeddings for query and document matching.
Enhancement: Keep this as your semantic understanding base, but supplement it with:

Sparse Retrieval (Lexical Matching)
Implementation: Add BM25 or TF-IDF based retrieval

Benefits:
Better at exact terminology matching
Captures OML-specific keywords that may be diluted in embeddings
Less sensitive to context window limitations



Structural Retrieval (Syntax-Aware)
Implementation: Create a structure-based index that captures OML patterns
Benefits:

Identifies examples with similar structural patterns regardless of domain
Helps find the right OML syntax templates even when domains differ.

In [None]:
!pip install rank_bm25

Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank_bm25
Successfully installed rank_bm25-0.2.2


In [None]:
# Install required packages if not already installed
!pip install -q sentence-transformers rank-bm25 ipywidgets
!jupyter nbextension enable --py widgetsnbextension
!pip install -q tqdm

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m108.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m82.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m55.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
class VocabularyManager:
    def __init__(self, vocab_file_path):
        self.vocab_index = {}
        self.vocab_file = vocab_file_path
        self._build_index()

    def _build_index(self):
        """Build an index of vocabulary terms and their positions in the file."""
        with open(self.vocab_file, 'r', encoding='utf-8') as f:
            position = 0
            for line in f:
                if line.strip():
                    try:
                        term_data = json.loads(line)
                        term = term_data.get('term', '')
                        if term:
                            self.vocab_index[term] = position
                    except json.JSONDecodeError:
                        pass
                position = f.tell()

    def get_definition(self, term):
        """Retrieve a specific definition by term."""
        if term not in self.vocab_index:
            return None

        with open(self.vocab_file, 'r', encoding='utf-8') as f:
            f.seek(self.vocab_index[term])
            definition_line = f.readline()
            try:
                return json.loads(definition_line)
            except json.JSONDecodeError:
                return None

    def get_relevant_definitions(self, text):
        """Find all vocabulary terms mentioned in a text and return their definitions."""
        relevant_terms = [term for term in self.vocab_index.keys() if term in text]
        return {term: self.get_definition(term) for term in relevant_terms}

In [None]:
import json
import numpy as np
import re
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
import argparse
import os
import torch
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  # Colab-friendly progress bars
import ipywidgets as widgets
from IPython.display import display, clear_output

# Check if GPU is available (Colab usually provides GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


try:
    import sentence_transformers
    from rank_bm25 import BM25Okapi
except ImportError:
    !pip install -q sentence-transformers rank-bm25
    import sentence_transformers
    from rank_bm25 import BM25Okapi

from sentence_transformers import SentenceTransformer

# Helper function for loading JSONL data - moved outside the class
def load_jsonl_data(file_path):
    """Load JSONL data from file."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                try:
                    data.append(json.loads(line))
                except json.JSONDecodeError as e:
                    print(f"Error parsing JSON line: {e}")
                    print(f"Problematic line: {line[:100]}...")
    return data

# Helper function to truncate text
def truncate_text(text, max_length=100):
    """Truncate text to a maximum length."""
    if len(text) <= max_length:
        return text
    return text[:max_length] + "..."

class HybridRetriever:
    def __init__(self, jsonl_data):
        self.examples = jsonl_data

        print("Initializing embedding model...")
        self.embedding_model = SentenceTransformer('intfloat/multilingual-e5-large-instruct', device=device)

        print("Building dense index...")
        self.dense_index = self._build_dense_index()

        print("Building sparse index...")
        self.bm25 = self._build_sparse_index()

        print("Extracting structure patterns...")
        self.structure_patterns = self._extract_structure_patterns()

        print("Retriever initialized!")

    def _build_dense_index(self):
        # Embed all examples using the E5 model
        dense_index = []
        for i, example in enumerate(tqdm(self.examples, desc="Creating embeddings")):
            input_text = example['input']
            output_text = example['output']

            # Create embeddings following E5's recommended format
            input_embedding = self.embedding_model.encode(f"query: {input_text}", convert_to_numpy=True)
            output_embedding = self.embedding_model.encode(f"passage: {output_text}", convert_to_numpy=True)

            dense_index.append({
                'id': example.get('id', str(i)),
                'input_embedding': input_embedding,
                'output_embedding': output_embedding,
                'example': example
            })
        return dense_index

    def _build_sparse_index(self):
        # Create BM25 index for lexical matching
        corpus = []
        for example in self.examples:
            # Combine input and output with emphasis on OML syntax patterns
            oml_terms = re.findall(r'(aspect|concept|scalar|relation|property|vocabulary|extends|key)',
                                  example['output'])
            enhanced_text = f"{example['input']} {example['output']} {' '.join(oml_terms)}"
            corpus.append(enhanced_text)

        tokenized_corpus = [doc.lower().split() for doc in corpus]
        return BM25Okapi(tokenized_corpus)

    def _extract_structure_patterns(self):
        # Extract structural patterns from OML code
        patterns = []
        for i, example in enumerate(self.examples):
            output = example['output']

            # Extract OML structural patterns
            structure = {
                'has_aspect': bool(re.search(r'aspect\s+\w+', output)),
                'has_concept': bool(re.search(r'concept\s+\w+', output)),
                'has_scalar': bool(re.search(r'scalar\s+\w+', output)),
                'has_relation': bool(re.search(r'relation\s+\w+', output)),
                'has_relation_entity': bool(re.search(r'relation\s+entity\s+\w+', output)),
                'inheritance_depth': len(re.findall(r'<\s*\w+', output)),
                'property_count': len(re.findall(r'property\s+\w+', output)),
                'annotation_count': len(re.findall(r'@\w+', output))
            }

            patterns.append({
                'id': example.get('id', str(i)),
                'structure': structure,
                'example': example
            })
        return patterns

    def retrieve(self, query, top_k=5, weights=None, vocab_manager=None):
        if weights is None:
            weights = {
                'dense': 0.5,
                'sparse': 0.3,
                'structure': 0.2
            }

        print("\nRetrieving using all methods...")
        # Get results from each retrieval method
        dense_results = self._dense_retrieve(query, top_k)
        sparse_results = self._sparse_retrieve(query, top_k)
        structure_results = self._structure_retrieve(query, top_k)

        # Show individual method results
        print("\nDense Retrieval Results:")
        for i, (example, score) in enumerate(dense_results[:3]):
            print(f"  {i+1}. Score: {score:.4f} - {example['input'][:100]}...")

        print("\nSparse Retrieval Results:")
        for i, (example, score) in enumerate(sparse_results[:3]):
            print(f"  {i+1}. Score: {score:.4f} - {example['input'][:100]}...")

        print("\nStructure Retrieval Results:")
        for i, (example, score) in enumerate(structure_results[:3]):
            print(f"  {i+1}. Score: {score:.4f} - {example['input'][:100]}...")

        # Combine results
        print("\nCombining results with weights:", weights)
        combined_results = self._combine_results(dense_results, sparse_results,
                                           structure_results, weights)

        # Add relevant vocabulary if available
        if vocab_manager:
          for result in combined_results[:top_k]:
            example = result[0]
            # Find relevant vocabulary for this example
            relevant_vocab = vocab_manager.get_relevant_definitions(
                example['input'] + ' ' + example['output']
            )
            # Attach to the result
            example['relevant_vocabulary'] = relevant_vocab

        return combined_results[:top_k]

    def _dense_retrieve(self, query, top_k):
        query_embedding = self.embedding_model.encode(f"query: {query}")

        results = []
        for item in self.dense_index:
            # Calculate similarity to both input and output embeddings
            input_similarity = self._cosine_similarity(query_embedding, item['input_embedding'])
            output_similarity = self._cosine_similarity(query_embedding, item['output_embedding'])

            # Use the maximum similarity as the score
            similarity = max(input_similarity, output_similarity)
            results.append((item['example'], similarity))

        # Sort by similarity
        results.sort(key=lambda x: x[1], reverse=True)
        return results[:top_k]

    def _sparse_retrieve(self, query, top_k):
        # Use BM25 for lexical matching
        tokenized_query = query.lower().split()
        bm25_scores = self.bm25.get_scores(tokenized_query)

        # Pair examples with their BM25 scores
        results = [(self.examples[i], bm25_scores[i]) for i in range(len(self.examples))]

        # Sort by BM25 score
        results.sort(key=lambda x: x[1], reverse=True)
        return results[:top_k]

    def _structure_retrieve(self, query, top_k):

      # Extract structure hints from the query
      query_structure = {
        'wants_aspect': bool(re.search(r'aspect|property|attribute', query, re.I)),
        'wants_concept': bool(re.search(r'concept|class|type', query, re.I)),
        'wants_scalar': bool(re.search(r'scalar|datatype|value', query, re.I)),
        'wants_relation': bool(re.search(r'relation|relationship|connection', query, re.I)),
        'wants_vocabulary': bool(re.search(r'vocabulary|ontology|vocab', query, re.I)),
        'wants_annotation': bool(re.search(r'annotation|comment|label|description', query, re.I)),
        'wants_rule': bool(re.search(r'rule|infer|constraint', query, re.I)),
        'wants_inheritance': bool(re.search(r'inherit|extends|specializes|subclass', query, re.I)),
        'wants_restriction': bool(re.search(r'restrict|exactly|min|max|all|some', query, re.I))
    }

      # Score examples based on structural similarity
      results = []
      for pattern in self.structure_patterns:
        structure = pattern['structure']
        # Simple matching score
        score = 0
        if query_structure['wants_aspect'] and structure['has_aspect']:
          score += 1
        if query_structure['wants_concept'] and structure['has_concept']:
          score += 1
        if query_structure['wants_scalar'] and structure['has_scalar']:
          score += 1
        if query_structure['wants_relation'] and structure['has_relation']:
          score += 1
        if query_structure['wants_vocabulary'] and structure.get('has_vocabulary', False):
          score += 1
        if query_structure['wants_annotation'] and structure.get('has_annotation', False):
          score += 1
        if query_structure['wants_rule'] and structure.get('has_rule', False):
          score += 1
        if query_structure['wants_inheritance'] and structure.get('inheritance_depth', 0) > 0:
          score += 1
        if query_structure['wants_restriction'] and structure.get('has_restriction', False):
          score += 1

        # Contextual domain match (optional)
        domain_keywords = self._extract_domain_keywords(query)
        if any(kw in pattern['example'].get('tags', []) for kw in domain_keywords):
            score += 0.5

        results.append((pattern['example'], score))

      # Sort by structure match score
      results.sort(key=lambda x: x[1], reverse=True)
      return results[:top_k]

    def _extract_domain_keywords(self, query):
      # Extract domain keywords from query
      domains = {
        'pizza': ['pizza', 'food', 'topping', 'cheese', 'base'],
        'security': ['security', 'cyber', 'defense', 'attack', 'mitigation'],
        'system': ['system', 'component', 'architecture', 'interface'],
        'metamodel': ['metamodel', 'capella', 'modeling', 'uml']
    }

      keywords = []
      for domain, terms in domains.items():
        if any(term in query.lower() for term in terms):
          keywords.append(domain)
          keywords.extend(terms)

      return keywords

    def _combine_results(self, dense_results, sparse_results, structure_results, weights):
        # Combine and normalize scores
        result_map = {}

        # Process dense results
        max_dense_score = max([score for _, score in dense_results]) if dense_results else 1.0
        for example, score in dense_results:
            example_id = example.get('id', str(id(example)))
            if example_id not in result_map:
                result_map[example_id] = {'example': example, 'score': 0}
            # Normalize dense scores
            normalized_score = score / max_dense_score if max_dense_score > 0 else 0
            result_map[example_id]['score'] += normalized_score * weights['dense']

        # Process sparse results
        max_sparse_score = max([score for _, score in sparse_results]) if sparse_results else 1.0
        for example, score in sparse_results:
            example_id = example.get('id', str(id(example)))
            if example_id not in result_map:
                result_map[example_id] = {'example': example, 'score': 0}
            # Normalize BM25 scores
            normalized_score = score / max_sparse_score if max_sparse_score > 0 else 0
            result_map[example_id]['score'] += normalized_score * weights['sparse']

        # Process structure results
        max_structure_score = max([score for _, score in structure_results]) if structure_results else 1.0
        for example, score in structure_results:
            example_id = example.get('id', str(id(example)))
            if example_id not in result_map:
                result_map[example_id] = {'example': example, 'score': 0}
            # Normalize structure scores
            normalized_score = score / max_structure_score if max_structure_score > 0 else 0
            result_map[example_id]['score'] += normalized_score * weights['structure']

        # Convert to list and sort
        combined_results = [(item['example'], item['score']) for item in result_map.values()]
        combined_results.sort(key=lambda x: x[1], reverse=True)

        return combined_results

    def _cosine_similarity(self, a, b):
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


# Interactive Colab-friendly function
def run_interactive_demo(jsonl_path):
    # Load data
    print(f"Loading data from {jsonl_path}...")
    examples = load_jsonl_data(jsonl_path)
    print(f"Loaded {len(examples)} examples")

    # Initialize retriever
    retriever = HybridRetriever(examples)

    # Create interactive widgets for Colab
    query_widget = widgets.Text(
        value='',
        placeholder='Type your query here',
        description='Query:',
        disabled=False,
        layout=widgets.Layout(width='80%')
    )

    dense_weight = widgets.FloatSlider(
        value=0.5,
        min=0,
        max=1.0,
        step=0.1,
        description='Dense:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.1f',
    )

    sparse_weight = widgets.FloatSlider(
        value=0.3,
        min=0,
        max=1.0,
        step=0.1,
        description='Sparse:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.1f',
    )

    structure_weight = widgets.FloatSlider(
        value=0.2,
        min=0,
        max=1.0,
        step=0.1,
        description='Structure:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.1f',
    )

    top_k = widgets.IntSlider(
        value=5,
        min=1,
        max=10,
        step=1,
        description='Top K:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d',
    )

    output = widgets.Output()

    def on_button_clicked(b):
        with output:
            clear_output()

        query = query_widget.value

        if not query:
            with output:
                print("Please enter a query")
            return

        # Normalize weights
        total = dense_weight.value + sparse_weight.value + structure_weight.value
        weights = {
            'dense': dense_weight.value / total,
            'sparse': sparse_weight.value / total,
            'structure': structure_weight.value / total
        }

        with output:
            print(f"Query: {query}")
            print(f"Weights: Dense={weights['dense']:.2f}, Sparse={weights['sparse']:.2f}, Structure={weights['structure']:.2f}")
            print(f"Retrieving top {top_k.value} results...")

            # Retrieve results
            results = retriever.retrieve(query, top_k.value, weights)

            # Display results
            print("\n=== COMBINED RETRIEVAL RESULTS ===")
            for i, (example, score) in enumerate(results):
                print(f"\n{i+1}. Score: {score:.4f}")
                print(f"   Title: {example.get('title', 'N/A')}")
                print(f"   Input: {truncate_text(example['input'])}")

                # Only show short preview in the list
                output_text = example['output']
                print(f"   Output Preview: {truncate_text(output_text, 150)}")

                # Add a button to view the full output
                view_btn_id = f"view_btn_{i}"

            # Create view buttons outside the loop
            for i in range(len(results)):
                view_btn = widgets.Button(
                    description=f"View full #{i+1}",
                    layout=widgets.Layout(width='120px')
                )

                # Create a closure to capture the current index
                def make_view_handler(idx):
                    def view_handler(b):
                        with output:
                            print(f"\n=== FULL OUTPUT FOR RESULT #{idx+1} ===")
                            print(results[idx][0]['output'])
                    return view_handler

                view_btn.on_click(make_view_handler(i))
                display(view_btn)

                # Create a collapsible section for the full output
                details = widgets.Accordion(
                    children=[widgets.HTML(value=f"<pre>{results[i][0]['output']}</pre>")],
                    selected_index=None
                )
                details.set_title(0, f"Full Output for Result #{i+1}")
                display(details)

    button = widgets.Button(description="Retrieve")
    button.on_click(on_button_clicked)

    display(widgets.HTML(value="<h3>Hybrid Retrieval for OML Code Generation</h3>"))
    display(widgets.HBox([query_widget, button]))
    display(widgets.HTML(value="<h4>Retrieval Weights</h4>"))
    display(widgets.VBox([dense_weight, sparse_weight, structure_weight, top_k]))
    display(output)


# Standard main function for direct execution
def main():
    # Modified for Colab compatibility - don't use argparse in Colab
    # Check if running in Colab
    try:
        import google.colab
        from google.colab import files
        in_colab = True

        # Default values for Colab
        jsonl_path = '3.jsonl'
        top_k = 5

        # If file doesn't exist, ask user to upload
        if not os.path.exists(jsonl_path):
            print(f"File {jsonl_path} not found. Please upload it.")
            uploaded = files.upload()
            # Use the first uploaded file
            if uploaded:
                jsonl_path = next(iter(uploaded.keys()))

        # Run interactive demo in Colab
        run_interactive_demo(jsonl_path)

    except ImportError:
        # Not in Colab, use argparse for command line
        parser = argparse.ArgumentParser(description='Test hybrid retrieval for OML code generation')
        parser.add_argument('--jsonl', type=str, default='3.jsonl',
                            help='Path to JSONL file containing examples')
        parser.add_argument('--top_k', type=int, default=5,
                            help='Number of examples to retrieve')
        args = parser.parse_args()

        # Original command-line version
        # Load data
        print(f"Loading data from {args.jsonl}...")
        examples = load_jsonl_data(args.jsonl)
        print(f"Loaded {len(examples)} examples")

        # Initialize retriever
        retriever = HybridRetriever(examples)

        # Interactive query loop
        while True:
            query = input("\nEnter your query (or 'q' to quit): ")
            if query.lower() == 'q':
                break

            # Allow adjusting weights
            use_custom_weights = input("Use custom weights? (y/n, default: n): ").lower() == 'y'
            if use_custom_weights:
                try:
                    dense_weight = float(input("Dense weight (0-1, default: 0.5): ") or 0.5)
                    sparse_weight = float(input("Sparse weight (0-1, default: 0.3): ") or 0.3)
                    structure_weight = float(input("Structure weight (0-1, default: 0.2): ") or 0.2)

                    # Normalize weights
                    total = dense_weight + sparse_weight + structure_weight
                    weights = {
                        'dense': dense_weight / total,
                        'sparse': sparse_weight / total,
                        'structure': structure_weight / total
                    }
                except ValueError:
                    print("Invalid weights, using defaults")
                    weights = {'dense': 0.5, 'sparse': 0.3, 'structure': 0.2}
            else:
                weights = {'dense': 0.5, 'sparse': 0.3, 'structure': 0.2}

            # Retrieve results
            results = retriever.retrieve(query, args.top_k, weights)

            # Display results
            print("\n=== COMBINED RETRIEVAL RESULTS ===")
            for i, (example, score) in enumerate(results):
                print(f"\n{i+1}. Score: {score:.4f}")
                print(f"   Title: {example.get('title', 'N/A')}")
                print(f"   Input: {truncate_text(example['input'])}")

                # Format output to be more readable
                output_text = example['output']
                # If output is long, truncate with indicator
                if len(output_text) > 300:
                    output_preview = output_text[:300] + "...[truncated]"
                else:
                    output_preview = output_text

                print(f"   Output Preview: {output_preview}")

            # Option to see full output for a result
            view_full = input("\nView full output for a result (number or 'n')? ")
            if view_full.isdigit() and 1 <= int(view_full) <= len(results):
                idx = int(view_full) - 1
                print("\n=== FULL OUTPUT ===")
                print(results[idx][0]['output'])


# For Colab compatibility: Run directly
if __name__ == "__main__":
    try:
        # First, check if we're in Colab
        import google.colab
        from google.colab import files
        print("Detected Google Colab environment. Starting interactive demo...")

        # Default jsonl path
        jsonl_path = '3.jsonl'

        # If file doesn't exist, ask user to upload
        if not os.path.exists(jsonl_path):
            print(f"File {jsonl_path} not found. Please upload it.")
            uploaded = files.upload()
            # Use the first uploaded file
            if uploaded:
                jsonl_path = next(iter(uploaded.keys()))

        # Run interactive demo directly in Colab
        run_interactive_demo(jsonl_path)
    except ImportError:
        # Not in Colab, run normal main function
        main()

Using device: cuda
Detected Google Colab environment. Starting interactive demo...
Loading data from 3.jsonl...
Loaded 97 examples
Initializing embedding model...
Building dense index...


Creating embeddings:   0%|          | 0/97 [00:00<?, ?it/s]

Building sparse index...
Extracting structure patterns...
Retriever initialized!


HTML(value='<h3>Hybrid Retrieval for OML Code Generation</h3>')

HBox(children=(Text(value='', description='Query:', layout=Layout(width='80%'), placeholder='Type your query h…

HTML(value='<h4>Retrieval Weights</h4>')

VBox(children=(FloatSlider(value=0.5, continuous_update=False, description='Dense:', max=1.0, readout_format='…

Output()

## Adding component chunks for additional context

In [None]:
class VocabularyManager:
    def __init__(self, vocab_file_path):
        self.vocab_index = {}
        self.vocab_file = vocab_file_path
        self._build_index()

    def _build_index(self):
        """Build an index of vocabulary terms and their positions in the file."""
        with open(self.vocab_file, 'r', encoding='utf-8') as f:
            position = 0
            for line in f:
                if line.strip():
                    try:
                        term_data = json.loads(line)
                        term = term_data.get('term', '')
                        if term:
                            self.vocab_index[term] = position
                    except json.JSONDecodeError:
                        pass

    def _build_component_index(self):
      """Build an index of individual OML components extracted from examples."""
      self.component_index = []

      print("Building component-level index...")
      for i, example in enumerate(tqdm(self.examples, desc="Extracting components")):
        # Extract components from the example
        components = self._extract_components(example)

        for component in components:
          # Create embeddings for the component code
          component_code = component['code']
          component_type = component['type']

          # Format the component for embedding
          component_text = f"{component_type}: {component_code}"

          # Create embeddings using the E5 model
          component_embedding = self.embedding_model.encode(
          f"query: {component_text}",
          convert_to_numpy=True
          )

          # Generate sparse representation for the component
          component_tokens = component_text.lower().split()

          # Extract structural patterns specific to this component
          structure = self._extract_component_structure(component)

          # Add to component index with all necessary metadata
          self.component_index.append({
                'id': f"{example.get('id', str(i))}_component_{len(self.component_index)}",
                'type': component_type,
                'code': component_code,
                'embedding': component_embedding,
                'tokens': component_tokens,
                'structure': structure,
                'parent_example': example
            })

          print(f"Extracted {len(components)} components from example {i+1}/{len(self.examples)}")

        print(f"Built component index with {len(self.component_index)} total components")

      # Create BM25 index for component-level sparse retrieval
      component_texts = [f"{c['type']}: {c['code']}" for c in self.component_index]
      tokenized_components = [text.lower().split() for text in component_texts]
      self.component_bm25 = BM25Okapi(tokenized_components)

def _extract_component_structure(self, component):
    """Extract structural information specific to this component type."""
    code = component['code']
    component_type = component['type']

    structure = {
        'component_type': component_type,
        'has_inheritance': '<' in code,
        'has_constraints': '[' in code and ']' in code,
        'name': re.search(r'\b(\w+)', code.split()[1]).group(1) if len(code.split()) > 1 else '',
    }

    # Extract specific patterns based on component type
    if component_type == 'concept':
        structure['parent_concepts'] = re.findall(r'<\s*(\w+)', code)
        structure['restricts'] = re.findall(r'restricts\s+(\w+)', code)
    elif component_type == 'relation' or component_type == 'relation_entity':
        structure['from_type'] = re.search(r'from\s+(\w+)', code).group(1) if re.search(r'from\s+(\w+)', code) else None
        structure['to_type'] = re.search(r'to\s+(\w+)', code).group(1) if re.search(r'to\s+(\w+)', code) else None
        structure['is_functional'] = 'functional' in code
    elif component_type == 'scalar_property':
        structure['domain'] = re.search(r'domain\s+(\w+)', code).group(1) if re.search(r'domain\s+(\w+)', code) else None
        structure['range'] = re.search(r'range\s+(\w+)', code).group(1) if re.search(r'range\s+(\w+)', code) else None

    return structure

    def get_definition(self, term):
        """Retrieve a specific definition by term."""
        if term not in self.vocab_index:
            return None

        with open(self.vocab_file, 'r', encoding='utf-8') as f:
            f.seek(self.vocab_index[term])
            definition_line = f.readline()
            try:
                return json.loads(definition_line)
            except json.JSONDecodeError:
                return None

    def get_relevant_definitions(self, text):
        """Find all vocabulary terms mentioned in a text and return their definitions."""
        relevant_terms = [term for term in self.vocab_index.keys() if term in text]
        return {term: self.get_definition(term) for term in relevant_terms}

In [None]:
import json
import numpy as np
import re
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
import argparse
import os
import torch
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  # Colab-friendly progress bars
import ipywidgets as widgets
from IPython.display import display, clear_output

# Check if GPU is available (Colab usually provides GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


try:
    import sentence_transformers
    from rank_bm25 import BM25Okapi
except ImportError:
    !pip install -q sentence-transformers rank-bm25
    import sentence_transformers
    from rank_bm25 import BM25Okapi

from sentence_transformers import SentenceTransformer

# Helper function for loading JSONL data - moved outside the class
def load_jsonl_data(file_path):
    """Load JSONL data from file."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                try:
                    data.append(json.loads(line))
                except json.JSONDecodeError as e:
                    print(f"Error parsing JSON line: {e}")
                    print(f"Problematic line: {line[:100]}...")
    return data

# Helper function to truncate text
def truncate_text(text, max_length=100):
    """Truncate text to a maximum length."""
    if len(text) <= max_length:
        return text
    return text[:max_length] + "..."

class HybridRetriever:
    def __init__(self, jsonl_data):
        self.examples = jsonl_data

        print("Initializing embedding model...")
        self.embedding_model = SentenceTransformer('intfloat/multilingual-e5-large-instruct', device=device)

        print("Building dense index...")
        self.dense_index = self._build_dense_index()

        print("Building sparse index...")
        self.bm25 = self._build_sparse_index()

        print("Building component-level index...")
        self.component_index = []
        self._build_component_index()


        print("Extracting structure patterns...")
        self.structure_patterns = self._extract_structure_patterns()

        print("Retriever initialized!")

    def _build_component_index(self):
      """Build an index of individual OML components extracted from examples."""
      self.component_index = []

      print("Building component-level index...")
      for i, example in enumerate(tqdm(self.examples, desc="Extracting components")):
        # Extract components from the example
        components = self._extract_components(example)

        for component in components:
          # Create embeddings for the component code
          component_code = component['code']
          component_type = component['type']

          # Format the component for embedding
          component_text = f"{component_type}: {component_code}"

          # Create embeddings using the E5 model
          component_embedding = self.embedding_model.encode(
          f"query: {component_text}",
          convert_to_numpy=True
          )

          # Generate sparse representation for the component
          component_tokens = component_text.lower().split()

          # Extract structural patterns specific to this component
          structure = self._extract_component_structure(component)

          # Add to component index with all necessary metadata
          self.component_index.append({
                'id': f"{example.get('id', str(i))}_component_{len(self.component_index)}",
                'type': component_type,
                'code': component_code,
                'embedding': component_embedding,
                'tokens': component_tokens,
                'structure': structure,
                'parent_example': example
            })

          print(f"Extracted {len(components)} components from example {i+1}/{len(self.examples)}")

        print(f"Built component index with {len(self.component_index)} total components")

      # Create BM25 index for component-level sparse retrieval
      component_texts = [f"{c['type']}: {c['code']}" for c in self.component_index]
      tokenized_components = [text.lower().split() for text in component_texts]
      self.component_bm25 = BM25Okapi(tokenized_components)

    def _extract_component_structure(self, component):
      """Extract structural information specific to this component type."""
      code = component['code']
      component_type = component['type']

      structure = {
          'component_type': component_type,
        'has_inheritance': '<' in code,
        'has_constraints': '[' in code and ']' in code,
        'name': re.search(r'\b(\w+)', code.split()[1]).group(1) if len(code.split()) > 1 else '',
        }

      # Extract specific patterns based on component type
      if component_type == 'concept':
        structure['parent_concepts'] = re.findall(r'<\s*(\w+)', code)
        structure['restricts'] = re.findall(r'restricts\s+(\w+)', code)
      elif component_type == 'relation' or component_type == 'relation_entity':
        structure['from_type'] = re.search(r'from\s+(\w+)', code).group(1) if re.search(r'from\s+(\w+)', code) else None
        structure['to_type'] = re.search(r'to\s+(\w+)', code).group(1) if re.search(r'to\s+(\w+)', code) else None
        structure['is_functional'] = 'functional' in code
      elif component_type == 'scalar_property':
        structure['domain'] = re.search(r'domain\s+(\w+)', code).group(1) if re.search(r'domain\s+(\w+)', code) else None
        structure['range'] = re.search(r'range\s+(\w+)', code).group(1) if re.search(r'range\s+(\w+)', code) else None

      return structure

    def _build_dense_index(self):
      # Embed all examples using the E5 model
      dense_index = []
      for i, example in enumerate(tqdm(self.examples, desc="Creating embeddings")):
        input_text = example['input']
        output_text = example['output']

        # Create embeddings following E5's recommended format
        input_embedding = self.embedding_model.encode(f"query: {input_text}", convert_to_numpy=True)
        output_embedding = self.embedding_model.encode(f"passage: {output_text}", convert_to_numpy=True)

        dense_index.append({
                'id': example.get('id', str(i)),
                'input_embedding': input_embedding,
                'output_embedding': output_embedding,
                'example': example
            })
        return dense_index

    def _build_sparse_index(self):
        # Create BM25 index for lexical matching
        corpus = []
        for example in self.examples:
            # Combine input and output with emphasis on OML syntax patterns
            oml_terms = re.findall(r'(aspect|concept|scalar|relation|property|vocabulary|extends|key)',
                                  example['output'])
            enhanced_text = f"{example['input']} {example['output']} {' '.join(oml_terms)}"
            corpus.append(enhanced_text)

        tokenized_corpus = [doc.lower().split() for doc in corpus]
        return BM25Okapi(tokenized_corpus)

    def _extract_structure_patterns(self):
        # Extract structural patterns from OML code
        patterns = []
        for i, example in enumerate(self.examples):
            output = example['output']

            # Extract OML structural patterns
            structure = {
                'has_aspect': bool(re.search(r'aspect\s+\w+', output)),
                'has_concept': bool(re.search(r'concept\s+\w+', output)),
                'has_scalar': bool(re.search(r'scalar\s+\w+', output)),
                'has_relation': bool(re.search(r'relation\s+\w+', output)),
                'has_relation_entity': bool(re.search(r'relation\s+entity\s+\w+', output)),
                'inheritance_depth': len(re.findall(r'<\s*\w+', output)),
                'property_count': len(re.findall(r'property\s+\w+', output)),
                'annotation_count': len(re.findall(r'@\w+', output))
            }

            patterns.append({
                'id': example.get('id', str(i)),
                'structure': structure,
                'example': example
            })
        return patterns

    def _extract_components(self, example):
      """Extract individual OML components from an example."""
      output = example['output']
      components = []

      # Extract aspects
      aspect_matches = re.finditer(r'(aspect\s+\w+(?:\s*<[^[{]*)?(?:\s*\[[^]]*\])?)', output)
      for match in aspect_matches:
        components.append({
            'type': 'aspect',
            'code': match.group(0),
            'parent_example': example
        })

      # Extract concepts
      concept_matches = re.finditer(r'(concept\s+\w+(?:\s*<[^[{]*)?(?:\s*\[[^]]*\])?)', output)
      for match in concept_matches:
        components.append({
            'type': 'concept',
            'code': match.group(0),
            'parent_example': example
        })

      # Extract relation entities
      relation_entity_matches = re.finditer(r'(relation\s+entity\s+\w+\s*\[[^\]]*\](?:\s*<[^[{]*)?)', output)
      for match in relation_entity_matches:
        components.append({
            'type': 'relation_entity',
            'code': match.group(0),
            'parent_example': example
        })

      # Extract regular relations
      relation_matches = re.finditer(r'(relation\s+\w+\s*\[[^\]]*\](?:\s*<[^[{]*)?)', output)
      for match in relation_matches:
        if not "entity" in match.group(0).split()[1]:
            # Skip if it's a relation entity
            components.append({
                'type': 'relation',
                'code': match.group(0),
                'parent_example': example
            })

      # Extract scalar properties
      scalar_property_matches = re.finditer(r'(scalar\s+property\s+\w+\s*\[[^\]]*\])', output)
      for match in scalar_property_matches:
        components.append({
            'type': 'scalar_property',
            'code': match.group(0),
            'parent_example': example
        })

      # Extract scalar types
      scalar_type_matches = re.finditer(r'(scalar\s+\w+\s*(?:\[[^\]]*\])?)', output)
      for match in scalar_type_matches:
        if not "property" in match.group(0).split()[1]:  # Skip if it's a scalar property
            components.append({
                'type': 'scalar_type',
                'code': match.group(0),
                'parent_example': example
            })

      # Extract rules
      rule_matches = re.finditer(r'(rule\s+\w+\s*\[[^\]]*\])', output)
      for match in rule_matches:
        components.append({
            'type': 'rule',
            'code': match.group(0),
            'parent_example': example
        })

      # Extract annotation properties
      annotation_matches = re.finditer(r'(annotation\s+property\s+\w+(?:\s*<[^[{]*)?)', output)
      for match in annotation_matches:
        components.append({
            'type': 'annotation_property',
            'code': match.group(0),
            'parent_example': example
        })

      return components

    def retrieve(self, query, top_k=5, weights=None, vocab_manager=None):
      """
      Hierarchical retrieval using both example-level and component-level indexes.
      """
      if weights is None:
        weights = {
            'dense': 0.5,
            'sparse': 0.3,
            'structure': 0.2
        }

      print("\nRetrieving using hierarchical approach...")

      # FIRST TIER: Retrieve most relevant examples
      print("\nTIER 1: Retrieving most relevant examples...")
      example_results = self._retrieve_examples(query, top_k * 2)  # Get more examples than needed

      # SECOND TIER: Extract and rank components from relevant examples
      print("\nTIER 2: Retrieving relevant components...")
      component_results = self._retrieve_components(query, example_results, top_k * 3)

      # Combine and rank final results
      print("\nCombining and ranking final results...")
      combined_results = self._prepare_final_results(component_results, top_k)

      # Add relevant vocabulary if available
      if vocab_manager:
        for result in combined_results:
          # Find relevant vocabulary for this result
          relevant_vocab = vocab_manager.get_relevant_definitions(
                result['component_code'] + ' ' + result['context']
            )
          # Attach to the result
          result['relevant_vocabulary'] = relevant_vocab

      return combined_results

    def _dense_retrieve(self, query, top_k):
        query_embedding = self.embedding_model.encode(f"query: {query}")

        results = []
        for item in self.dense_index:
            # Calculate similarity to both input and output embeddings
            input_similarity = self._cosine_similarity(query_embedding, item['input_embedding'])
            output_similarity = self._cosine_similarity(query_embedding, item['output_embedding'])

            # Use the maximum similarity as the score
            similarity = max(input_similarity, output_similarity)
            results.append((item['example'], similarity))

        # Sort by similarity
        results.sort(key=lambda x: x[1], reverse=True)
        return results[:top_k]

    def _sparse_retrieve(self, query, top_k):
        # Use BM25 for lexical matching
        tokenized_query = query.lower().split()
        bm25_scores = self.bm25.get_scores(tokenized_query)

        # Pair examples with their BM25 scores
        results = [(self.examples[i], bm25_scores[i]) for i in range(len(self.examples))]

        # Sort by BM25 score
        results.sort(key=lambda x: x[1], reverse=True)
        return results[:top_k]

    def _structure_retrieve(self, query, top_k):

      # Extract structure hints from the query
      query_structure = {
        'wants_aspect': bool(re.search(r'aspect|property|attribute', query, re.I)),
        'wants_concept': bool(re.search(r'concept|class|type', query, re.I)),
        'wants_scalar': bool(re.search(r'scalar|datatype|value', query, re.I)),
        'wants_relation': bool(re.search(r'relation|relationship|connection', query, re.I)),
        'wants_vocabulary': bool(re.search(r'vocabulary|ontology|vocab', query, re.I)),
        'wants_annotation': bool(re.search(r'annotation|comment|label|description', query, re.I)),
        'wants_rule': bool(re.search(r'rule|infer|constraint', query, re.I)),
        'wants_inheritance': bool(re.search(r'inherit|extends|specializes|subclass', query, re.I)),
        'wants_restriction': bool(re.search(r'restrict|exactly|min|max|all|some', query, re.I))
    }

      # Score examples based on structural similarity
      results = []
      for pattern in self.structure_patterns:
        structure = pattern['structure']
        # Simple matching score
        score = 0
        if query_structure['wants_aspect'] and structure['has_aspect']:
          score += 1
        if query_structure['wants_concept'] and structure['has_concept']:
          score += 1
        if query_structure['wants_scalar'] and structure['has_scalar']:
          score += 1
        if query_structure['wants_relation'] and structure['has_relation']:
          score += 1
        if query_structure['wants_vocabulary'] and structure.get('has_vocabulary', False):
          score += 1
        if query_structure['wants_annotation'] and structure.get('has_annotation', False):
          score += 1
        if query_structure['wants_rule'] and structure.get('has_rule', False):
          score += 1
        if query_structure['wants_inheritance'] and structure.get('inheritance_depth', 0) > 0:
          score += 1
        if query_structure['wants_restriction'] and structure.get('has_restriction', False):
          score += 1

        # Contextual domain match (optional)
        domain_keywords = self._extract_domain_keywords(query)
        if any(kw in pattern['example'].get('tags', []) for kw in domain_keywords):
            score += 0.5

        results.append((pattern['example'], score))

      # Sort by structure match score
      results.sort(key=lambda x: x[1], reverse=True)
      return results[:top_k]

    def _extract_domain_keywords(self, query):
      # Extract domain keywords from query
      domains = {
        'pizza': ['pizza', 'food', 'topping', 'cheese', 'base'],
        'security': ['security', 'cyber', 'defense', 'attack', 'mitigation'],
        'system': ['system', 'component', 'architecture', 'interface'],
        'metamodel': ['metamodel', 'capella', 'modeling', 'uml']
    }

      keywords = []
      for domain, terms in domains.items():
        if any(term in query.lower() for term in terms):
          keywords.append(domain)
          keywords.extend(terms)

      return keywords

    def _combine_results(self, dense_results, sparse_results, structure_results, weights):
        # Combine and normalize scores
        result_map = {}

        # Process dense results
        max_dense_score = max([score for _, score in dense_results]) if dense_results else 1.0
        for example, score in dense_results:
            example_id = example.get('id', str(id(example)))
            if example_id not in result_map:
                result_map[example_id] = {'example': example, 'score': 0}
            # Normalize dense scores
            normalized_score = score / max_dense_score if max_dense_score > 0 else 0
            result_map[example_id]['score'] += normalized_score * weights['dense']

        # Process sparse results
        max_sparse_score = max([score for _, score in sparse_results]) if sparse_results else 1.0
        for example, score in sparse_results:
            example_id = example.get('id', str(id(example)))
            if example_id not in result_map:
                result_map[example_id] = {'example': example, 'score': 0}
            # Normalize BM25 scores
            normalized_score = score / max_sparse_score if max_sparse_score > 0 else 0
            result_map[example_id]['score'] += normalized_score * weights['sparse']

        # Process structure results
        max_structure_score = max([score for _, score in structure_results]) if structure_results else 1.0
        for example, score in structure_results:
            example_id = example.get('id', str(id(example)))
            if example_id not in result_map:
                result_map[example_id] = {'example': example, 'score': 0}
            # Normalize structure scores
            normalized_score = score / max_structure_score if max_structure_score > 0 else 0
            result_map[example_id]['score'] += normalized_score * weights['structure']

        # Convert to list and sort
        combined_results = [(item['example'], item['score']) for item in result_map.values()]
        combined_results.sort(key=lambda x: x[1], reverse=True)

        return combined_results

    def _cosine_similarity(self, a, b):
        return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


    def _retrieve_examples(self, query, top_k):
      """First tier retrieval of most relevant examples."""
      # Get results from each retrieval method at example level
      dense_results = self._dense_retrieve(query, top_k)
      sparse_results = self._sparse_retrieve(query, top_k)
      structure_results = self._structure_retrieve(query, top_k)

      # Show individual method results at example level
      print("\nExample Dense Retrieval Results:")
      for i, (example, score) in enumerate(dense_results[:3]):
        print(f"  {i+1}. Score: {score:.4f} - {example['input'][:100]}...")

      print("\nExample Sparse Retrieval Results:")
      for i, (example, score) in enumerate(sparse_results[:3]):
        print(f"  {i+1}. Score: {score:.4f} - {example['input'][:100]}...")

      print("\nExample Structure Retrieval Results:")
      for i, (example, score) in enumerate(structure_results[:3]):
        print(f"  {i+1}. Score: {score:.4f} - {example['input'][:100]}...")

      # Combine results at example level (using existing _combine_results method)
      example_results = self._combine_results(dense_results, sparse_results, structure_results,
                                      {'dense': 0.5, 'sparse': 0.3, 'structure': 0.2})

      return example_results

    def _retrieve_components(self, query, example_results, max_components):
      """Second tier retrieval of components from the most relevant examples."""
      component_candidates = []

      # Extract unique relevant examples - keep only top examples to limit search space
      relevant_examples = [result[0] for result in example_results[:10]]

      # 1. First collect all components from relevant examples
      for example in relevant_examples:
        example_id = example.get('id', str(id(example)))
        # Find all components from this example in the component index
        components = [c for c in self.component_index if c['parent_example'].get('id') == example_id]
        component_candidates.extend(components)

      # 2. Now rank components by relevance to query
      component_scores = []

      # Create query embedding
      query_embedding = self.embedding_model.encode(f"query: {query}")

      for component in component_candidates:
        # Dense similarity
        dense_similarity = self._cosine_similarity(query_embedding, component['embedding'])

        # Sparse (BM25) similarity
        component_text = f"{component['type']}: {component['code']}"
        tokenized_query = query.lower().split()
        component_index = component_candidates.index(component)
        sparse_score = self.component_bm25.get_scores(tokenized_query)[component_index]

        # Structure similarity (match component type to query)
        structure_score = 0
        if 'aspect' in query.lower() and component['type'] == 'aspect':
            structure_score += 1
        if 'concept' in query.lower() and component['type'] == 'concept':
            structure_score += 1
        if 'relation' in query.lower() and ('relation' in component['type']):
            structure_score += 1
        if 'property' in query.lower() and 'property' in component['type']:
            structure_score += 1
        if 'scalar' in query.lower() and 'scalar' in component['type']:
            structure_score += 1

        # Calculate weighted score
        weighted_score = (
            dense_similarity * 0.6 +
            sparse_score * 0.3 +
            structure_score * 0.1
        )

        component_scores.append((component, weighted_score))

        # Sort by score and return top components
        component_scores.sort(key=lambda x: x[1], reverse=True)
      return component_scores[:max_components]

    def _prepare_final_results(self, component_results, top_k):
      """Format final results with components and their context."""
      final_results = []

      for component, score in component_results[:top_k]:
        parent_example = component['parent_example']

      # Create a result with both component code and context from parent example
      result = {
            'score': score,
            'component_type': component['type'],
            'component_code': component['code'],
            'component_structure': component['structure'],
            'context': truncate_text(parent_example['output'], 500),
            'parent_example': {
                'id': parent_example.get('id', ''),
                'title': parent_example.get('title', 'Untitled'),
                'input': parent_example['input'],
                'tags': parent_example.get('tags', [])
            }
        }

      final_results.append(result)

      return final_results

# Interactive Colab-friendly function
def run_interactive_demo(jsonl_path):
    # Load data
    print(f"Loading data from {jsonl_path}...")
    examples = load_jsonl_data(jsonl_path)
    print(f"Loaded {len(examples)} examples")

    # Initialize retriever
    retriever = HybridRetriever(examples)

    # Create interactive widgets for Colab
    query_widget = widgets.Text(
        value='',
        placeholder='Type your query here',
        description='Query:',
        disabled=False,
        layout=widgets.Layout(width='80%')
    )

    dense_weight = widgets.FloatSlider(
        value=0.5,
        min=0,
        max=1.0,
        step=0.1,
        description='Dense:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.1f',
    )

    sparse_weight = widgets.FloatSlider(
        value=0.3,
        min=0,
        max=1.0,
        step=0.1,
        description='Sparse:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.1f',
    )

    structure_weight = widgets.FloatSlider(
        value=0.2,
        min=0,
        max=1.0,
        step=0.1,
        description='Structure:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.1f',
    )

    top_k = widgets.IntSlider(
        value=5,
        min=1,
        max=10,
        step=1,
        description='Top K:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d',
    )

    output = widgets.Output()

    def on_button_clicked(b):
        with output:
            clear_output()

        query = query_widget.value

        if not query:
            with output:
                print("Please enter a query")
            return

        # Normalize weights
        total = dense_weight.value + sparse_weight.value + structure_weight.value
        weights = {
            'dense': dense_weight.value / total,
            'sparse': sparse_weight.value / total,
            'structure': structure_weight.value / total
        }

        with output:
            print(f"Query: {query}")
            print(f"Weights: Dense={weights['dense']:.2f}, Sparse={weights['sparse']:.2f}, Structure={weights['structure']:.2f}")
            print(f"Retrieving top {top_k.value} results...")

            # Retrieve results
            results = retriever.retrieve(query, top_k.value, weights)

            # Display results
            print("\n=== COMBINED RETRIEVAL RESULTS ===")
            for i, result in enumerate(results):
              print(f"\n{i+1}. Score: {result['score']:.4f}")
              print(f"   Component Type: {result['component_type']}")
              print(f"   Title: {result['parent_example'].get('title', 'N/A')}")
              print(f"   Input: {truncate_text(result['parent_example']['input'])}")
              print(f"   Component Code: {truncate_text(result['component_code'], 150)}")

              # Format the component code for display
              component_code = result['component_code']

              # If component code is long, truncate with indicator
              if len(component_code) > 300:
                code_preview = component_code[:300] + "...[truncated]"

              else:
                code_preview = component_code
                print(f"   Component Code Preview: {code_preview}")

              # Format context for display
              context = result['context']
              print(f"   Context Preview: {truncate_text(context, 150)}")


              # Add a button to view the full output
              view_btn_id = f"view_btn_{i}"

            # Create view buttons outside the loop
            for i in range(len(results)):
                view_btn = widgets.Button(
                    description=f"View full #{i+1}",
                    layout=widgets.Layout(width='120px')
                )

                # Create a closure to capture the current index
                def make_view_handler(idx):
                  def view_handler(b):
                    with output:
                      print(f"\n=== FULL COMPONENT AND CONTEXT FOR RESULT #{idx+1} ===")
                      print(f"Component Code:\n{results[idx]['component_code']}\n")
                      print(f"Context:\n{results[idx]['context']}")
                  return view_handler

                view_btn.on_click(make_view_handler(i))
                display(view_btn)

                # Create a collapsible section for the full output
                details = widgets.Accordion(
                  children=[widgets.HTML(value=f"<pre>Component: {results[i]['component_code']}\n\nContext: {results[i]['context']}</pre>")],
                  selected_index=None
                  )
                details.set_title(0, f"Full Output for Result #{i+1}")
                display(details)

    button = widgets.Button(description="Retrieve")
    button.on_click(on_button_clicked)

    display(widgets.HTML(value="<h3>Hybrid Retrieval for OML Code Generation</h3>"))
    display(widgets.HBox([query_widget, button]))
    display(widgets.HTML(value="<h4>Retrieval Weights</h4>"))
    display(widgets.VBox([dense_weight, sparse_weight, structure_weight, top_k]))
    display(output)


# Standard main function for direct execution
def main():
    # Modified for Colab compatibility - don't use argparse in Colab
    # Check if running in Colab
    try:
        import google.colab
        from google.colab import files
        in_colab = True

        # Default values for Colab
        jsonl_path = '3.jsonl'
        top_k = 5

        # If file doesn't exist, ask user to upload
        if not os.path.exists(jsonl_path):
            print(f"File {jsonl_path} not found. Please upload it.")
            uploaded = files.upload()
            # Use the first uploaded file
            if uploaded:
                jsonl_path = next(iter(uploaded.keys()))

        # Run interactive demo in Colab
        run_interactive_demo(jsonl_path)

    except ImportError:
        # Not in Colab, use argparse for command line
        parser = argparse.ArgumentParser(description='Test hybrid retrieval for OML code generation')
        parser.add_argument('--jsonl', type=str, default='3.jsonl',
                            help='Path to JSONL file containing examples')
        parser.add_argument('--top_k', type=int, default=5,
                            help='Number of examples to retrieve')
        args = parser.parse_args()

        # Original command-line version
        # Load data
        print(f"Loading data from {args.jsonl}...")
        examples = load_jsonl_data(args.jsonl)
        print(f"Loaded {len(examples)} examples")

        # Initialize retriever
        retriever = HybridRetriever(examples)

        # Interactive query loop
        while True:
            query = input("\nEnter your query (or 'q' to quit): ")
            if query.lower() == 'q':
                break

            # Allow adjusting weights
            use_custom_weights = input("Use custom weights? (y/n, default: n): ").lower() == 'y'
            if use_custom_weights:
                try:
                    dense_weight = float(input("Dense weight (0-1, default: 0.5): ") or 0.5)
                    sparse_weight = float(input("Sparse weight (0-1, default: 0.3): ") or 0.3)
                    structure_weight = float(input("Structure weight (0-1, default: 0.2): ") or 0.2)

                    # Normalize weights
                    total = dense_weight + sparse_weight + structure_weight
                    weights = {
                        'dense': dense_weight / total,
                        'sparse': sparse_weight / total,
                        'structure': structure_weight / total
                    }
                except ValueError:
                    print("Invalid weights, using defaults")
                    weights = {'dense': 0.5, 'sparse': 0.3, 'structure': 0.2}
            else:
                weights = {'dense': 0.5, 'sparse': 0.3, 'structure': 0.2}

            # Retrieve results
            results = retriever.retrieve(query, args.top_k, weights)

            # Display results
            print("\n=== COMBINED RETRIEVAL RESULTS ===")
            for i, result in enumerate(results):
                print(f"\n{i+1}. Score: {result['score']:.4f}")
                print(f"   Component Type: {result['component_type']}")
                print(f"   Title: {result['parent_example'].get('title', 'N/A')}")
                print(f"   Input: {truncate_text(result['parent_example']['input'])}")
                print(f"   Component Code: {truncate_text(result['component_code'], 150)}")

                # Format the component code for display
                component_code = result['component_code']

                # If component code is long, truncate with indicator
                if len(component_code) > 300:
                  code_preview = component_code[:300] + "...[truncated]"

                else:
                  code_preview = component_code

                print(f"   Component Code Preview: {code_preview}")

                # Format context for display
                context = result['context']
                print(f"   Context Preview: {truncate_text(context, 150)}")

                # Option to see full output for a result
                view_full = input("\nView full output for a result (number or 'n')? ")
                if view_full.isdigit() and 1 <= int(view_full) <= len(results):
                  idx = int(view_full) - 1
                  print("\n=== FULL COMPONENT AND CONTEXT ===")
                  print(f"Component Code:\n{results[idx]['component_code']}\n")
                  print(f"Full Context:\n{results[idx]['context']}")


# For Colab compatibility: Run directly
if __name__ == "__main__":
    try:
        # First, check if we're in Colab
        import google.colab
        from google.colab import files
        print("Detected Google Colab environment. Starting interactive demo...")

        # Default jsonl path
        jsonl_path = '3.jsonl'

        # If file doesn't exist, ask user to upload
        if not os.path.exists(jsonl_path):
            print(f"File {jsonl_path} not found. Please upload it.")
            uploaded = files.upload()
            # Use the first uploaded file
            if uploaded:
                jsonl_path = next(iter(uploaded.keys()))

        # Run interactive demo directly in Colab
        run_interactive_demo(jsonl_path)
    except ImportError:
        # Not in Colab, run normal main function
        main()

Using device: cuda
Detected Google Colab environment. Starting interactive demo...
Loading data from 3.jsonl...
Loaded 97 examples
Initializing embedding model...
Building dense index...


Creating embeddings:   0%|          | 0/97 [00:00<?, ?it/s]

Building sparse index...
Building component-level index...
Building component-level index...


Extracting components:   0%|          | 0/97 [00:00<?, ?it/s]

Extracted 2 components from example 1/97
Extracted 2 components from example 1/97
Built component index with 2 total components
Extracted 3 components from example 2/97
Extracted 3 components from example 2/97
Extracted 3 components from example 2/97
Built component index with 5 total components
Extracted 4 components from example 3/97
Extracted 4 components from example 3/97
Extracted 4 components from example 3/97
Extracted 4 components from example 3/97
Built component index with 9 total components
Extracted 5 components from example 4/97
Extracted 5 components from example 4/97
Extracted 5 components from example 4/97
Extracted 5 components from example 4/97
Extracted 5 components from example 4/97
Built component index with 14 total components
Extracted 6 components from example 5/97
Extracted 6 components from example 5/97
Extracted 6 components from example 5/97
Extracted 6 components from example 5/97
Extracted 6 components from example 5/97
Extracted 6 components from example 

HTML(value='<h3>Hybrid Retrieval for OML Code Generation</h3>')

HBox(children=(Text(value='', description='Query:', layout=Layout(width='80%'), placeholder='Type your query h…

HTML(value='<h4>Retrieval Weights</h4>')

VBox(children=(FloatSlider(value=0.5, continuous_update=False, description='Dense:', max=1.0, readout_format='…

Output()