# 🧠 GPT now understands my repo like a senior dev – here's how

This notebook is an end-to-end reproduction and reinterpretation of the **CodeRAG** framework (see [paper](https://arxiv.org/pdf/2504.10046)), adapted to run locally on your own codebase.

We want to give GPT (or any LLM) the ability to:
- Parse and **understand your entire repo**
- Retrieve and reason over related code
- Answer questions like a senior developer would — with **zero fine-tuning**

---

## 🎯 Goal of this experiment

Our objective is to build a **code-aware assistant** using a combination of:

- 🔍 **Tree-sitter** to parse and structure the codebase
- 🧠 **LLMs** to describe each function or class (aka "requirements")
- 🕸️ **Code graphs** to represent dependencies (calls, imports, etc)
- 🧭 **Agentic reasoning** to let the LLM query and retrieve context dynamically
- ⚡ **RAG (Retrieval-Augmented Generation)** to reduce hallucinations and give smarter answers

The end result is a local-first, fully transparent, and extensible RAG pipeline tailored for your own project.

---

## 🧪 Inspired by CodeRAG (What we're replicating)

From the CodeRAG paper (April 2025), we aim to recreate the following innovations:

1. **Requirement Graph**  
   A graph where each node is a *natural language description* of a function or class. Edges represent semantic similarity or parent-child relations.

2. **DS-Code Graph**  
   A code graph that encodes structural dependencies like:
   - function calls
   - class inheritance
   - file/module containment
   - semantic similarity (via embeddings)

3. **BiGraph Mapping**  
   Links between requirements and code elements — allowing retrieval of relevant code given a high-level prompt.

4. **Agentic Reasoning**  
   An LLM-driven reasoning loop that dynamically:
   - queries the graph
   - follows dependencies
   - does web search if needed
   - formats and tests generated code


## Install dependencies

In [1]:
pip install tree-sitter-language-pack sentence_transformers scikit-learn pyvis langchain openai duckduckgo-search black langchain-community langchain-openai python-dotenv langgraph "langchain[openai]"

Note: you may need to restart the kernel to use updated packages.


## Parse all Python files in the repo

In [2]:
import ast, json, pandas as pd
from tree_sitter_language_pack import get_parser
from pathlib import Path

parser = get_parser("python")
ROOT   = Path("./app").resolve()
out    = []

def ts_name_and_span(src: str):
    """Return dict {name,start,end} for every function/class via Tree-sitter."""
    tree = parser.parse(src.encode())
    root = tree.root_node
    res  = []

    def walk(node):
        if node.type in ("function_definition", "class_definition"):
            name = node.child_by_field_name("name").text.decode()
            res.append((name, node.start_point[0]+1, node.end_point[0]+1))
        for c in node.named_children:
            walk(c)
    walk(root)
    return res

for py in ROOT.rglob("*.py"):
    code = py.read_text(encoding="utf-8")
    # 1️⃣ positions with tree-sitter
    spans = ts_name_and_span(code)
    # 2️⃣ docstrings with ast
    module = ast.parse(code, filename=str(py))
    for node in ast.walk(module):
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
            doc = ast.get_docstring(node) or ""
            name = node.name
            # match span (name is unique inside file)
            start, end = next((s,e) for n,s,e in spans if n == name)
            out.append({
                "type": "class_definition" if isinstance(node, ast.ClassDef) else "function_definition",
                "name": name,
                "docstring": doc,
                "start_line": start,
                "end_line": end,
                "file": str(py.relative_to(ROOT))
            })

with open("code_elements_with_docstrings.json", "w") as f:
    json.dump(out, f, indent=2)

print("✅ saved", len(out), "elements with docstrings")
pd.DataFrame(out).head()

✅ saved 12 elements with docstrings


Unnamed: 0,type,name,docstring,start_line,end_line,file
0,function_definition,root,,18,19,main.py
1,function_definition,get_token_header,Dependency to validate X-Token header,5,11,dependencies.py
2,function_definition,get_query_token,Dependency to validate token query parameter,14,20,dependencies.py
3,class_definition,Task,,7,12,routers/tasks.py
4,function_definition,get_tasks,Get all tasks with optional status filter,33,39,routers/tasks.py


## Load elements (docstrings included)

In [3]:
import json, pandas as pd, networkx as nx
from pathlib import Path

with open("code_elements_with_docstrings.json") as f:
    elements = json.load(f)

df = pd.DataFrame(elements)
df.head()

Unnamed: 0,type,name,docstring,start_line,end_line,file
0,function_definition,root,,18,19,main.py
1,function_definition,get_token_header,Dependency to validate X-Token header,5,11,dependencies.py
2,function_definition,get_query_token,Dependency to validate token query parameter,14,20,dependencies.py
3,class_definition,Task,,7,12,routers/tasks.py
4,function_definition,get_tasks,Get all tasks with optional status filter,33,39,routers/tasks.py


## Requirement Graph – build *similar_to* edges

In [4]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

texts = [e["docstring"] or "" for e in elements]
emb = model.encode(texts, normalize_embeddings=True)

RG = nx.Graph()

# add nodes
for idx, e in enumerate(elements):
    RG.add_node(f"R{idx}", **e)

# similarity edges
cos_matrix = cosine_similarity(emb)
THRESH = 0.8
for i in range(len(elements)):
    for j in range(i + 1, len(elements)):
        if cos_matrix[i, j] >= THRESH:
            RG.add_edge(f"R{i}", f"R{j}", kind="similar_to", weight=float(cos_matrix[i, j]))


In [5]:
print(f"RG: {RG.number_of_nodes():,} nodes  |  {RG.number_of_edges():,} similar_to edges")

RG: 12 nodes  |  3 similar_to edges


In [6]:
from networkx.classes.reportviews import NodeView, EdgeView
print("Nodes:", RG.number_of_nodes())
print("Edges:", RG.number_of_edges())
print("Node examples:", list(RG.nodes(data=True))[:3])

Nodes: 12
Edges: 3
Node examples: [('R0', {'type': 'function_definition', 'name': 'root', 'docstring': '', 'start_line': 18, 'end_line': 19, 'file': 'main.py'}), ('R1', {'type': 'function_definition', 'name': 'get_token_header', 'docstring': 'Dependency to validate X-Token header', 'start_line': 5, 'end_line': 11, 'file': 'dependencies.py'}), ('R2', {'type': 'function_definition', 'name': 'get_query_token', 'docstring': 'Dependency to validate token query parameter', 'start_line': 14, 'end_line': 20, 'file': 'dependencies.py'})]


### 2.2 Add *parent_child* edges (calls)

In [7]:
import ast, collections, itertools

# 1. lookup (file, name)  ->  requirement ID
name_to_rid = {
    (data["file"], data["name"]): rid
    for rid, data in RG.nodes(data=True)
}

def deepest_attr(node: ast.Attribute) -> str:
    """Return the last attribute name of a dotted call: pkg.mod.func -> func"""
    while isinstance(node, ast.Attribute):
        last = node.attr
        node  = node.value
    return last  # 'func'

for py in ROOT.rglob("*.py"):
    src   = py.read_text(encoding="utf-8")
    mod   = ast.parse(src, filename=str(py))
    file_ = str(py.relative_to(ROOT))

    # todos los defs de este archivo
    defs = {n.name: n for n in ast.walk(mod)
            if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))}

    for def_name, fn in defs.items():
        rid_caller = name_to_rid.get((file_, def_name))
        if not rid_caller:
            continue

        for call in ast.walk(fn):
            if not isinstance(call, ast.Call):
                continue

            callee_name = None
            # foo()
            if isinstance(call.func, ast.Name):
                callee_name = call.func.id
            # obj.foo()  /  pkg.mod.bar()
            elif isinstance(call.func, ast.Attribute):
                callee_name = deepest_attr(call.func)

            if callee_name:
                rid_callee = name_to_rid.get((file_, callee_name))
                if rid_callee:
                    RG.add_edge(rid_caller, rid_callee, kind="parent_child")

