In [None]:
from tree_sitter import Language, Parser
import tree_sitter_python, tree_sitter_javascript
import os
import json
from helix import Client

PY_LANGUAGE = Language(tree_sitter_python.language())
JS_LANGUAGE = Language(tree_sitter_javascript.language())

py_parser = Parser(PY_LANGUAGE)
js_parser = Parser(JS_LANGUAGE)

client = Client(local=True, verbose=False)

In [None]:
def parse_file(file_path, parser):
    """Parse a single Python file and return its syntax tree."""
    try:
        with open(file_path, 'rb') as file:
            source_code = file.read()

        return parser.parse(source_code), source_code
    except Exception as e:
        print(f"Error parsing {file_path}: {e}")
        return None, None

def node_to_dict(node, source_code, order:int=1):
    return {
        "type": node.type,
        "start_byte": node.start_byte,
        "end_byte": node.end_byte,
        "order": order,
        "text": source_code[node.start_byte:node.end_byte].decode('utf8'),
        "children": [node_to_dict(child, source_code, i+1) for i, child in enumerate(node.children)]
    }

def scan_directory(root_path):
    folders = []
    files = []

    for entry in os.listdir(root_path):
        full_path = os.path.join(root_path, entry)
        if os.path.isdir(full_path):
            folders.append(entry)
        elif os.path.isfile(full_path):
            files.append(entry)

    return {
        "root": root_path,
        "folders": folders,
        "files": files
    }

def process_path(path: str, code_base: dict):
    dir_dict = scan_directory(path)
    current_level = code_base
    
    path_parts = [p for p in path.split(os.sep) if p]
    
    # Navigate to the correct level in the dictionary
    for part in path_parts:
        if part not in current_level:
            current_level[part] = {}
        current_level = current_level[part]
    
    # Process files
    for file in dir_dict["files"]:
        if file.endswith('.py'):
            file_path = os.path.join(path, file)
            tree, code = parse_file(file_path, py_parser)
            if tree:
                tree_dict = node_to_dict(tree.root_node, code, 1)
                current_level[file] = tree_dict
        else:
            current_level[file] = "file_content"
        del tree
        del code
        del tree_dict
    
    # Process subdirectories
    for folder in dir_dict["folders"]:
        folder_path = os.path.join(path, folder)
        if folder not in current_level:
            current_level[folder] = {}
        process_path(folder_path, code_base)
    
    return code_base

# Initialize and run
code_base = {}
test_dict = process_path("sample", code_base)
test_dict['sample']['lab00.py']

In [None]:
with open("test.json", "w") as f:
    test_json = json.dump(test_dict, f, indent=2)

In [None]:
def scan_directory(root_path):
    folders = []
    files = []

    for entry in os.listdir(root_path):
        full_path = os.path.join(root_path, entry)
        if os.path.isdir(full_path):
            folders.append(entry)
        elif os.path.isfile(full_path):
            files.append(entry)

    return {
        "folders": folders,
        "files": files
    }

def process_entities(parent_dict, parent_id):
    children = parent_dict['children']
    for entity in children:
        # Create sub entity
        entity_id = client.query('createSubEntity', {'entity_id': parent_id, 'entity_type': entity['type'], 'start_byte': entity['start_byte'], 'end_byte': entity['end_byte'], 'order': entity['order'], 'text': entity['text']})[0]['entity'][0]['id']
        
        # Recursively process sub entities in the entity
        process_entities(entity, entity_id)

def chunk_entity(text:str):
    return [text[i:i+50] for i in range(0, len(text)+1, 50)]

from random import random
def random_embedding(text:str):
    return [random() for _ in range(768)]

