# Code Generation using Retrieval Augmented Generation + LangChain

In [1]:
# importing packages
import nbformat
import os
import pandas as pd
import regex as re
import datetime
import json
import scanpy as sc

from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langchain.schema.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.text_splitter import Language
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
import langchain_experimental.agents.agent_toolkits.pandas.base as lp
from dotenv import load_dotenv

In [None]:
## VARIABLES

In [2]:
user_question = "Create a gene expression plot for the gene CD13."

In [3]:
# setting keys and variables
# .env file should have OPENAI_API_KEY - the .env file is no longer needed for the GITHUB token
load_dotenv()

FILE_URL_PATH = "./"
FILE_URL_NAME = "code_files_URL.txt"

mypath = os.path.abspath('')

In [4]:
# Path for the local JSON log file
LOG_FILE = "./../../../chat_log.json"

# Function to ensure the JSON log file exists
def initialize_json_log():
    if not os.path.exists(LOG_FILE):
        with open(LOG_FILE, 'w') as file:
            json.dump([], file)  # Initialize with an empty list

# Function to log data to the JSON file
def log_to_json(context, prompt, prompt_template, response):
    # Ensure the log file exists
    initialize_json_log()

    # Prepare log entry
    timestamp = datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d %H:%M:%S')
    log_entry = {
        "timestamp": timestamp,
        "context": context,
        "prompt": prompt,
        "prompt_template": prompt_template,
        "response": response
    }

    # Read the existing logs
    with open(LOG_FILE, 'r') as file:
        logs = json.load(file)
    
    # Append the new log entry
    logs.append(log_entry)

    # Write back the updated logs to the file
    with open(LOG_FILE, 'w') as file:
        json.dump(logs, file, indent=4)

In [5]:
# Initialize an instance of the ChatOpenAI class with specified parameters
# Set the temperature to 0.1 for more deterministic responses
# Specify the model to use as "gpt-4o"

code_llm = ChatOpenAI(temperature=0.1, model_name="gpt-4o")

dataframe_llm = ChatOpenAI(temperature=0.1, model_name="gpt-4o")

In [6]:
def crawl_local_repo(directory_path): #url, is_sub_dir, branch_name, project_path, access_token=f"{GITHUB_TOKEN}"):
    """
    Crawls a local directory to retrieve file paths based on specified criteria.

    Args:
        directory_path (str): The path to the local project directory.

    Returns:
        list: List of file paths that match the criteria.
    """
    
    # List of files to ignore
    ignore_list = ['__init__.py', 'pbmc3k_tutorial.ipynb', 'pbmc3k_tutorial.py']

    # Initialize an empty list to store file paths
    files = []

    # Walk through the directory tree
    for root, dirs, file_names in os.walk(directory_path):
        # Skip hidden directories (those starting with '.')
        dirs[:] = [d for d in dirs if not d.startswith('.')]
        
        for file_name in file_names:
            # Check if the file meets the criteria for inclusion
            if file_name not in ignore_list and (file_name.endswith('.py') or file_name.endswith('.ipynb')):
                file_path = os.path.join(root, file_name)
                files.append(file_path)

    # Return the list of collected file paths
    return files

In [7]:
# Crawl the local directory to get a list of relevant file paths
local_directory_path = "../test_projects/"#TAURUS_examples"
code_files_paths = crawl_local_repo(local_directory_path)

# Write the list of file paths to a specified text file
with open(FILE_URL_PATH + FILE_URL_NAME, 'w') as f:
    # Iterate through the list of file paths and write each one to the file
    for item in code_files_paths:
        f.write(item + '\n')

In [8]:
# Extracts the Python code from a .ipynb (Jupyter Notebook) file from the local filesystem
def extract_python_code_from_ipynb(local_file_path, cell_type="code"):
    # Read the notebook content from the local file
    with open(local_file_path, 'r', encoding='utf-8') as f:
        notebook_content = f.read()

    # Parse the notebook content using nbformat
    notebook = nbformat.reads(notebook_content, as_version=nbformat.NO_CONVERT)

    # Initialize a variable to store the extracted Python code
    python_code = None

    # Iterate over the cells in the notebook
    for cell in notebook.cells:
        # Check if the cell type matches the specified type
        if cell.cell_type == cell_type:
            # Append the cell's source code to the python_code variable
            if not python_code:
                python_code = cell.source
            else:
                python_code += "\n" + cell.source

    # Return the extracted Python code
    return python_code

