In [None]:
from dotenv import load_dotenv
import os


load_dotenv()
api_key = os.getenv('GROQ_API_KEY')


In [None]:
from git import Repo
import stat, os, shutil

def get_repo(repo_url: str, path_url: str = './temp_repo'):

    def remove_readonly(func, path, exc_info):
        os.chmod(path, stat.S_IWRITE)
        func(path)
    
    if os.path.exists(path_url):
        shutil.rmtree(path_url, onexc = remove_readonly)
    
    print('Cloning the repo...')
    
    Repo.clone_from(repo_url, path_url, depth = 1)
    print('Clone completed !!!')


In [None]:
from tree_sitter import Parser, Language
from pathlib import Path
import tree_sitter_python as tspython
import tree_sitter_javascript as tsjavascript
import tree_sitter_typescript as tstypescript

In [None]:
from tree_sitter import Parser, Language
import tree_sitter_python as tspython
import tree_sitter_javascript as tsjavascript
import tree_sitter_typescript as tstypescript
import tree_sitter_c as tsc
import tree_sitter_cpp as tscpp

# Look up table. To setup the language grammar for the tree sitter.
LANGUAGES = { 
    ".py": Language(tspython.language()),
    ".js": Language(tsjavascript.language()),
    ".jsx": Language(tsjavascript.language()),
    ".ts": Language(tstypescript.language_typescript()),
    ".tsx": Language(tstypescript.language_tsx()),
    
    
    ".c": Language(tsc.language()),
    ".h": Language(tsc.language()), 
    ".cpp": Language(tscpp.language()),
    ".hpp": Language(tscpp.language()),
    ".cc": Language(tscpp.language()),
    ".cxx": Language(tscpp.language()),
    ".hh": Language(tscpp.language()),
}

# To specify the common name to the functions and classes for all languages.
NODE_TYPES = {
    ".py": {
        "function": "function_definition",
        "class": "class_definition",
    },
    ".js": {
        "function": "function_declaration",
        "method": "method_definition",
        "arrow": "arrow_function",
        "class": "class_declaration",
    },
    ".ts": {
        "function": "function_declaration",
        "method": "method_definition",
        "arrow": "arrow_function",
        "class": "class_declaration",
        "interface": "interface_declaration"
    },
    ".tsx": {
        "function": "function_declaration",
        "method": "method_definition",
        "arrow": "arrow_function",
        "class": "class_declaration",
    },
    
  
    ".c": {
        "function": "function_definition",
        "struct": "struct_specifier", 
        "typedef": "type_definition"
    },
    ".h": {
        "function": "function_definition",
        "struct": "struct_specifier",
        "typedef": "type_definition"
    },

   
    ".cpp": {
        "function": "function_definition",
        "class": "class_specifier",
        "struct": "struct_specifier",
        "template": "template_declaration"
    },
    ".hpp": {
        "function": "function_definition",
        "class": "class_specifier",
        "struct": "struct_specifier",
        "template": "template_declaration"
    },

    ".cc": { "function": "function_definition", "class": "class_specifier" },
    ".cxx": { "function": "function_definition", "class": "class_specifier" },
}

In [None]:
def get_node_name(node, source):
    
    name_node = node.child_by_field_name('name')
    
   
    if not name_node:
        declarator = node.child_by_field_name('declarator')
        
       
        while declarator:
            
            if declarator.type in ['identifier', 'field_identifier', 'type_identifier']:
                name_node = declarator
                break
            
            
            next_decl = declarator.child_by_field_name('declarator')
            if next_decl:
                declarator = next_decl
            else:
              
                name_node = declarator
                break

    if name_node:
        return source[name_node.start_byte:name_node.end_byte].decode('utf8', errors='replace')
    
    return None

def walk(node, results, ext, source):
   
    node_map = NODE_TYPES.get(ext, {})
    
    target_type = None
    
    # Check if this node matches one of our target types (function, class, etc.)
    for common_name, type_name in node_map.items():
        if node.type == type_name:
            target_type = common_name
            break
        
    if target_type:
       
        name = get_node_name(node, source)
        
        # Fallback if still not found
        if not name:
             name = f"<anonymous_{target_type}_L{node.start_point[0]}>"
             
        code = source[node.start_byte:node.end_byte].decode('utf8', errors="replace")
        
        results.append({
            'type': target_type,
            'name': name,
            'code': code,
            'start_line': node.start_point[0] + 1,
            'end_line': node.end_point[0] + 1
        })

    # Continue traversing children
    for child in node.children:
        walk(child, results, ext, source)