print(f"Parent_child edges added. Graph now has {RG.number_of_edges()} edges.")

Parent_child edges added. Graph now has 3 edges.


In [8]:
parents = [
    u for u, v, kind in RG.edges(data="kind")
    if kind == "parent_child"
][:10]

for rid in parents:
    data = RG.nodes[rid]
    print(f"{rid}  {data['file']}:{data['name']}")

### Quick graph visualisation

In [9]:
from pyvis.network import Network

net = Network(height="750px", width="100%", notebook=True, directed=False)
net.toggle_physics(True)

for n, data in RG.nodes(data=True):
    net.add_node(n, label=data["name"], title=data["docstring"][:200])

for u, v, k in RG.edges(data="kind"):
    color = "#2ca02c" if k == "parent_child" else "#9467bd"
    net.add_edge(u, v, color=color)

net.show("requirement_graph.html")

requirement_graph.html


## 3. DS-Code Graph (CG) – structure of the real code

In [10]:
import ast, networkx as nx
from pathlib import Path
from collections import defaultdict
import json

ROOT = Path("./app").resolve()

# ---------- load elements ----------
with open("code_elements_with_docstrings.json") as f:
    elements = json.load(f)

# helper: (file, name) -> code-node id
cid_map = {}
CG = nx.DiGraph()