def populate(full_path, curr_type='root', parent_id=None):
    dir_dict = scan_directory(full_path)

    print(f'Processing {len(dir_dict["folders"])} folders')
    for folder in dir_dict["folders"]:
        print(f"Reached {folder}")
        if curr_type == 'root':
            # Create super folder
            folder_id = client.query('createSuperFolder', {'root_id': parent_id, 'name': folder})[0]['folder'][0]['id']
        else:
            # Create sub folder
            folder_id = client.query('createSubFolder', {'folder_id': parent_id, 'name': folder})[0]['subfolder'][0]['id']
        
        # Recursively populate the stuff in the folder
        populate(os.path.join(full_path, folder), curr_type='folder', parent_id=folder_id)

    print(f'Processing {len(dir_dict["files"])} files')
    for file in dir_dict["files"]:
        print(f"Reached {file}")
        if file.endswith('.py'):
            # Extract python code structure with tree-sitter
            file_path = os.path.join(full_path, file)
            tree, code = parse_file(file_path, py_parser)

            if tree:
                tree_dict = node_to_dict(tree.root_node, code, 0)
                del tree
                del code

                if curr_type == 'root':
                    # Create super file
                    file_id = client.query('createSuperFile', {'root_id': parent_id, 'name': file, 'text': tree_dict['text']})[0]['file'][0]['id']
                else:
                    # Create sub file
                    file_id = client.query('createFile', {'folder_id': parent_id, 'name': file, 'text': tree_dict['text']})[0]['file'][0]['id']

                children = tree_dict['children']
                del tree_dict

                print(f"Process {len(children)} super entities")
                for superentity in children:
                    print(f"Reached {superentity['type']}")
                    # Create super entity
                    super_entity_id = client.query('createSuperEntity', {'file_id': file_id, 'entity_type': superentity['type'], 'start_byte': superentity['start_byte'], 'end_byte': superentity['end_byte'], 'order': superentity['order'], 'text': superentity['text']})[0]['entity'][0]['id']
                    
                    # Embed super entity
                    chunks = chunk_entity(superentity['text'])
                    for chunk in chunks:
                        client.query('embedSuperEntity', {'entity_id':super_entity_id, 'vector': random_embedding(chunk)})
                        del chunk

                    del chunks

                    process_entities(superentity, super_entity_id)
                    
                    del superentity

                del children
            else:
                print(f'Failed to parse file: {file}')
                del tree
                del code
        else:
            print(f'Not python file: {file}')

    del dir_dict


root_name = os.getcwd()
root_id = client.query('createRoot', {'name': root_name})[0]['root'][0]['id']
populate(root_name, parent_id=root_id)

In [None]:
def check_gitignore():
    to_ignore = {'files': [], 'folders': []}
    with open('.gitignore', 'r') as f:
        lines = f.read()
    print(lines)

    return to_ignore

check_gitignore()

---

In [None]:
import json
from pathlib import Path
from tree_sitter import Language, Parser
import tree_sitter_python


PY_LANGUAGE = Language(tree_sitter_python.language())
py_parser = Parser(PY_LANGUAGE)

grammar = json.loads(Path("node-types.json").read_text())
grammar

In [None]:
nodes = ['module', 'block']

for parent in grammar:
    if 'subtypes' in parent:
        nodes.append(parent['type'])
        if parent['type'] == '_compound_statement':
            for child in parent['subtypes']:
                nodes.append(child['type'])
nodes

In [None]:
nodes = [
        "module",
        "block",
        "_compound_statement",
        "class_definition",
        "decorated_definition",
        "for_statement",
        "function_definition",
        "if_statement",
        "match_statement",
        "try_statement",
        "while_statement",
        "with_statement",
        "_simple_statement",
        "expression",
        "parameter",
        "pattern",
        "primary_expression"
    ]

In [None]:
def node_to_dict(node, source_code, order:int=1):
    # if node.type not in nodes:
    #     return None
    # if node.type == 'block':
    #     for child in node.children:
    #         node_to_dict(child, source_code, order)
    children = [node_to_dict(child, source_code, i+1) for i, child in enumerate(node.children)]
    children = [child for child in children if child is not None]
    return {
        "type": node.type,
        "start_byte": node.start_byte,
        "end_byte": node.end_byte,
        "order": order,
        "text": source_code[node.start_byte:node.end_byte].decode('utf8'),
        "children": children
    }

In [None]:
len([None, None])

In [None]:
code = """def main():
    print("Hello, world!")
"""

tree = py_parser.parse(code.encode("utf-8"))
tree_dict = node_to_dict(tree.root_node, code.encode("utf-8"))
tree_dict

In [None]:
with open('ingestion.py', 'rb') as file:
    ingest_code = file.read()


tree = py_parser.parse(ingest_code)
tree_dict = node_to_dict(tree.root_node, ingest_code)

In [None]:
tree_dict['children']

In [None]:
len(tree_dict['children'])

---

In [None]:
import json
from pathlib import Path
from tree_sitter import Parser, Language
import tree_sitter_rust

# Create a Language object first
RS_LANGUAGE = Language(tree_sitter_rust.language())
rs_parser = Parser(RS_LANGUAGE)


In [None]:
grammar = json.loads(Path("node-types.json").read_text())
grammar

