# Graph-RAG with Gemma 3 1B IT (Colab)

This notebook demonstrates end-to-end retrieval that merges semantic neighbors (Chroma) with call-graph context (Neo4j via Joern CPG).

Prereqs:
- Accept the Hugging Face license for `google/gemma-3-1b-it`.
- Have a Neo4j instance with Joern CSVs ingested.
- Have a Chroma collection built from the repo.



In [None]:
pip -q install transformers accelerate torch neo4j chromadb sentence-transformers


In [None]:
import os
from neo4j import GraphDatabase
import chromadb
from chromadb.config import Settings
from transformers import AutoTokenizer, AutoModelForCausalLM

HF_TOKEN = os.getenv('HF_TOKEN', '')
if not HF_TOKEN:
    print('Note: Set HF_TOKEN to use gated models if needed.')

# Connect to Chroma (adjust if remote)
CHROMA_PATH = os.getenv('CHROMA_PATH', '/content/chroma_store')
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH, settings=Settings(anonymized_telemetry=False))
collection = chroma_client.get_collection('repo_chunks')

# Connect to Neo4j
NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test-password')
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))

# Load Gemma
model_id = 'google/gemma-3-1b-it'
print('Loading model:', model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN if HF_TOKEN else None)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto', torch_dtype='auto', use_auth_token=HF_TOKEN if HF_TOKEN else None)



In [None]:
from repo_indexer.graph.query_graph import get_functions_for_chunk, get_call_subgraph, serialize_graph_for_model


def build_prompt(graph_text: str, semantic_chunks: list, question: str) -> str:
    lines = []
    lines.append('System:\nYou are a code reasoning assistant. Use the code snippets and the call graph to answer questions concisely. When giving step-by-step flow, number the steps. When referring to code, include filepath:line ranges.')
    lines.append('\nContext:\n[CALL GRAPH]\n' + graph_text)
    lines.append('\n[SEMANTIC CHUNKS]')
    for i, c in enumerate(semantic_chunks, 1):
        loc = f"{c.get('filepath')}:{c.get('lines',["?","?"])[0]}-{c.get('lines',["?","?"])[1]}"
        lines.append(f"- {i}) {c.get('summary')} ({loc})\n{c.get('snippet','')[:400]}")
    lines.append('\nTask:\n' + question)
    lines.append('\nResponse format:\n- Short answer (1-2 lines)\n- Key steps (numbered)\n- Relevant code pointers (filepath:lines)\n- If unsure, state assumptions.')
    return '\n'.join(lines)


def run_inference(prompt: str, max_new_tokens: int = 256, temperature: float = 0.2):
    import torch
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_new_tokens, temperature=temperature)
    return tokenizer.decode(out[0], skip_special_tokens=True)



In [None]:
# Demo: take first chunk from chunks.jsonl and run retrieval
import json
from pathlib import Path

chunks_path = Path('/content/repo-indexer/outputs/chunks.jsonl')
if not chunks_path.exists():
    print('Note: mount your repo outputs to /content/repo-indexer/outputs/chunks.jsonl')

chunk = None
if chunks_path.exists():
    with open(chunks_path, 'r', encoding='utf-8') as f:
        line = f.readline()
        if line:
            item = json.loads(line)
            chunk = {'document': item.get('text',''), 'metadata': item.get('metadata',{})}

if not chunk:
    chunk = {'document': 'def foo():\n  return 1', 'metadata': {'summary': 'demo function foo'}}

# Semantic neighbors via a quick embedding using sentence-transformers
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('all-mpnet-base-v2')
emb = model.encode([chunk['metadata'].get('summary') or chunk['document']], convert_to_tensor=False)[0]
results = collection.query(query_embeddings=[emb.tolist()], n_results=3, include=['documents','metadatas'])
semantic = []
for i in range(len(results['ids'][0])):
    semantic.append({
        'id': results['ids'][0][i],
        'filepath': results['metadatas'][0][i].get('filepath'),
        'summary': results['metadatas'][0][i].get('summary'),
        'language': results['metadatas'][0][i].get('language'),
        'node_type': results['metadatas'][0][i].get('node_type'),
        'lines': [results['metadatas'][0][i].get('start_line'), results['metadatas'][0][i].get('end_line')],
        'snippet': results['documents'][0][i][:800]
    })

# Graph subgraph
with driver.session() as session:
    fnames = get_functions_for_chunk(chunk)
    nodes, edges = get_call_subgraph(session, fnames, direction='both', hops=2)
    graph_text = serialize_graph_for_model(nodes, edges, max_tokens=600)

prompt = build_prompt(graph_text, semantic, 'Explain the flow and purpose of the selected code.')
print(prompt[:1000])

output = run_inference(prompt)
print('\n--- Model Output ---\n')
print(output)

