# Recursive code generation with LangChain, LangGraph, LangSmith, HuggingFace and Gemma

References: [AlphaCodium](https://github.com/Codium-ai/AlphaCodium) , [LangChain examples](https://github.com/langchain-ai/langgraph/blob/main/examples/code_assistant/langgraph_code_assistant.ipynb)

Gemma is a relatively small model and you may need to add some more tweaks to the code in this notebook to improve the performance. E.g. you might want to finetune gemma first for code generation and only after that use it as cod assistant. Here we directly use gemma, but [here](https://github.com/VladimerKhasia/ML-in-Notebooks/blob/main/custom%20GenAI%20and%20tools/gemma_qlora_rlhf.ipynb) is the notebook showing how to fine-tune gemma for the code generation.

Do not forget to set your HugggingFace `HF_TOKEN` and LangChain `LANGCHAIN_API_KEY` secrets in google colab.

</br>
NOTE:

- Gemma 2b strugles wenn long questions end with question mark. If you remove question marks present in the evaluation dataset it will boost the performance.

- In this notebook there is few-shot prompting implemented but commented out. You can create good propts in the following way: 1. write a prompt and get the models output, 2. restructure models output and make its structure as similar to your need as possible which gives you new prompt, 3. You can iterate and compare which prompt will give you bes desired results. It is basically injection of your needs and structure into the "language" of the model.

In [None]:
#%%capture --no-stderr          # this is basically to avoid warnings
%pip install -U -q langchain_community tiktoken langchainhub chromadb langchain langgraph faiss-cpu bs4

In [None]:
!pip install -q bitsandbytes accelerate

In [3]:
model_id = "google/gemma-2b-it"

In [None]:
#@title  for GPU

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# bnb_config = BitsAndBytesConfig(
#         load_in_4bit=True,
#         bnb_4bit_use_double_quant=True,
#         bnb_4bit_quant_type="nf4",
#         bnb_4bit_compute_dtype=torch.bfloat16
# )

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # quantization_config=bnb_config,
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)

In [11]:
#@title for CPU

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AutoModelForCausalLM.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# gemma coding assistant

In [5]:
from bs4 import BeautifulSoup as Soup
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader

# LCEL docs
url = "https://python.langchain.com/docs/expression_language/"
loader = RecursiveUrlLoader(
    url=url, max_depth=20, extractor=lambda x: Soup(x, "html.parser").text
)
docs = loader.load()

# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
    [doc.page_content for doc in d_reversed]
)

In [7]:
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional

from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.messages.ai import AIMessage
from langchain_core.language_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessageChunk, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import run_in_executor
from transformers import pipeline
import re
import json
from typing import Any


class GemmaChatModel(BaseChatModel):
    """
    A custom chat model powered by Gemma from Hugging Face, designed to be informative, comprehensive, and engaging.
    See the custom model guide here: https://python.langchain.com/docs/modules/model_io/chat/custom_chat_model/
    """

    model_name: str = "gemma_chat_model"  # Replace with the actual Gemma model name
    task: str = "conversational"  # Task for the pipeline (conversational or summarization)
    n: int = 2048
    model : Any = None
    tokenizer : Any = None


    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        """
        Args:
            messages: The list of prompt messages.
            stop: Optional list of stop tokens.
            run_manager: Optional callback manager.
            **kwargs: Additional keyword arguments.

        Returns:
            A ChatResult object containing the generated response.
        """

        prompt = messages[-1].content #[: self.n]
        input_ids = self.tokenizer(prompt, return_tensors="pt").to(device)
        outputs = self.model.generate(**input_ids, max_new_tokens=self.n)
        text = self.tokenizer.decode(outputs[0])
        #text = " ".join(text.split("\n"))

        start_index, end_index = text.find("<eos>"), text.rfind("<eos>")
        response = text[start_index+len("<eos>"):end_index].strip()

        message = AIMessage(content=response, additional_kwargs={}, response_metadata={"time_in_seconds": 3})
        return ChatResult(generations=[ChatGeneration(message=message)])

    @property
    def _llm_type(self) -> str:
        """
        Returns the type of language model used: "gemma_chat_model".
        """
        return "gemma_chat_model"

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """
        Returns a dictionary of identifying parameters for LangChain callbacks.
        """
        return {"model_name": self.model_name, "task": self.task}

llm = GemmaChatModel()
llm.model = model
llm.tokenizer = tokenizer

In [8]:
llm.invoke("hello")

AIMessage(content="Hello! 👋\n\nIt's great to hear from you. What can I do for you today? 😊", response_metadata={'time_in_seconds': 3}, id='run-1eebbc13-9bcf-4c19-9331-14ec1cd15dde-0')

In [9]:
from langchain.prompts import (
    ChatPromptTemplate,
    FewShotChatMessagePromptTemplate,
)

# # This part is for few shot prompting
# #https://python.langchain.com/docs/modules/model_io/prompts/few_shot_examples_chat/