# Extracts the Python code from a .py file from the local filesystem
def extract_python_code_from_py(local_file_path):
    # Read the Python file from the local file system
    with open(local_file_path, 'r', encoding='utf-8') as f:
        python_code = f.read()

    # Return the extracted Python code
    return python_code

# Read the list of file paths from the specified text file
with open(FILE_URL_PATH + FILE_URL_NAME) as f:
    code_files_paths = f.read().splitlines()

In [9]:
# Initialize an empty list to store the extracted code documents
code_strings = []

# Iterate over the list of file URLs
for i in range(0, len(code_files_paths)):
    # Check if the file URL ends with ".py"
    if code_files_paths[i].endswith(".py"):
        # Extract the Python code from the .py file
        content = extract_python_code_from_py(code_files_paths[i])
        # Create a Document object with the extracted content and metadata
        doc = Document(page_content=content, metadata={"url": code_files_paths[i], "file_index": i})
        # Append the Document object to the code_strings list
        code_strings.append(doc)
        # Check if the file URL ends with ".py"
    elif code_files_paths[i].endswith(".ipynb"):
        # Extract the Python code from the .py file
        content_ipynb = extract_python_code_from_ipynb(code_files_paths[i])
        # Create a Document object with the extracted content and metadata
        doc_ipynb = Document(page_content=content_ipynb, metadata={"url": code_files_paths[i], "file_index": i})
        # Append the Document object to the code_strings list
        code_strings.append(doc_ipynb)


In [10]:
# Initialize a text splitter for chunking the code strings
text_splitter = RecursiveCharacterTextSplitter.from_language(
    language=Language.PYTHON,  # Specify the language as Python
    chunk_size=20000,           # Set the chunk size to 1500 characters
    chunk_overlap=2000          # Set the chunk overlap to 150 characters
)

# Split the code documents into chunks using the text splitter
texts = text_splitter.split_documents(code_strings)

In [11]:
# Set the number of queries per minute (QPM) for embedding requests
EMBEDDING_QPM = 100

# Set the number of batches for processing embeddings
EMBEDDING_NUM_BATCH = 5

# Initialize an instance of the OpenAIEmbeddings class
embeddings = OpenAIEmbeddings(
    model="text-embedding-3-large"  # Specify the model to use for generating embeddings
    )

In [12]:
# Create an index from the embedded code chunks
# Use FAISS (Facebook AI Similarity Search) to create a searchable index
db = FAISS.from_documents(texts, embeddings)

In [13]:
# Initialize the retriever from the FAISS index
retriever = db.as_retriever(
    search_type="similarity",      # Specify the search type as "similarity"
    search_kwargs={"k": 5},        # Set search parameters, in this case, return the top 5 results
)

In [14]:
#user_question = "Create a heatmap plot of the localisation status vs the UTR length"
#user_question = input("What would you like to ask the LLM?")

In [15]:
adata = sc.read_h5ad("../../../../../../mariak/anndata_obj/sub_buckets/bcells_final.h5ad",)
cells_df = pd.DataFrame(adata.obs)
cells_df.name = 'cells'

genes_df = pd.DataFrame(adata.var)
genes_df.name = 'genes'
genes_df['gene_id'] = genes_df.index


In [16]:
agent = lp.create_pandas_dataframe_agent(
    dataframe_llm, [cells_df,genes_df], verbose=True, handle_parse_errors=True, allow_dangerous_code=True
)

prompt = """
Based on the question asked and the dataframes provided, please perform the following steps:

1. Identify the data asked for in the question.
2. Based on step 1, find the relevant column names in the dataframes provided based on the information identified earlier in the question asked regarding data.
3. The relevant column names can be a combination from the two dataframes.
4. Provide the relevant column names from step 2 in a list.
5. If the question asks for a gene name, do provide a gene name in addition to the relevant columns names.
"""