# ---------- add code nodes ----------
for idx, el in enumerate(elements):
    cid = f"C{idx}"
    cid_map[(el["file"], el["name"])] = cid
    CG.add_node(cid, **el)

# ---------- add file nodes ----------
for el in elements:
    CG.add_node(el["file"], type="module")

# ---------- contain edges ----------
for idx, el in enumerate(elements):
    CG.add_edge(el["file"], f"C{idx}", kind="contain")

# ---------- scan every file with ast ----------
for py in ROOT.rglob("*.py"):
    file_id = str(py.relative_to(ROOT))
    src     = py.read_text(encoding="utf-8")
    tree    = ast.parse(src, filename=file_id)

    # --- import edges (file -> imported module/file) ---
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            for n in node.names:
                CG.add_edge(file_id, n.name, kind="import")  # crude, module string
        elif isinstance(node, ast.ImportFrom):
            mod = node.module or ""
            CG.add_edge(file_id, mod, kind="import")

    # --- call + inherit edges inside this file ---
    defs = {n.name: n for n in ast.walk(tree)
            if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef))}

    for def_name, obj in defs.items():
        caller_cid = cid_map.get((file_id, def_name))
        if not caller_cid:
            continue

        # inherit (classes only)
        if isinstance(obj, ast.ClassDef):
            for base in obj.bases:
                if isinstance(base, ast.Name):
                    parent_cid = cid_map.get((file_id, base.id))
                    if parent_cid:
                        CG.add_edge(caller_cid, parent_cid, kind="inherit")

        # call edges
        for call in ast.walk(obj):
            if isinstance(call, ast.Call):
                # simple cases: foo(), obj.foo()
                target_name = None
                if isinstance(call.func, ast.Name):
                    target_name = call.func.id
                elif isinstance(call.func, ast.Attribute):
                    target_name = call.func.attr
                if target_name:
                    callee_cid = cid_map.get((file_id, target_name))
                    if callee_cid:
                        CG.add_edge(caller_cid, callee_cid, kind="call")

print(f"CG built: {CG.number_of_nodes()} nodes  |  {CG.number_of_edges()} edges")

# preview a few edges
list(CG.edges(data="kind"))[:10]

CG built: 21 nodes  |  25 edges


[('main.py', 'C0', 'contain'),
 ('main.py', 'fastapi', 'import'),
 ('main.py', 'dependencies', 'import'),
 ('main.py', 'routers', 'import'),
 ('dependencies.py', 'C1', 'contain'),
 ('dependencies.py', 'C2', 'contain'),
 ('dependencies.py', 'typing', 'import'),
 ('dependencies.py', 'fastapi', 'import'),
 ('routers/tasks.py', 'C3', 'contain'),
 ('routers/tasks.py', 'C4', 'contain')]

