# Insight Generation Pipeline

## Import all the important stuff

### General Imports

In [1]:
import os
import getpass
import re

### Langchain Imports

In [2]:
from langchain import hub
from langchain.chat_models import init_chat_model
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, FewShotPromptTemplate
from langchain_ollama import ChatOllama


### OpenAI Imports

In [3]:
from openai import OpenAI

## LLM

In [None]:
from langchain.chat_models import init_chat_model
if not os.environ.get("OPENAI_API_KEY"):
  os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter API key for OpenAI: ")
model = init_chat_model("gpt-4o-mini", model_provider="openai")

### SQL agent

In [8]:
from langchain_community.agent_toolkits import create_sql_agent
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate, ChatPromptTemplate
from langchain.agents import AgentExecutor, Tool, AgentType, initialize_agent
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.agents.agent import AgentExecutor
import json


In [9]:
data = []
file_names = [
    "demonstration_sqls_15B_3S_oncomx.json",
    ]
directory = "dem_sqls_oncomx"

for file_name in file_names:
    print(file_name)
    with open(os.path.join(directory, file_name), "r") as f:
        data.append(json.load(f))

#demonstration_sqls = json.load(open("dem_sqls/demonstration_sqls_cordis_1B_1S.json"))

demonstration_sqls_15B_3S_oncomx.json


In [10]:
import sqlite3
def get_database_prompt(db_path) -> str:

    stmt = ''

    conn = sqlite3.connect(db_path)
    cur = conn.cursor()

    # Fetch names of all tables
    cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cur.fetchall()

    # Fech create statements for all tables
    for table in tables:
        table_name = table[0]
        cur.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}';")
        create_statement = cur.fetchone()[0]

        stmt += create_statement + '\n\n'
        stmt += "/*\n" + f"3 rows from {table_name} table:\n"
        
        cur.execute(f"PRAGMA table_info({table_name})")
        columns = cur.fetchall()
        for column in columns:
            stmt += column[1] + '\t'
        stmt += "\n"
            
        cur.execute(f"SELECT * FROM {table_name} LIMIT 3;")
        create_statement = cur.fetchmany(3)
        for row in create_statement:
            for word in row:
                stmt += str(word) + '\t'
            stmt += "\n"
        stmt += "*/\n\n\n"

    conn.close()
    return stmt

db_path = "./data/sft_data_collections/oncomx/oncomx_v1_0_25_small/oncomx_v1_0_25_small.sqlite"
print(get_database_prompt(db_path=db_path))

CREATE TABLE anatomical_entity (
	id VARCHAR NOT NULL, 
	name VARCHAR, 
	description TEXT, 
	CONSTRAINT idx_18385_primary PRIMARY KEY (id)
)

/*
3 rows from anatomical_entity table:
id	name	description	
CL:0000057	fibroblast	A connective tissue cell which secretes an extracellular matrix rich in collagen and other macromolecules. Flattened and irregular in outline with branching processes; appear fusiform or spindle-shaped.	
CL:0000169	type B pancreatic cell	A cell that secretes insulin and is located towards the center of the islets of Langerhans.	
CL:0002092	bone marrow cell	A cell found in the bone marrow. This can include fibroblasts, macrophages, adipocytes, osteoblasts, osteoclasts, endothelial cells and hematopoietic cells.	
*/


CREATE TABLE biomarker (
	id VARCHAR DEFAULT '0', 
	gene_symbol VARCHAR, 
	biomarker_description TEXT, 
	biomarker_id VARCHAR, 
	test_is_a_panel INTEGER NOT NULL, 
	CONSTRAINT idx_18391_primary PRIMARY KEY (id)
)

/*
3 rows from biomarker table:
id	gene

In [11]:
def dem2sql_gpt(dem_sqls, db, db_schema):
    prefix = "You are a helpful assistant that translates natural language queries into SQL queries. You have access to the following database schema:\n{db_schema}\n\n\nThese are some examples of natural language queries and their corresponding SQL queries:"
    suffix = (
    "Question: {input}\n"
    "SQL:\n"
    "```sql\n"
    "-- Write only the SQL query needed to answer the question.\n"
    "-- Do not include any explanation.\n"
    "-- Always enclose your SQL in a ```sql code block.\n"
    "-- Do not add anything before or after the code block.\n"
)
    
    examples = dem_sqls["demonstrations"]
    question = dem_sqls["question"]
    
    example_prompt = PromptTemplate(
        input_variables=["input"],
        template="Question: {input}\nSQL: {query}",
        )
    
    prompt = FewShotPromptTemplate(
        prefix=prefix,
        examples=examples,
        example_prompt=example_prompt,
        suffix=suffix,
        input_variables=["input", "db_schema"],
    )
    
    final_prompt = prompt.format(
        input=question,
        db_schema=db_schema
        )
    
    response = model.invoke(final_prompt)
    return response

def extract_sql(response):
     
    raw_content = response.content

    # Extract content between ```sql and ```
    match = re.search(r"```sql\s*(.*?)\s*```", raw_content, re.DOTALL)
    #match = re.search(r"&\s*(.*?)\s*&", raw_content, re.DOTALL)
    if match:
        sql_query = match.group(1).strip()
    else:
        print(raw_content)
        sql_query = None
        #raise ValueError("No SQL query found in model response.")

    return sql_query

In [12]:
db_schema = get_database_prompt(db_path=db_path)
for dem_sqls, file_name in zip(data, file_names):
    with open(f"pred_sqls_oncomx/pred_sqls_gpt4omini_{file_name[19:-5]}.txt", "w") as f:
        for dem_sql in dem_sqls:
            response = dem2sql_gpt(dem_sql, db, db_schema)
            generated_sql = extract_sql(response)

            if generated_sql is not None:
                # Remove the trailing newline character
                # and add a space before writing to the file
                f.write(generated_sql.replace("\n", " ") + "\n")
            else: f.write("None\n")