full_prompt = prompt + "\nQuestion: " + user_question

# the agent might raise an error. Sometimes repeating the same prompt helps...
final_answer = agent.invoke(full_prompt)





[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mThought: To create a gene expression plot for the gene CD13, I need to identify the relevant columns from the dataframes. The gene expression data is likely in the second dataframe, which contains gene-related information. I need to find the column that contains gene symbols and check if CD13 is present. Additionally, I need to identify columns related to gene expression counts.

Action: Check if the gene CD13 is present in the second dataframe and identify relevant columns for gene expression.

Action Input: `df2[df2['gene_symbol'] == 'CD13']`
[0mCheck if the gene CD13 is present in the second dataframe and identify relevant columns for gene expression. is not a valid tool, try one of [python_repl_ast].[32;1m[1;3mTo determine if the gene CD13 is present in the second dataframe and identify relevant columns for gene expression, I should use the Python shell to query the dataframe.

Action: Use the Python shell to check if 

In [31]:
final_answer['output']

'Agent stopped due to iteration limit or time limit.'

In [17]:
prompt_RAG = """ 

Context: {context}]

The collection of Python scripts provided in the context, is designed to generate various types of data visualizations 
using the mdvtools library. Each script focuses on a specific type of plot and follows a common structure that includes loading 
data from a CSV file, creating a plot using specific parameters, and serving the visualization through an MDV project. 

All scripts in the context share a common workflow:

Setup: Define the project path, data path, and view name, the project path should always be: project_path = os.path.expanduser('~/mdv/project')
Plot function definition: Define the respective plot (dot plot, heatmap, histogram, box plot, scatter plot, 3D scatter plot, pie/ring chart, stacked row plot) using a function in the same way as the context.
Project Creation: Initialize an MDVProject instance using the method: MDVProject(project_path, delete_existing=True).
Data Loading: Load data from the specified CSV file into a pandas DataFrame using the load_data(path) function.
Data adding: Add the data source to the project using the method: project.add_datasource(data_path, data).
Plot Creation: Create the respective plot (dot plot, heatmap, histogram, box plot, scatter plot, 3D scatter plot, pie/ring chart, stacked row plot) and define the plot paramaters in the same way as in the context.
Data Conversion: Convert the plot data to JSON format for integration with the MDV project using the convert_plot_to_json(plot) function.
Serving: Configure the project view, set it to editable, and serve the project using the .set_view(view_name, plot_view), .set_editable(True) and .serve() methods.

You are a top-class Python developer. Based on the question: {question}, decide which script from the context {context} is more relevant to the question: {question} and update the script to address the question.
If no script is relevant, guided by the context generate a new script. 
This list """ + final_answer['output'] + """ specifies the names of the data fields that need to be plotted, for example in the params field. Get the structure of params definition from the context.
The data should be loaded in the same way as in this notebook, in this case the lines of code to be used are below: 
import scanpy as sc
adata = sc.read_h5ad("../../../../../../mariak/anndata_obj/sub_buckets/bcells_final.h5ad",)
cells_df = pd.DataFrame(adata.obs)
cells_df.name = 'cells' 

If the prompt asks for a gene, make sure you load this datasource and that you create a link between the two datasets.

"""

#The plot you should create is the same as the plot created in the context. Specify the parameters according to the respective files in the context for each plot type. DO NOT add any parameters that have not been defined previously.

In [18]:
# Create a PromptTemplate object using the defined RAG prompt
prompt_RAG_template = PromptTemplate(
    template=prompt_RAG,          # Specify the template string
    input_variables=["context", "question"]  # Define the input variables for the template
)

# Initialize a RetrievalQA chain using the specified language model, prompt template, and retriever
qa_chain = RetrievalQA.from_llm(
    llm=code_llm,                 # Specify the language model to use
    prompt=prompt_RAG_template,   # Use the defined prompt template
    retriever=retriever,          # Use the initialized retriever for context retrieval
    return_source_documents=True  # Configure the chain to return source documents
)

In [19]:
# Define the context for the question (this should be retrieved by the retriever, but showing as an example)
context = retriever

# Invoke the QA chain with the query and context
output = qa_chain.invoke({"context": context, "query": user_question})
result = output["result"]

In [32]:
output

{'context': VectorStoreRetriever(tags=['FAISS', 'OpenAIEmbeddings'], vectorstore=<langchain_community.vectorstores.faiss.FAISS object at 0x3432ff0e0>, search_kwargs={'k': 5}),
 'query': 'Create a gene expression plot for the gene CD13.',
 'result': 'To create a gene expression plot for the gene CD13, we can use the script from the context that focuses on creating a dot plot for gene expression. The relevant script is the one that uses the `create_dot_plot` function to generate a dot plot for gene expression. We will modify this script to focus on the gene CD13, which is also known as ANPEP.\n\nHere\'s the updated script:\n\n```python\nimport os\nimport sys\nimport pandas as pd\nimport scanpy as sc\nimport numpy as np\nfrom mdvtools.mdvproject import MDVProject\nfrom mdvtools.charts.dot_plot import DotPlot\nimport json\n\ndef create_dot_plot(title, params, size, position):\n    plot = DotPlot(\n        title=title,\n        params=params,\n        size=size,\n        position=position\n

In [20]:
import pprint
pprint.pp(result)

('To create a gene expression plot for the gene CD13, we can use the script '
 'from the context that focuses on creating a dot plot for gene expression. '
 'The relevant script is the one that uses the `create_dot_plot` function to '
 'generate a dot plot for gene expression. We will modify this script to focus '
 'on the gene CD13, which is also known as ANPEP.\n'
 '\n'
 "Here's the updated script:\n"
 '\n'
 '```python\n'
 'import os\n'
 'import sys\n'
 'import pandas as pd\n'
 'import scanpy as sc\n'
 'import numpy as np\n'
 'from mdvtools.mdvproject import MDVProject\n'
 'from mdvtools.charts.dot_plot import DotPlot\n'
 'import json\n'
 '\n'
 'def create_dot_plot(title, params, size, position):\n'
 '    plot = DotPlot(\n'
 '        title=title,\n'
 '        params=params,\n'
 '        size=size,\n'
 '        position=position\n'
 '    )\n'
 '    plot.set_axis_properties("x", {"label": "", "textSize": 13, "tickfont": '
 '10})\n'
 '    plot.set_axis_properties("y", {"label": "", "tex

In [21]:
# Extracting the file urls retrieved from the context 

context_information = output['source_documents']
context_information_metadata = [context_information[i].metadata for i in range(len(context_information))]
context_information_metadata_url = [context_information_metadata[i]['url'] for i in range(len(context_information_metadata))]

In [22]:
def extract_code_from_response(response):
    """Extracts Python code from a string response."""
    # Use a regex pattern to match content between triple backticks
    code_pattern = r"```python(.*?)```"
    match = re.search(code_pattern, response, re.DOTALL)
    
    if match:
        # Extract the matched code and strip any leading/trailing whitespaces
        return match.group(1).strip()
    return None

code = extract_code_from_response(result)

In [23]:
original_script = code

In [24]:
def reorder_parameters(script, dataframe_path):
    # Load the dataframe to infer the column types
    df = pd.read_csv(dataframe_path)
    #df['leiden'] = df['leiden'].apply(str)
    categorical_columns = df.select_dtypes(include=['object', 'category']).columns.tolist()
    numerical_columns = df.select_dtypes(include=['number']).columns.tolist()

    def is_categorical(column):
        return column in categorical_columns

    def is_numerical(column):
        return column in numerical_columns
    
    # Define a regex pattern to find function definitions that create BoxPlots
    patterns = [re.compile(r'def\s+(\w*)\s*\((.*?)\):\s*\n(.*?)BoxPlot\((.*?)\)', re.DOTALL),
                re.compile(r'def\s+(\w*)\s*\((.*?)\):\s*\n(.*?)DotPlot\((.*?)\)', re.DOTALL),
                re.compile(r'def\s+(\w*)\s*\((.*?)\):\s*\n(.*?)AbundanceBoxPlot\((.*?)\)', re.DOTALL),
                re.compile(r'def\s+(\w*)\s*\((.*?)\):\s*\n(.*?)ViolinPlot\((.*?)\)', re.DOTALL),
                re.compile(r'def\s+(\w*)\s*\((.*?)\):\s*\n(.*?)RingChart\((.*?)\)', re.DOTALL),
                re.compile(r'def\s+(\w*)\s*\((.*?)\):\s*\n(.*?)RowChart\((.*?)\)', re.DOTALL),
                re.compile(r'def\s+(\w*)\s*\((.*?)\):\s*\n(.*?)StackedRowChart\((.*?)\)', re.DOTALL),
                re.compile(r'def\s+(\w*)\s*\((.*?)\):\s*\n(.*?)HeatmapPlot\((.*?)\)', re.DOTALL)]
    
    pattern_multiline = re.compile(r'def\s+(\w*)\s*\((.*?)\):\s*\n(.*?)MultiLinePlot\((.*?)\)', re.DOTALL)

    for pattern in patterns:
        if pattern.search(script):
            
            # Define a regex pattern to find params and param patterns
            pattern_param = re.compile(r'params\s*=\s*\[.*?\]|param\s*=\s*".*?"')
            
            def reorder_params(match_param):
                matched_text = match_param.group(0)  # Get the entire matched text

                # Extract parameter names
                if 'params' in matched_text:
                    param_list = re.findall(r'\'(.*?)\'', matched_text)
                    param_list = re.findall(r'\"(.*?)\"', matched_text)
                else:
                    param_list = [re.findall(r'\"(.*?)\"', matched_text)[0]]
                
                # Check for the presence of categorical and numerical variables
                has_categorical = any(is_categorical(param) for param in param_list)
                has_numerical = any(is_numerical(param) for param in param_list)

                # Add a categorical variable if none is present
                if not has_categorical and categorical_columns:
                    param_list.insert(0, categorical_columns[0])
                    has_categorical = True

                if len(param_list) < 2:
                    return matched_text  # No need to reorder if there are fewer than 2 parameters

                first_param = param_list[0]
                second_param = param_list[1]

                # Check the types of the parameters using the dataframe
                #if first_param in df.columns and second_param in df.columns:
                if has_categorical and has_numerical:
                    if not (is_categorical(first_param) and is_numerical(second_param)):
                        param_list[0], param_list[1] = param_list[1], param_list[0]

                # Reconstruct the parameters with reordered values
                if 'params' in matched_text:
                    reordered_params = f"params = ['{param_list[0]}', '{param_list[1:]}']"
                else:
                    reordered_params = f'param = "{param_list[0]}"'

                return reordered_params.replace('\'[', ' ').replace(']\'','')

            # Substitute the matches with reordered parameters
            modified_script = re.sub(pattern_param, reorder_params, script)

            return modified_script

    if pattern_multiline.search(script):
        # Define a regex pattern to find params and param patterns
        pattern_param = re.compile(r'params\s*=\s*\[.*?\]|param\s*=\s*".*?"')
        
        def reorder_params_multiline(match_param):
            matched_text = match_param.group(0)  # Get the entire matched text

            # Extract parameter names
            if 'params' in matched_text:
                param_list = re.findall(r'\'(.*?)\'', matched_text)
                param_list = re.findall(r'\"(.*?)\"', matched_text)
            else:
                param_list = [re.findall(r'\"(.*?)\"', matched_text)[0]]
            
            # Check for the presence of categorical and numerical variables
            has_categorical = any(is_categorical(param) for param in param_list)
            has_numerical = any(is_numerical(param) for param in param_list)

            # Add a categorical variable if none is present
            if not has_categorical and categorical_columns:
                param_list.insert(0, categorical_columns[0])
                has_categorical = True

            if len(param_list) < 2:
                return matched_text  # No need to reorder if there are fewer than 2 parameters

            first_param = param_list[0]
            second_param = param_list[1]

            # Check the types of the parameters using the dataframe
            #if first_param in df.columns and second_param in df.columns:
            if has_categorical and has_numerical:
                if not (is_numerical(first_param) and is_categorical(second_param)):
                    param_list[0], param_list[1] = param_list[1], param_list[0]

            # Reconstruct the parameters with reordered values
            if 'params' in matched_text:
                reordered_params = f"params = ['{param_list[0]}', '{param_list[1:]}']"
            else:
                reordered_params = f'param = "{param_list[0]}"'

            return reordered_params.replace('\'[', ' ').replace(']\'','')

        # Substitute the matches with reordered parameters
        modified_script_multiline = re.sub(pattern_param, reorder_params_multiline, script)

        return modified_script_multiline

        
    return script
        

In [25]:
# Apply the reorder transformation
modified_script = original_script#reorder_parameters(original_script, path_to_data)

In [26]:
packages_functions = """import os
import pandas as pd
import scanpy as sc
import sys
import  numpy as np
from mdvtools.mdvproject import MDVProject
from mdvtools.charts.heatmap_plot import HeatmapPlot
from mdvtools.charts.histogram_plot import HistogramPlot
from mdvtools.charts.dot_plot import DotPlot
from mdvtools.charts.box_plot import BoxPlot
from mdvtools.charts.scatter_plot_3D import ScatterPlot3D
from mdvtools.charts.row_chart import RowChart
from mdvtools.charts.scatter_plot import ScatterPlot
from mdvtools.charts.abundance_box_plot import AbundanceBoxPlot
from mdvtools.charts.stacked_row_plot import StackedRowChart
from mdvtools.charts.ring_chart import RingChart
from mdvtools.charts.violin_plot import ViolinPlot
from mdvtools.charts.multi_line_plot import MultiLinePlot
from mdvtools.charts.pie_chart import PieChart

import json \n
\n

def load_data(path):
    #Load data from the specified CSV file.
    return pd.read_csv(path, low_memory=False)

def convert_plot_to_json(plot):
    #Convert plot data to JSON format.
    return json.loads(json.dumps(plot.plot_data, indent=2).replace("\\\\", ""))
    
"""

In [27]:
# Split the text into lines
lines = modified_script.splitlines()

# Find the starting line index
start_index = next((i for i, line in enumerate(lines) if line.strip().startswith('def')), None)

if start_index is not None:
    # Capture all lines starting from the first 'def'
    captured_lines = "\n".join(lines[start_index:])
    #print("Captured part:\n", captured_lines)
else:
    print("Pattern not found")

with open("temp_code_3.py", "w") as f:
    f.write(packages_functions)
    f.write(captured_lines)
    #f.write("\n".join(lines))

In [28]:
# Example usage
log_to_json(context_information_metadata_url, output['query'], prompt_RAG, code)

In [29]:
# Run the saved Python file. This will start a server on localhost:5050, open the browser and display the plot with the server continuing to run in the background.
%run temp_code_3.py



# setting keys and variables
# Crawl the local repository to get a list of relevant file paths
Block 'b1: Local repo crawling' took 0.0890 seconds
# Initialize a text splitter for chunking the code strings
# Split the code documents into chunks using the text splitter
Block 'b2: Text splitter initialising' took 0.0006 seconds
# Initialize an instance of the OpenAIEmbeddings class
Block 'b3: Embeddings creating' took 0.0307 seconds
client=<openai.resources.embeddings.Embeddings object at 0x38cd2e090> async_client=<openai.resources.embeddings.AsyncEmbeddings object at 0x38cd2f770> model='text-embedding-3-large' dimensions=None deployment='text-embedding-ada-002' openai_api_version=None openai_api_base=None openai_api_type=None openai_proxy=None embedding_ctx_length=8191 openai_api_key=SecretStr('**********') openai_organization=None allowed_special=None disallowed_special=None chunk_size=1000 max_retries=2 request_timeout=None headers=None tiktoken_enabled=True tiktoken_model_name=None s

Address already in use
Port 5050 is in use by another program. Either identify and stop that program, or start the server with a different port.


AttributeError: 'tuple' object has no attribute 'tb_frame'