### 3.1 Quick sanity-checks for the Code Graph

In [11]:
from collections import Counter

print("Nodes:", CG.number_of_nodes())
print("Edges:", CG.number_of_edges())

# distribución por tipo de arista
edge_kinds = Counter(k for _,_,k in CG.edges(data="kind"))
print("Edge types:", edge_kinds)

Nodes: 21
Edges: 25
Edge types: Counter({'import': 13, 'contain': 12})


#### 3.2  Add `similar_to` edges inside the Code Graph

In [12]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from pathlib import Path
import textwrap

THRESH_CODE_SIM = 0.80          # similarity threshold
EMBED_MODEL      = "sentence-transformers/paraphrase-mpnet-base-v2"

# 1️⃣ gather code-node ids (functions & classes only)
code_nodes = [
    n for n, d in CG.nodes(data=True)
    if d.get("type") in ("function_definition", "class_definition")
]

# 2️⃣ extract raw source text for each node
def get_source(node_data):
    file_path = Path("./app") / node_data["file"]
    lines = file_path.read_text(encoding="utf-8").splitlines()
    return "\n".join(lines[node_data["start_line"]-1 : node_data["end_line"]])

corpus = [get_source(CG.nodes[n]) for n in code_nodes]

# 3️⃣ embed & normalise
model_code = SentenceTransformer(EMBED_MODEL)
embeddings = model_code.encode(corpus, normalize_embeddings=True)

# 4️⃣ compute cosine matrix (small n, so brute force is fine)
cos_mat = cosine_similarity(embeddings)

# 5️⃣ add edges
added = 0
for i in range(len(code_nodes)):
    for j in range(i+1, len(code_nodes)):
        if cos_mat[i, j] >= THRESH_CODE_SIM:
            CG.add_edge(code_nodes[i], code_nodes[j],
                        kind="similar_to",
                        weight=float(cos_mat[i, j]))
            added += 1

print(f"➕ Added {added} code 'similar_to' edges – CG now has {CG.number_of_edges()} edges")

➕ Added 4 code 'similar_to' edges – CG now has 29 edges


In [13]:
# show a few similarity pairs
examples = [
    (u, v, CG.edges[u, v]["weight"])
    for u, v, k in CG.edges(data="kind") if k == "similar_to"
][:5]

for u, v, w in examples:
    print(f"{CG.nodes[u]['name']}  ≃  {CG.nodes[v]['name']}   (cos={w:.2f})")

get_token_header  ≃  get_query_token   (cos=0.80)
Task  ≃  Agent   (cos=0.92)
get_task  ≃  get_agent_tasks   (cos=0.87)
get_agent_tasks  ≃  get_agent   (cos=0.85)


### 4. ID map  (Requirement  ↔  Code)  –  the Bigraph

In [14]:
# 1-to-1 map
id_map = {f"R{i}": f"C{i}" for i in range(len(elements))}

IDMAP = id_map  

# guardamos como atributo cruzado
for rid, cid in id_map.items():
    RG.nodes[rid]["code_id"] = cid
    CG.nodes[cid]["req_id"]  = rid
print("Bigraph mapping added:", len(id_map), "links")

Bigraph mapping added: 12 links


### 💾 Persist graphs to disk

We’ll save both graphs in **GraphML** format so they can be:

* Reloaded later in NetworkX without rebuilding.
* Imported into Neo4j (via `neo4j-admin import`) or visual tools like Gephi.

Files created:

* `requirement_graph.graphml`
* `code_graph.graphml`

In [15]:
# ⬇️ EJECUTA ESTO DESPUÉS de construir RG, CG e id_map
import numpy as np, pickle, networkx as nx

# 1. Backup en formato NetworkX (GraphML) — útil para Neo4j, Gephi, etc.
nx.write_graphml(RG, "requirement_graph.graphml")
nx.write_graphml(CG, "code_graph.graphml")