# example_input_one = "write python function for adding two numbers"
# example_output_one = """Get the two numbers from the user. a = int(input("Enter the first number: ")) b = int(input("Enter the second number: "))
# Add the two numbers and print the result. sum = add_two_numbers(a, b) print(f"The sum of {a} and {b} is {sum}")
# import python
# def add_two_numbers(a, b):
#       \"\"\"This function adds two numbers. Args: a (int): The first number. b (int): The second number. Returns: int: The sum of a and b.\"\"\"
#       return a + b """

# examples = [
#     {"input": example_input_one, "output": example_output_one},
# ]
# example_prompt = ChatPromptTemplate.from_messages(
#     [
#         ("human", "{input}"),
#         ("ai", "{output}"),
#     ]
# )
# few_shot_prompt = FewShotChatMessagePromptTemplate(
#     example_prompt=example_prompt,
#     examples=examples,
# )

In [10]:
from functools import reduce

def dictifier(ai_message):

    def code_only(x):
      if x.startswith('python'):
        return x[6:]

    def add_non_nones(acc, x):
        if x is not None:
            return acc + x
        else:
            return acc

    filtered = list( map( code_only, ai_message.content.split("```") ) )
    code_string = reduce(add_non_nones, filtered, '')

    # Create a dictionary
    result_dict = {
        'prefix': ai_message.content,
        #'imports': import_section,
        'code': code_string,
    }
    return json.dumps(result_dict)


In [11]:
## Parser: https://python.langchain.com/docs/modules/model_io/output_parsers/types/pydantic/   https://python.langchain.com/docs/modules/model_io/output_parsers/quick_start/
##         https://python.langchain.com/docs/modules/model_io/output_parsers/custom/  custom one

from typing import List
from langchain.output_parsers import PydanticOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field, validator

class Code(BaseModel):
    """Code output"""

    prefix: str = Field(description="Description of the problem and approach")
    ## imports: str = Field(description="Code block import statements")
    code: str = Field(description="Code block not including import statements")
    description = "Schema for code solutions to questions about LCEL."

    # # Custom validation logic with Pydantic for prefix.
    # @validator("prefix")
    # def description_isnot_python_comment(cls, content):
    #     if not content: #TODO: implement logic here
    #         raise ValueError("Desctiption is not in a comment format!")
    #     return content


# Set up a parser + inject instructions into the prompt template.
parser = PydanticOutputParser(pydantic_object=Code)

# Just inspect what it contains
##parser.dict()['pydantic_object'].__dict__

In [None]:
code_gen_prompt = ChatPromptTemplate.from_messages(
    [("system","""You are a coding assistant with expertise in LCEL, LangChain expression language. \n
    Here is a full set of LCEL documentation:  \n ------- \n  {context} \n ------- \n Answer the user
    question based on the above provided documentation. Ensure any code you provide can be executed \n
    with all required imports and variables defined. Structure your answer with a description of the code solution. \n
    Then list the imports. And finally list the functioning code block.\n
    Use these format instructions for the structure of your answer: {format_instructions}.""",),
    #few_shot_prompt,
    ("user", "{messages}"),])

coder = code_gen_prompt | llm | dictifier | parser

question = "How can I directly pass a string to a runnable and use it to construct the input needed for my prompt"
result = coder.invoke({"context": concatenated_content, "format_instructions": parser.get_format_instructions(), "messages":[("user", question)]})
# ## parser.invoke(result)   #use this if you have not imcluded parser in the chain and want to test how parser works
# ## result.content          #use this in case you are dealing with direct llm output - it's AIMessage

print(result.prefix, "\n--------------------------------------\n", result.code)

In [13]:
#@title graph part

from operator import itemgetter
from langchain.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnablePassthrough
from typing import Dict, TypedDict, List
from langgraph.graph import END, StateGraph

coder = code_gen_prompt | llm | dictifier | parser

class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        error : Binary flag for control flow to indicate whether test error was tripped
        messages : With user question, error messages, reasoning
        generation : Code solution
        iterations : Number of tries
    """

    error : str
    messages : List
    generation : str
    iterations : int


### Parameter

# Max tries
max_iterations = 3
# Reflect
# flag = 'reflect'
flag = 'do not reflect'

### Nodes

def generate(state: GraphState):
    """
    Generate a code solution

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    # State
    messages = state["messages"]
    iterations = state["iterations"]
    error = state["error"]

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [("user","Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:")]

    # Solution
    code_solution = coder.invoke({"context": concatenated_content, "format_instructions": parser.get_format_instructions(), "messages" : messages})
    ## messages += [("assistant",f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}")]
    messages += [("assistant",f"{code_solution.prefix} \n \n Code: {code_solution.code}")]

    # Increment
    iterations = iterations + 1
    return {"generation": code_solution, "messages": messages, "iterations": iterations}