In [None]:
def parse_file(path_url):
    """Reads the files and returns the chunks such functions/classes"""
    ext = Path(path_url).suffix
    
    if ext not in LANGUAGES:
        return []

    try:
        tree_parser = Parser(LANGUAGES[ext])
        source = Path(path_url).read_bytes()
        tree = tree_parser.parse(source)
        
        results = []
        # print_tree(tree.root_node, source)
        walk(tree.root_node, results, ext, source)
        return results
    
    except Exception as e:
        print (e)
        return []

In [None]:
IGNORE_DIRS = {
    ".git",
    "__pycache__",
    "node_modules",
    "venv",
    ".venv",
    "env",
    "dist",
    "build",
    ".idea",
    ".vscode",
}

In [None]:
import networkx as nx, matplotlib.pyplot as plt

graph = nx.MultiDiGraph()
function_map = {}
all_chunks = []
global_bm25 = ''

In [None]:
def build_graph(repo_path:str):
    print('Building graph for your code base...')
    # print(repo_path)
    for root, dir, files in os.walk(repo_path):
        dir[:] = [d for d in dir if d not in IGNORE_DIRS]
        
        for file in files: 
            full_path = os.path.join(root, file)
            print(full_path)
            
            chunks = parse_file(full_path)
            
            for chunk in chunks:
                func_name = chunk['name']
                node_id = full_path + '::' +func_name
                chunk['file_path'] = full_path
                graph.add_node(
                    node_id,
                    name = chunk['name'],
                    code = chunk['code'],
                    type = chunk['type']
                )
                
                function_map[func_name] = node_id
                all_chunks.append(chunk)
                
        print('Linking the function calls...')
        print(all_chunks)
        
        for chunk in all_chunks:
            caller_name = chunk['name']
            caller_code = chunk['code']
            
            for target_name, target_id in function_map.items():
                if (target_name == caller_name): continue

                if target_name in caller_code:
                    graph.add_edge(caller_name, target_id)
                    print(f"   -> {caller_name} calls {target_id}")
                    
    print('The Graph has been built successFully !!')          
    

In [None]:
# build_graph(repo_path='./temp_repo')
# nx.draw_forceatlas2(graph, with_labels=True)
# plt.show()

In [None]:
import torch
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

#Load the Model (Base Code)
model_id = "jinaai/jina-embeddings-v2-base-code" 
print(f"Loading model: {model_id}...")

model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
model.to(device)
model.eval()

print("Base Code Model loaded successfully!")

In [None]:
def get_embeddings(chunks):
    if not chunks:
        return []
    
    with torch.no_grad():
        embeddings = model.encode(chunks)
        
    return embeddings

In [None]:
import chromadb

chromadb_client = chromadb.PersistentClient('./chromadb')
collection = chromadb_client.get_or_create_collection('codebase-vectors')
print ('Chromadb\'s setup successfully !!')

In [None]:
from rank_bm25 import BM25Okapi

# We need a place to map the BM25 results back to the original chunk data
bm25_mapping = {} 

def create_bm25_index(chunks):
    """
    Input: List of code chunk dictionaries
    Action: Builds a searchable keyword index
    """
    tokenized_corpus = []
    
    print(f"ðŸ”¤ Building BM25 Index for {len(chunks)} items...")
    
    for i, chunk in enumerate(chunks):
        # Search against the function NAME and the CODE
        text = chunk['name'] + " " + chunk['code']
        
        
        tokens = text.lower().split()
        tokenized_corpus.append(tokens)
        
       
        bm25_mapping[i] = chunk
        
    # Create the Index
    bm25 = BM25Okapi(tokenized_corpus)
    return bm25