# 2. Backup “todo-en-uno” con NumPy (rápido para recargar en Jupyter)
np.savez("graphs.npz", RG=RG, CG=CG, id_map=id_map)
print("✅ graphs.npz escrito 🗂️")

✅ graphs.npz escrito 🗂️


# 5. CodeRAG Agent (local, ReAct style)

## 5.1 Tools  (GraphReason · WebSearch · CodeTest)

In [16]:
# ✅ Dependencies
from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import InMemorySaver
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
from pydantic import BaseModel
import os

# ✅ Load .env
load_dotenv()

# ✅ Init LLM using the env key
llm = ChatOpenAI(model="gpt-4o", temperature=0.2)

In [17]:
# ✅ Tools (GraphReason, WebSearch, CodeTest)
from langchain.agents import Tool
from pathlib import Path
import pathlib, subprocess, tempfile, black
from duckduckgo_search import DDGS

# Assume RG is already defined (networkx DiGraph or MultiDiGraph)
def graph_reason(rid:str, n_sim=3, n_child=3):
    sims, childs = [], []
    for _, v, k in RG.edges(rid, data="kind"):
        if k=="similar_to" and len(sims) < n_sim: sims.append(RG.nodes[v])
        if k=="parent_child" and len(childs) < n_child: childs.append(RG.nodes[v])
    return {"similar": sims, "child": childs}

def web_search(q:str, k=3):
    with DDGS() as ddgs:
        results = [r for r in ddgs.text(q, max_results=k)]
    return "\n".join(f"{r['title']}: {r['body']}" for r in results if 'body' in r)

def code_test(code:str):
    tmp = pathlib.Path(tempfile.mkstemp(suffix=".py")[1])
    tmp.write_text(code, encoding="utf-8")
    black.format_file_in_place(tmp, fast=True, mode=black.FileMode())
    proc = subprocess.run(["python", "-m", "py_compile", str(tmp)], capture_output=True, text=True)
    return "\u2705 compiled ok" if proc.returncode == 0 else proc.stderr[:300]

def show_source(node):
    file = node.get("file")
    start = node.get("start_line")
    end = node.get("end_line")
    if not file or not start or not end:
        return "File or line info missing."
    lines = Path(file).read_text(encoding="utf-8").splitlines()
    return "\n".join(lines[start-1:end])
    
def code_graph_lookup(func_name: str):
    # Busca en los nodos del grafo de código
    matches = [v for _, v, d in CG.edges(data=True)
               if CG.nodes[v].get("name") == func_name]

    if not matches:
        return f"No match found for function `{func_name}`."

    # Asume que es único (o devuelve el primero)
    node_id = matches[0]
    node_data = CG.nodes[node_id]

    # Encuentra llamadas entrantes y salientes
    calls = {
        "calls": [CG.nodes[v] for _, v in CG.out_edges(node_id)],
        "called_by": [CG.nodes[u] for u, _ in CG.in_edges(node_id)],
    }

    # Empaqueta respuesta
    return {
        "function": node_data,
        **calls
    }

tools = [
    Tool(name="LookupInGraph", func=graph_reason, description="Input: Requirement id (e.g. R42) -> similar and child requirements"),
    Tool(name="WebSearch", func=web_search, description="Input: string query -> DuckDuckGo search result titles and snippets"),
    Tool(name="CodeTest", func=code_test, description="Input: Python code string -> formats with black, then compiles to check for syntax errors")
]

tools.append(
    Tool(
        name="CodeGraphLookup",
        func=code_graph_lookup,
        description="Input: function name (e.g. 'process_data') -> gets the function's docstring, file, and connected calls in the Code Graph."
    )
)


## 5.2 Build the ReAct agent

