In [29]:
app.get_graph().print_ascii()


        +-----------+           
        | __start__ |           
        +-----------+           
               *                
               *                
               *                
          +-------+             
          | start |             
          +-------+             
               *                
               *                
               *                
         +----------+           
         | classify |           
         +----------+           
          .        .            
        ..          ..          
       .              .         
+---------+       +----------+  
| explain |       | generate |  
+---------+       +----------+  
          *        *            
           **    **             
             *  *               
           +-----+              
           | end |              
           +-----+              
               *                
               *                
               *                
         +

# Dependencies 

In [1]:
# langgraph_rag_gradio_app.py

from langgraph.graph import StateGraph
from pydantic import BaseModel
from typing import Optional, List, Literal, Dict
import gradio as gr
import json
import faiss
import numpy as np
import traceback
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM

# ====== STATE DEFINITION ======

class AgentState(BaseModel):
    user_input: str
    intent: Optional[Literal["generate", "explain"]] = None
    retrieved_examples: Optional[List[Dict]] = None
    formatted_prompt: Optional[str] = None
    llm_output: Optional[str] = None
    final_response: Optional[str] = None


# ====== RAG SETUP ======

with open("C:/Users/sief x/Desktop/Study Sessions/RAG/human-eval/data/HumanEval.jsonl/human-eval-v2-20210705.jsonl", "r", encoding="utf-8") as f:
    data = [json.loads(line) for line in f]

prompts = [item['prompt'] for item in data]
task_ids = [item["task_id"] for item in data]

model_embed = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = model_embed.encode(prompts, show_progress_bar=True).astype("float32")

dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)

tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
model_llm = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")


# ====== HELPERS ======

def retrieve_similar_examples(query: str, k: int = 2) -> List[Dict]:
    query_vec = model_embed.encode([query]).astype("float32")
    D, I = index.search(query_vec, k)
    return [
        {
            "task_id": task_ids[idx],
            "prompt": prompts[idx],
            "canonical_solution": data[idx]["canonical_solution"]
        }
        for idx in I[0]
    ]

def generation_prompt(retrieved_context: List[Dict], user_prompt: str) -> str:
    parts = ["### EXAMPLES:\n"]
    for context in retrieved_context:
        parts.append(f"# Task: {context['prompt'].strip()}")
        parts.append(f"# Solution:\n{context['canonical_solution'].strip()}\n")
    parts.append("### NEW TASK:")
    parts.append(f"# Task: {user_prompt.strip()}")
    parts.append("# Solution:")
    parts.append("def")
    return "\n".join(parts)

def truncate_at_delimiter(text, delimiter="# Task:"):
    return text.split(delimiter)[0].strip()

def generate_code(prompt_text: str, max_tokens=128) -> str:
    inputs = tokenizer(prompt_text, return_tensors="pt")
    input_length = inputs["input_ids"].shape[1]
    outputs = model_llm.generate(
        **inputs,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=0.7,
        top_p=0.95,
        pad_token_id=tokenizer.eos_token_id or 50256
    )
    generated_tokens = outputs[0][input_length:]
    return truncate_at_delimiter(tokenizer.decode(generated_tokens, skip_special_tokens=True))


# ====== LANGGRAPH NODES ======

def start_node(state: AgentState) -> AgentState:
    return state  # Gradio input already captured
def classify_intent(state: AgentState) -> AgentState:
    text = state.user_input.lower()
    if "explain" in text:
        state.intent = "explain"
    elif any(word in text for word in ["generate", "write", "create", "make", "build"]):
        state.intent = "generate"
    else:
        raise ValueError("❌ Could not classify user intent. Please use keywords like 'generate' or 'explain'.")
    return state
def route(state: AgentState) -> str:
    return state.intent
def explain_code(state: AgentState) -> AgentState:
    code = state.user_input.replace("Explain this code:", "").strip()
    explanation = f"This code defines a Python function or snippet. Without execution or full context, a static guess:\n\n{code}\n\n☝️ You can improve the result by adding more context."
    state.final_response = explanation
    return state
def generate_code_node(state: AgentState) -> AgentState:
    try:
        retrieved = retrieve_similar_examples(state.user_input, k=2)
        state.retrieved_examples = retrieved
        formatted = generation_prompt(retrieved, state.user_input)
        state.formatted_prompt = formatted
        llm_result = generate_code(formatted)
        state.llm_output = llm_result
        state.final_response = "def " + llm_result
    except Exception as e:
        state.final_response = "⚠️ Generation failed:\n" + str(e)
    return state
def end_node(state: AgentState) -> AgentState:
    return state


# ====== LANGGRAPH STRUCTURE ======

graph = StateGraph(AgentState)

graph.add_node("start", start_node)
graph.add_node("classify", classify_intent)
graph.add_node("generate", generate_code_node)
graph.add_node("explain", explain_code)
graph.add_node("end", end_node)