In [None]:
def index_codebase(chunks, vectors):
   
    print(f" Indexing {len(chunks)} items...")
    
    ids = [str(i) for i in range(len(chunks))] 
    metadatas = []
    documents = [] 
    
    for chunk in chunks:
        metadatas.append({
            "name": chunk['name'],
            "file_path": chunk['file_path'],
            "type": chunk['type'],
           
            "start_line": int(chunk['start_line']),
            "end_line": int(chunk['end_line'])
        })
        documents.append(chunk['code'])

    collection.add(
        ids=ids,
        embeddings=vectors,
        metadatas=metadatas,
        documents=documents
    )
    print("   -> Vector Store Updated.")
    
    bm25_index = create_bm25_index(chunks)
    print("   -> Keyword Index Built.")
    
    return bm25_index



In [None]:
import networkx as nx

def ingest_repository(github_url):
    """
    Master function to load a new repo and prepare the RAG system.
    """
    print(f"\n{'='*50}")
    print(f"INITIALIZING NEW REPO: {github_url}")
    print(f"{'='*50}")

    get_repo(github_url, path_url='./temp_repo')

  
    global graph, function_map, all_chunks
    graph = nx.MultiDiGraph()
    function_map = {}
    all_chunks = []
    
   
    build_graph(repo_path='./temp_repo')
    
    if not all_chunks:
        print("Warning: No code found! Check if the repo path is correct.")
        return

   
    print(f"Embedding {len(all_chunks)} functions...")
    codes = [c['code'] for c in all_chunks]
    vectors = get_embeddings(codes)

    global global_bm25 
    global_bm25 = index_codebase(all_chunks, vectors)

    print(f"\nSUCCESS! Repository loaded. The Agent is ready to answer questions.")

In [None]:
def deduplicate_results(results_a, results_b):
    
    seen_ids = set()
    merged = []
    
    # Combine both lists
    all_results = results_a + results_b
    
    for chunk in all_results:
        # Create a unique signature for this chunk
        unique_id = f"{chunk['file_path']}::{chunk['name']}"
        
        if unique_id not in seen_ids:
            merged.append(chunk)
            seen_ids.add(unique_id)
            
    return merged

In [None]:
def hybrid_search(query, k = 5):
    print (f'Searching for the given {query}')
    
    #1) Vector search
    query_vector = get_embeddings([query])[0]
    chroma_results = collection.query(
        query_embeddings=query_vector,
        n_results=k
    )
    
    #formating the chromo results
    vector_hits = []
    if chroma_results['ids']:
        for i in range (len(chroma_results["ids"][0])):
            meta = chroma_results['metadatas'][0][i]
            doc = chroma_results['documents'][0][i]
            vector_hits.append({
                "name": meta['name'],
                "file_path": meta['file_path'],
                "code": doc,
                "score": chroma_results['distances'][0][i],
                "source": "vector"
            })
    
    # 2) Keyword Search (BM25)
    tokenized_query = query.lower().split()
    bm25_scores = global_bm25.get_scores(tokenized_query)
    top_n = sorted(range(len(bm25_scores)), key=lambda i: bm25_scores[i], reverse=True)[:k]
    
    keyword_hits = []
    for idx in top_n:
        # Filter out results with 0 score (no keyword match)
        if bm25_scores[idx] > 0:
            original_chunk = bm25_mapping[idx] # Look up the original data
            keyword_hits.append({
                "name": original_chunk['name'],
                "file_path": original_chunk['file_path'],
                "code": original_chunk['code'],
                "score": bm25_scores[idx],
                "source": "keyword"
            })
    final_results = deduplicate_results(vector_hits, keyword_hits)
    
    return final_results
    

In [None]:
def expand_context(initial_results, graph, max_depth=1):
    expanded_results = []
    seen = set()
   
    for x in initial_results:
        node_id = f"{x['file_path']}::{x['name']}"
        if node_id not in seen:
            expanded_results.append(x)
            seen.add(node_id)
    
    
    for x in initial_results:
        node_id = f"{x['file_path']}::{x['name']}"
        
        if node_id in graph:
            neighbors = list(graph.successors(node_id))
            if neighbors:
                print(f" '{x['name']}' calls: {len(neighbors)} dependencies")
                
                for neighbor_id in neighbors:
                    if neighbor_id not in seen:
                        data = graph.nodes[neighbor_id]
                        name = neighbor_id.split('::')[-1]
                        
                        neighbor_chunk = {
                            "name": name,
                            "file_path": data.get('file', 'unknown'),
                            "code": data.get('code', ''),
                            "type": "dependency", 
                            "score": 0.0,
                            "source": "graph"
                        }
                    expanded_results.append(neighbor_chunk)
                    seen.add(neighbor_id)
    print(f"Context grew from {len(initial_results)} to {len(expanded_results)} chunks.")
    return expanded_results