In [18]:
prompt = """You are **CodeRAG-Agent**, an expert developer assistant designed to help understand and reason about a codebase and its requirements.

You have access to two graphs:
1. **Requirement Graph (RG)** – a graph where each node is a software requirement (e.g. R0, R1, R42), connected by semantic relationships such as:
   - `similar_to`: nodes that express related or overlapping functionality.
   - `parent_child`: hierarchical breakdown of features or sub-requirements.

2. **Code Graph (CG)** – a graph where each node is a function or class in the codebase (e.g. C0, C1). Nodes include:
   - Metadata: name, type, file path, line range, docstring.
   - Edges represent `calls` and `called_by` relationships between functions/classes.

---

Your job is to help the user understand:
- What a requirement means and which code is relevant to it.
- What a function or class does, how it connects to others, and its implementation details.
- How different requirements or code elements relate structurally or semantically.

---

### 🛠️ Tools at your disposal

1. **LookupInGraph**
   - Input: requirement id (e.g. "R42").
   - Use this to explore `similar` and `child` requirements.
   - Always call this FIRST when a requirement is mentioned.

2. **CodeGraphLookup**
   - Input: function or class name (e.g. "process_data").
   - Use this to retrieve metadata from the Code Graph (CG), including file, lines, docstring, and related functions (calls/called_by).
   - Use this for all code-related questions.

3. **WebSearch**
   - Input: any query string.
   - Use only if the requirement or function lacks context in the graphs.

4. **CodeTest**
   - Input: Python code string.
   - Use this to auto-format the code and test it for syntax errors (compilation via `black` and `py_compile`).

---

### 🧠 Strategy for every query

- If the user mentions a requirement (e.g. "R42"), always call **LookupInGraph** first.
- If the user asks about a class or function, use **CodeGraphLookup**.
- Extract context from the graphs before answering.
- If context is still insufficient, optionally use **WebSearch** to enrich your answer.
- If you produce or edit Python code, test it with **CodeTest** before showing the final version.
- Always return the final answer as clear text, optionally including relevant code snippets or summaries.

Do not make assumptions. Always ground your answers in the graph context or tool outputs.
If a tool fails or the input isn't found, explain what you tried and ask the user for clarification if needed.
"""

# ✅ LangGraph checkpointing (in-memory)
checkpointer = InMemorySaver()

# ✅ Create the agent
agent = create_react_agent(
    model=llm,
    tools=tools,
    prompt=prompt,
    checkpointer=checkpointer,
)


## 4.3 Ask the agent

In [19]:
def display_agent_response(response):
    messages = response.get("messages", [])
    for msg in messages:
        role = type(msg).__name__.replace("Message", "")
        content = getattr(msg, "content", "")
        name = getattr(msg, "name", None)

        if name:
            print(f"🛠️ Tool ({name}):\n{content}\n")
        else:
            prefix = "🤖 AI" if role == "AI" else "🧑 Human"
            print(f"{prefix}:\n{content}\n")

In [20]:
response = agent.invoke(
    {"messages": [{"role": "user", "content": "how i can us the function get agent"}]},
    config={"configurable": {"thread_id": "session-001"}}
)

display_agent_response(response)

🧑 Human:
how i can us the function get agent

🤖 AI:


🛠️ Tool (CodeGraphLookup):
{"function": {"type": "function_definition", "name": "get_agent", "docstring": "Get a specific agent by ID", "start_line": 40, "end_line": 47, "file": "routers/agents.py", "req_id": "R10"}, "calls": [], "called_by": [{"type": "module"}, {"type": "function_definition", "name": "get_agent_tasks", "docstring": "Get all tasks assigned to a specific agent", "start_line": 54, "end_line": 61, "file": "routers/tasks.py", "req_id": "R6"}]}

🤖 AI:
The function `get_agent` is defined in the file `routers/agents.py`, from lines 40 to 47. Its purpose is to "Get a specific agent by ID". 

### Usage
To use the `get_agent` function, you would typically call it with the ID of the agent you want to retrieve. However, the specific parameters and return values are not detailed in the docstring, so you would need to look at the function's implementation in the file for more details.

### Connections
- **Called By**: The `get_a