graph.set_entry_point("start")
graph.add_edge("start", "classify")
graph.add_conditional_edges("classify", route, {
    "generate": "generate",
    "explain": "explain"
})
graph.add_edge("generate", "end")
graph.add_edge("explain", "end")
graph.set_finish_point("end")

app = graph.compile()


# ====== GRADIO UI ======

def gradio_interface(user_input: str) -> str:
    try:
        initial_state = AgentState(user_input=user_input)
        final_state = app.invoke(initial_state)
        return final_state["final_response"]
    except Exception as e:
        return "❌ Error:\n" + traceback.format_exc()

gr.Interface(
    fn=gradio_interface,
    inputs=gr.Textbox(lines=4, placeholder="Enter 'Generate a function...' or 'Explain this code: ...'"),
    outputs="text",
    title="🧠 LangGraph Code Assistant",
    description="Supports code generation and explanation using retrieval-augmented reasoning."
).launch()





Batches:   0%|          | 0/6 [00:00<?, ?it/s]

  return forward_call(*args, **kwargs)
Some weights of the model checkpoint at Salesforce/codegen-350M-mono were not used when initializing CodeGenForCausalLM: ['transformer.h.0.attn.causal_mask', 'transformer.h.1.attn.causal_mask', 'transformer.h.10.attn.causal_mask', 'transformer.h.11.attn.causal_mask', 'transformer.h.12.attn.causal_mask', 'transformer.h.13.attn.causal_mask', 'transformer.h.14.attn.causal_mask', 'transformer.h.15.attn.causal_mask', 'transformer.h.16.attn.causal_mask', 'transformer.h.17.attn.causal_mask', 'transformer.h.18.attn.causal_mask', 'transformer.h.19.attn.causal_mask', 'transformer.h.2.attn.causal_mask', 'transformer.h.3.attn.causal_mask', 'transformer.h.4.attn.causal_mask', 'transformer.h.5.attn.causal_mask', 'transformer.h.6.attn.causal_mask', 'transformer.h.7.attn.causal_mask', 'transformer.h.8.attn.causal_mask', 'transformer.h.9.attn.causal_mask']
- This IS expected if you are initializing CodeGenForCausalLM from the checkpoint of a model trained on anoth

* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.




In [27]:
def test_generate_add_function():
    input_text = "Generate a function that adds two numbers"
    output = gradio_interface(input_text)
    print("Input:", input_text)
    print("Output:", output)
    assert "def" in output.lower()
    assert "+" in output or "add" in output.lower()
    print("✅ Passed: test_generate_add_function")

def test_generate_sorting_function():
    input_text = "Generate a function that sorts a list"
    output = gradio_interface(input_text)
    print("Input:", input_text)
    print("Output:", output)
    assert "def" in output.lower()
    assert "sort" in output.lower()
    print("✅ Passed: test_generate_sorting_function")

def test_explain_loop_code():
    input_text = "Explain this code: for i in range(5): print(i)"
    output = gradio_interface(input_text)
    print("Input:", input_text)
    print("Output:", output)
    assert any(keyword in output.lower() for keyword in ["loop", "iterates", "range", "prints"])
    print("✅ Passed: test_explain_loop_code")


def test_invalid_input():
    input_text = "What's the weather in Cairo?"
    try:
        output = gradio_interface(input_text)
    except Exception as e:
        print("✅ Properly raised error:", str(e))
        return
    print("Output:", output)
    assert "❌" in output or "could not classify" in output.lower()
    print("✅ Passed: test_invalid_input")


In [28]:
test_generate_add_function()
test_generate_sorting_function()
test_explain_loop_code()
test_invalid_input()


Input: Generate a function that adds two numbers
Output: def add(x: int, y: int) -> int:
    """Add two numbers x and y
    >>> add(2, 3)
    5
    >>> add(5, 7)
    12
    """
# Solution:
    return x + y
✅ Passed: test_generate_add_function
Input: Generate a function that sorts a list
Output: def sort_list_with_sort(l: list, sort_func: callable = None) -> list:
    """Sort a list.

    Sort the elements of the list, using a given function.
    If sort_func is None, it uses a default sort function.

    >>> sort_list_with_sort([1, 2, 3, 4, 5, 6])
    [1, 2, 3, 4, 5, 6]
    >>> sort_list_with_sort([1, 2, 3, 4, 5, 6], lambda x: x*2
✅ Passed: test_generate_sorting_function
Input: Explain this code: for i in range(5): print(i)
Output: This code defines a Python function or snippet. Without execution or full context, a static guess:

for i in range(5): print(i)

☝️ You can improve the result by adding more context.
✅ Passed: test_explain_loop_code
Output: ❌ Error:
Traceback (most recent ca