In [None]:
from sentence_transformers import CrossEncoder


reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")


def rerank_results(query: str,chunks: list,top_k: int = 5,max_rerank: int = 200):
    """
    Input:
      - query: User question
      - chunks: Results after Hybrid Search + Graph Expansion
      - top_k: Final number of chunks to keep
      - max_rerank: Safety cap for CrossEncoder (performance + quality)

    Output:
      - Top-k chunks sorted by rerank_score
    """

    print(f"Reranking {len(chunks)} chunks...")

    if not chunks:
        return []

  
    candidates = chunks[:max_rerank]

 
    pairs = []
    for c in candidates:
       
        structured_code = (
            f"Type: {c.get('type', 'code')}\n"
            f"File: {c.get('file_path', 'unknown')}\n\n"
            f"{c.get('code', '')}"
        )
        pairs.append([query, structured_code])


    scores = reranker.predict(pairs)

    for i, score in enumerate(scores):
        final_score = float(score)

       
        if candidates[i].get("source") == "graph":
            final_score *= 0.9

        candidates[i]["rerank_score"] = final_score


    candidates.sort(key=lambda x: x["rerank_score"],reverse=True)

    final_results = candidates[:top_k]

    print(
        f"Kept top {len(final_results)} chunks "
        f"(Best Score: {final_results[0]['rerank_score']:.4f})"
    )

    return final_results


In [None]:
import os
from groq import Groq

client = Groq(
    api_key=api_key,
)

def route_question(state):
    question = state["question"]
    print(f"Routing Query: '{question}'")

    system_prompt = """
    You are an intelligent router for a RAG system.
    
    CRITICAL RULE:
    If the user asks a technical question (about code, logic, functions, files) mixed with a greeting, YOU MUST CHOOSE 'retrieve'.
    
    Examples:
    - "Hi, how does login work?" -> retrieve (Technical intent exists)
    - "Hello!" -> direct_answer (Pure greeting)
    - "Thanks, but where is the main file?" -> retrieve (Technical intent exists)
    - "What is the weather?" -> direct_answer (Not about code)
    
    Answer with ONE WORD ONLY: 'retrieve' or 'direct_answer'.
    """

    try:
        response = client.chat.completions.create(
            model="llama-3.1-8b-instant", 
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": question}
            ],
            temperature=0, 
            max_tokens=10 
        )

        decision = response.choices[0].message.content.strip().lower()
        
        if "retrieve" in decision:
            print(f" Decision: RETRIEVE (Database Search)")
            return "retrieve"
        else:
            print(f" Decision: DIRECT ANSWER (Chitchat)")
            return "direct_answer"
            
    except Exception as e:
        print(f"Router Error: {e}")
        return "retrieve"


In [None]:
def grade_documents(state):
    print("GRADER: Filtering documents...")
    
    question = state["question"]
    documents = state["documents"]
    kept_docs = []


    for doc in documents:
        prompt = f"""
    You are a strict relevance grader for a codebase question-answering system.

    TASK:
    Decide whether the given code snippet is useful for answering the user's question.

    USER QUESTION:
    {question}

    CODE SNIPPET:
    {doc['code'][:1000]}

    DECISION RULES:
    - Answer "yes" ONLY if the code directly helps explain, implement, or understand the question.
    - Answer "no" if the code is unrelated, generic, or does not help answer the question.
    - Do NOT explain your decision.
    - Do NOT add extra text.

    Answer ONLY one word: yes or no

    """
        response = client.chat.completions.create(
            model="llama-3.1-8b-instant",
            messages=[{"role": "user", "content": prompt}],
            temperature=0
        )

        decision = response.choices[0].message.content.strip().lower()

        if "yes" in decision:
            print(f"   KEEP: {doc['name']}")
            kept_docs.append(doc)
        else:
            print(f"   DROP: {doc['name']}")

    return {
        "documents": kept_docs,
        "question": question
    }