In [None]:
nodes = []

for parent in grammar:
    if 'subtypes' in parent:
        if parent['type'] not in nodes:
            nodes.append(parent['type'])
        # if parent['type'] == '_compound_statement':
        for child in parent['subtypes']:
            if child['type'] not in nodes:
                nodes.append(child['type'])
nodes

In [None]:
def node_to_dict(node, source_code, order:int=1):
    # if node.type not in nodes:
    #     return None
    # if node.type == 'block':
    #     for child in node.children:
    #         node_to_dict(child, source_code, order)
    children = [node_to_dict(child, source_code, i+1) for i, child in enumerate(node.children)]
    children = [child for child in children if child is not None]
    return {
        "type": node.type,
        "start_byte": node.start_byte,
        "end_byte": node.end_byte,
        "order": order,
        "text": source_code[node.start_byte:node.end_byte].decode('utf8'),
        "children": children
    }

In [None]:
code = """use heed3::{types::*, Database, Env, RoTxn, RwTxn};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, collections::HashMap};

use crate::{
    helix_engine::{
        storage_core::storage_core::HelixGraphStorage,
        types::GraphError,
        vector_core::{hnsw::HNSW, vector::HVector},
    },
    protocol::value::Value,
};

const DB_BM25_INVERTED_INDEX: &str = "bm25_inverted_index"; // term -> list of (doc_id, tf)
const DB_BM25_DOC_LENGTHS: &str = "bm25_doc_lengths"; // doc_id -> document length
const DB_BM25_TERM_FREQUENCIES: &str = "bm25_term_frequencies"; // term -> document frequency
const DB_BM25_METADATA: &str = "bm25_metadata"; // stores total docs, avgdl, etc.

#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct BM25Metadata {
    pub total_docs: u64,
    pub avgdl: f64,
    pub k1: f32,
    pub b: f32,
}

#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct PostingListEntry {
    pub doc_id: u128,
    pub term_frequency: u32,
}

pub trait BM25 {
    fn tokenize<const SHOULD_FILTER: bool>(&self, text: &str) -> Vec<Cow<'_, str>>;
    fn insert_doc(&self, txn: &mut RwTxn, doc_id: u128, doc: &str) -> Result<(), GraphError>;
    fn update_doc(&self, txn: &mut RwTxn, doc_id: u128, doc: &str) -> Result<(), GraphError>;
    fn delete_doc(&self, txn: &mut RwTxn, doc_id: u128) -> Result<(), GraphError>;
    fn search(
        &self,
        txn: &RoTxn,
        query: &str,
        limit: usize,
    ) -> Result<Vec<(u128, f32)>, GraphError>;
    fn calculate_bm25_score(
        &self,
        term: &str,
        doc_id: u128,
        tf: u32,
        doc_length: u32,
        df: u32,
        total_docs: u64,
        avgdl: f64,
    ) -> f32;
}

pub struct HBM25Config {
    pub graph_env: Env,
    pub inverted_index_db: Database<Bytes, Bytes>,
    pub doc_lengths_db: Database<U128<heed3::byteorder::BE>, U32<heed3::byteorder::BE>>,
    pub term_frequencies_db: Database<Bytes, U32<heed3::byteorder::BE>>,
    pub metadata_db: Database<Bytes, Bytes>,
}

impl HBM25Config {
    pub fn new(graph_env: &Env, wtxn: &mut RwTxn) -> Result<HBM25Config, GraphError> {
        let inverted_index_db: Database<Bytes, Bytes> = graph_env
            .database_options()
            .types::<Bytes, Bytes>()
            .flags(heed3::DatabaseFlags::DUP_SORT)
            .name(DB_BM25_INVERTED_INDEX)
            .create(wtxn)?;

        let doc_lengths_db: Database<U128<heed3::byteorder::BE>, U32<heed3::byteorder::BE>> =
            graph_env
                .database_options()
                .types::<U128<heed3::byteorder::BE>, U32<heed3::byteorder::BE>>()
                .name(DB_BM25_DOC_LENGTHS)
                .create(wtxn)?;

        let term_frequencies_db: Database<Bytes, U32<heed3::byteorder::BE>> = graph_env
            .database_options()
            .types::<Bytes, U32<heed3::byteorder::BE>>()
            .name(DB_BM25_TERM_FREQUENCIES)
            .create(wtxn)?;

        let metadata_db: Database<Bytes, Bytes> = graph_env
            .database_options()
            .types::<Bytes, Bytes>()
            .name(DB_BM25_METADATA)
            .create(wtxn)?;

        Ok(HBM25Config {
            graph_env: graph_env.clone(),
            inverted_index_db,
            doc_lengths_db,
            term_frequencies_db,
            metadata_db,
        })
    }
}

impl BM25 for HBM25Config {
    fn tokenize<const SHOULD_FILTER: bool>(&self, text: &str) -> Vec<Cow<'_, str>> {
        text.to_lowercase()
            .replace(|c: char| !c.is_alphanumeric(), " ")
            .split_whitespace()
            .filter(|s| !SHOULD_FILTER || s.len() > 2)
            .map(|s| Cow::Owned(s.to_string()))
            .collect()
    }

    fn insert_doc(&self, txn: &mut RwTxn, doc_id: u128, doc: &str) -> Result<(), GraphError> {
        let tokens = self.tokenize::<true>(doc);
        let doc_length = tokens.len() as u32;

        let mut term_counts: HashMap<Cow<'_, str>, u32> = HashMap::new();
        for token in tokens {
            *term_counts.entry(token).or_insert(0) += 1;
        }

        self.doc_lengths_db.put(txn, &doc_id, &doc_length)?;

        for (term, tf) in term_counts {
            let term_bytes = term.as_bytes();

            let posting_entry = PostingListEntry {
                doc_id,
                term_frequency: tf,
            };

            let posting_bytes = bincode::serialize(&posting_entry)?;

            self.inverted_index_db
                .put(txn, term_bytes, &posting_bytes)?;

            let current_df = self.term_frequencies_db.get(txn, term_bytes)?.unwrap_or(0);
            self.term_frequencies_db
                .put(txn, term_bytes, &(current_df + 1))?;
        }
        let metadata_key = b"metadata";
        let mut metadata = if let Some(data) = self.metadata_db.get(txn, metadata_key)? {
            bincode::deserialize::<BM25Metadata>(data)?
        } else {
            BM25Metadata {
                total_docs: 0,
                avgdl: 0.0,
                k1: 1.2,
                b: 0.75,
            }
        };

        let old_total_docs = metadata.total_docs;
        metadata.total_docs += 1;
        metadata.avgdl = (metadata.avgdl * old_total_docs as f64 + doc_length as f64)
            / metadata.total_docs as f64;

        let metadata_bytes = bincode::serialize(&metadata)?;
        self.metadata_db.put(txn, metadata_key, &metadata_bytes)?;

        // txn.commit()?;
        Ok(())
    }

    fn update_doc(&self, txn: &mut RwTxn, doc_id: u128, doc: &str) -> Result<(), GraphError> {
        // For simplicity, delete and re-insert
        self.delete_doc(txn, doc_id)?;
        self.insert_doc(txn, doc_id, doc)
    }

    fn delete_doc(&self, txn: &mut RwTxn, doc_id: u128) -> Result<(), GraphError> {
        let terms_to_update = {
            let mut terms = Vec::new();
            let mut iter = self.inverted_index_db.iter(txn)?;

            while let Some((term_bytes, posting_bytes)) = iter.next().transpose()? {
                let posting: PostingListEntry = bincode::deserialize(posting_bytes)?;
                if posting.doc_id == doc_id {
                    terms.push(term_bytes.to_vec());
                }
            }
            terms
        };

        // Remove postings and update term frequencies
        for term_bytes in terms_to_update {
            // Collect entries to keep
            let entries_to_keep = {
                let mut entries = Vec::new();
                if let Some(duplicates) = self.inverted_index_db.get_duplicates(txn, &term_bytes)? {
                    for result in duplicates {
                        let (_, posting_bytes) = result?;
                        let posting: PostingListEntry = bincode::deserialize(posting_bytes)?;
                        if posting.doc_id != doc_id {
                            entries.push(posting_bytes.to_vec());
                        }
                    }
                }
                entries
            };

            // Delete all entries for this term
            self.inverted_index_db.delete(txn, &term_bytes)?;

            // Re-add the entries we want to keep
            for entry_bytes in entries_to_keep {
                self.inverted_index_db.put(txn, &term_bytes, &entry_bytes)?;
            }

            // Update document frequency
            let current_df = self.term_frequencies_db.get(txn, &term_bytes)?.unwrap_or(0);
            if current_df > 0 {
                self.term_frequencies_db
                    .put(txn, &term_bytes, &(current_df - 1))?;
            }
        }

        // Get document length before deleting it
        let doc_length = self.doc_lengths_db.get(txn, &doc_id)?.unwrap_or(0);

        self.doc_lengths_db.delete(txn, &doc_id)?;

        // Update metadata
        let metadata_key = b"metadata";
        let metadata_data = self
            .metadata_db
            .get(txn, metadata_key)?
            .map(|data| data.to_vec());

        if let Some(data) = metadata_data {
            let mut metadata: BM25Metadata = bincode::deserialize(&data)?;
            if metadata.total_docs > 0 {
                // Update average document length
                metadata.avgdl = if metadata.total_docs > 1 {
                    (metadata.avgdl * metadata.total_docs as f64 - doc_length as f64)
                        / (metadata.total_docs - 1) as f64
                } else {
                    0.0
                };
                metadata.total_docs -= 1;

                let metadata_bytes = bincode::serialize(&metadata)?;
                self.metadata_db.put(txn, metadata_key, &metadata_bytes)?;
            }
        }

        Ok(())
    }

    fn search(
        &self,
        txn: &RoTxn,
        query: &str,
        limit: usize,
    ) -> Result<Vec<(u128, f32)>, GraphError> {
        let query_terms = self.tokenize::<true>(query);
        let mut doc_scores: HashMap<u128, f32> = HashMap::with_capacity(limit);

        let metadata_key = b"metadata";
        let metadata = self
            .metadata_db
            .get(txn, metadata_key)?
            .ok_or(GraphError::New("BM25 metadata not found".to_string()))?;
        let metadata: BM25Metadata = bincode::deserialize(metadata)?;

        // For each query term, calculate scores
        for term in query_terms {
            let term_bytes = term.as_bytes();

            // Get document frequency for this term
            let df = self.term_frequencies_db.get(txn, term_bytes)?.unwrap_or(0);
            // if df == 0 {
            //     continue; // Term not in index
            // }

            // Get all documents containing this term
            if let Some(duplicates) = self.inverted_index_db.get_duplicates(txn, term_bytes)? {
                for result in duplicates {
                    let (_, posting_bytes) = result?;
                    let posting: PostingListEntry = bincode::deserialize(posting_bytes)?;

                    // Get document length
                    let doc_length = self.doc_lengths_db.get(txn, &posting.doc_id)?.unwrap_or(0);

                    // Calculate BM25 score for this term in this document
                    let score = self.calculate_bm25_score(
                        &term,
                        posting.doc_id,
                        posting.term_frequency,
                        doc_length,
                        df,
                        metadata.total_docs,
                        metadata.avgdl,
                    );

                    // Add to document's total score
                    *doc_scores.entry(posting.doc_id).or_insert(0.0) += score;
                }
            }
        }

        // Sort by score and return top results
        let mut results: Vec<(u128, f32)> = doc_scores.into_iter().collect();
        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
        results.truncate(limit);

        Ok(results)
    }

    fn calculate_bm25_score(
        &self,
        _term: &str,
        _doc_id: u128,
        tf: u32,
        doc_length: u32,
        df: u32,
        total_docs: u64,
        avgdl: f64,
    ) -> f32 {
        let k1 = 1.2;
        let b = 0.75;

        // Ensure we don't have division by zero
        let df = df.max(1);
        let total_docs = total_docs.max(1);

        // Calculate IDF: log((N - df + 0.5) / (df + 0.5))
        // This can be negative when df is high relative to N, which is mathematically correct
        let idf = ((total_docs as f64 - df as f64 + 0.5) / (df as f64 + 0.5)).ln();

        // Ensure avgdl is not zero
        let avgdl = if avgdl > 0.0 {
            avgdl
        } else {
            doc_length as f64
        };

        // Calculate BM25 score
        let tf_component = (tf as f64 * (k1 as f64 + 1.0))
            / (tf as f64 + k1 as f64 * (1.0 - b as f64 + b as f64 * (doc_length as f64 / avgdl)));

        let score = (idf * tf_component) as f32;

        // The score can be negative when IDF is negative (term appears in most documents)
        // This is mathematically correct - such terms have low discriminative power
        // But documents with higher tf should still score higher than those with lower tf
        score
    }
}
"""

tree = rs_parser.parse(code.encode("utf-8"))
tree_dict = node_to_dict(tree.root_node, code.encode("utf-8"))
tree_dict

In [None]:
for child in tree_dict['children']:
    if child['type'] not in nodes:
        print(child['type'])

In [None]:
'line_comment' in nodes