def code_check(state: GraphState):
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state["messages"]
    code_solution = state["generation"]
    iterations = state["iterations"]

    # Get solution components
    prefix = code_solution.prefix
    ##imports = code_solution.imports
    code = code_solution.code

    # # Check imports
    # try:
    #     exec(imports)
    # except Exception as e:
    #     print("---CODE IMPORT CHECK: FAILED---")
    #     error_message = [("user", f"Your solution failed the import test: {e}")]
    #     messages += error_message
    #     return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes"}

    # # Check execution
    # try:
    #     exec(imports + "\n" + code)
    # except Exception as e:
    #     print("---CODE BLOCK CHECK: FAILED---")
    #     error_message = [("user", f"Your solution failed the code execution test: {e}")]
    #     messages += error_message
    #     return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes"}

    try:
        exec(code)
    except Exception as e:
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [("user", f"Your solution failed the code execution test: {e}")]
        messages += error_message
        return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes"}


    # No errors
    print("---NO CODE TEST FAILURES---")
    return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "no"}

def reflect(state: GraphState):
    """
    Reflect on errors

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    # State
    messages = state["messages"]
    iterations = state["iterations"]
    code_solution = state["generation"]

    # Prompt reflection
    reflection_message = [("user", """You tried to solve this problem and failed a unit test. Reflect on this failure
                                    given the provided documentation. Write a few key suggestions based on the
                                    documentation to avoid making this mistake again.""")]

    # Add reflection
    reflections = coder.invoke({"context"  : concatenated_content, "format_instructions": parser.get_format_instructions(), "messages" : messages})
    messages += [("assistant" , f"Here are reflections on the error: {reflections}")]
    return {"generation": code_solution, "messages": messages, "iterations": iterations}

### Edges

def decide_to_finish(state: GraphState):
    """
    Determines whether to finish.

    Args:
        state (dict): The current graph state

    Returns:
        str: Next node to call
    """
    error = state["error"]
    iterations = state["iterations"]

    if error == "no" or iterations == max_iterations:
        print("---DECISION: FINISH---")
        return "end"
    else:
        print("---DECISION: RE-TRY SOLUTION---")
        if flag == 'reflect':
            return "reflect"
        else:
            return "generate"



workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("generate", generate)  # generation solution
workflow.add_node("check_code", code_check)  # check code
workflow.add_node("reflect", reflect)  # reflect

# Build graph
workflow.set_entry_point("generate")
workflow.add_edge("generate", "check_code")
workflow.add_conditional_edges(
    "check_code",
    decide_to_finish,
    {
        "end": END,
        "reflect": "reflect",
        "generate": "generate",
    },
)
workflow.add_edge("reflect", "generate")
app = workflow.compile()

In [None]:
question = "How can I directly pass a string to a runnable and use it to construct the input needed for my prompt"
app.invoke({"messages":[("user", question)],"iterations":0})

# Evaluation of gemma coding assistant

In [15]:
import langsmith
from langsmith.schemas import Example, Run
from langsmith.evaluation import evaluate

In [16]:
import os
from google.colab import userdata

os.environ["LANGCHAIN_API_KEY"] = userdata.get('LANGCHAIN_API_KEY')

client = langsmith.Client()
# Clone the dataset to your tenant to use it
public_dataset = ("https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d")
client.clone_public_dataset(public_dataset)

In [17]:
# def check_import(run: Run, example: Example) -> dict:
#     imports = run.outputs.get("imports")
#     try:
#         exec(imports)
#         return {"key": "import_check" , "score": 1}
#     except:
#         return {"key": "import_check" , "score": 0}

# def check_execution(run: Run, example: Example) -> dict:
#     imports = run.outputs.get("imports")
#     code = run.outputs.get("code")
#     try:
#         exec(imports + "\n" + code)
#         return {"key": "code_execution_check" , "score": 1}
#     except:
#         return {"key": "code_execution_check" , "score": 0}

def check_execution(run: Run, example: Example) -> dict:
    code = run.outputs.get("code")
    try:
        exec("\n" + code)
        return {"key": "code_execution_check" , "score": 1}
    except:
        return {"key": "code_execution_check" , "score": 0}

def predict_base_case(example: dict):
    """ Context stuffing """
    solution = coder.invoke({"context"  : concatenated_content, "messages" : [("user",example["question"])]})
    # return {"imports": solution_structured.imports, "code": solution_structured.code}
    return {"code": solution.code}

def predict_langgraph(example: dict):
    """ LangGraph """
    graph = app.invoke({"messages":[("user",example["question"])],"iterations":0})
    solution = graph["generation"]
    # return {"imports": solution.imports, "code": solution.code}
    return {"code": solution.code}


# Evaluator
##code_evalulator = [check_import,check_execution]
code_evalulator = [check_execution]

# Dataset
dataset_name = "lcel-teacher-eval"

In [None]:
# Run base case
experiment_results_ = evaluate(
    predict_base_case,
    data=dataset_name,
    evaluators=code_evalulator,
    experiment_prefix=f"test-without-langgraph-{model_id}",
    max_concurrency=2,
    metadata={
      "llm": model_id,
    },
)
# Run with langgraph
experiment_results = evaluate(
    predict_langgraph,
    data=dataset_name,
    evaluators=code_evalulator,
    experiment_prefix=f"test-with-langgraph-{model_id}-{flag}",
    max_concurrency=2,
    metadata={
      "llm": model_id,
      "feedback": flag,
    },
)