In [None]:
from typing import TypedDict, List

class GraphState(TypedDict):
    """
    Represents the state of our graph.
    
    Attributes:
        question: The user's original question
        generation: The final answer
        documents: List of code chunks found so far
        search_count: Safety counter to prevent infinite loops
    """
    question: str
    generation: str
    documents: List[dict] 
    search_count: int

In [None]:
from langgraph.graph import END, StateGraph

def generate(state):
    """
    Generate answer using the filtered documents.
    """
    print("GENERATOR: Writing final answer...")
    question = state["question"]
    documents = state["documents"]
    
    # Format context for the LLM
    context_text = ""
    if documents:
        for doc in documents:
            context_text += f"\nFile: {doc.get('file_path', 'unknown')}\nCode:\n{doc['code']}\n{'-'*20}"
    else:
        context_text = "No relevant code found."
        
    prompt = f"""
    You are a Senior Software Engineer. Use the provided code context to answer the user's question.
    
    RULES:
    1. Base your answer ONLY on the context.
    2. Cite file names and function names.
    3. Be concise and technical.
    4. If the code doesn't answer the question, admit it.
    
    QUESTION: {question}
    
    CONTEXT:
    {context_text}
    """
    
    response = client.chat.completions.create(
        model="llama-3.3-70b-versatile", 
        messages=[{"role": "user", "content": prompt}],
        temperature=0.2
    )
    
    return {"generation": response.choices[0].message.content}


def retrieve_node(state):
    """
    Orchestrates the Retrieval Pipeline: Search -> Expand -> Rerank
    """
    question = state["question"]
    print(f"RETRIEVER: Processing '{question}'")
    
    # 1. Search (Hybrid)
    raw_docs = hybrid_search(question, k=5)
    
    # 2. Expand (Graph)
    expanded_docs = expand_context(raw_docs, graph)
    
    # 3. Rerank (Cross Encoder)
    reranked_docs = rerank_results(question, expanded_docs, top_k=5, max_rerank=50)
    
    return {"documents": reranked_docs}

def direct_answer_node(state):
    """
    Handles simple chitchat.
    """
    print("DIRECT ANSWER: Handling chitchat...")
    question = state["question"]
    
    response = client.chat.completions.create(
        model="llama-3.1-8b-instant",
        messages=[
            {"role": "system", "content": "You are a helpful coding assistant. Answer the user's greeting or general question politely."},
            {"role": "user", "content": question}
        ]
    )
    return {"generation": response.choices[0].message.content}

# --- 3. Build the LangGraph ---

print("Building Graph...")
workflow = StateGraph(GraphState)

# Add Nodes
workflow.add_node("retrieve", retrieve_node)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("direct_answer", direct_answer_node)

# Add Conditional Entry Point (The Router)
workflow.set_conditional_entry_point(
    route_question,
    {
        "retrieve": "retrieve",
        "direct_answer": "direct_answer",
    },
)

# Add Edges (The Flow)
workflow.add_edge("retrieve", "grade_documents")

# 2. After Grading, go to Generation
workflow.add_edge("grade_documents", "generate")

# 3. End points
workflow.add_edge("generate", END)
workflow.add_edge("direct_answer", END)

# Compile
app = workflow.compile()
print("Agent Compiled Successfully!")

In [None]:
# Function to run the agent nicely
def ask_agent(query):
    inputs = {"question": query}
    
    print(f"\n{'='*40}")
    print(f"USER QUERY: {query}")
    print(f"{'='*40}")
    
    final_answer = ""
    
    # Stream the steps so we see what's happening
    for output in app.stream(inputs):
        for key, value in output.items():
            print(f"   Shape Shift -> Finished Node: {key}")
            if "generation" in value:
                final_answer = value["generation"]

    print(f"\nFINAL ANSWER:\n{final_answer}")

In [None]:

url = input("Enter GitHub URL: ")
ingest_repository(url) 


print("System Ready! Ask away...")

while True:
    user_input = input("User: ")
    if user_input.lower() in ["quit", "exit"]:
        break
        
    ask_agent(user_input)