In [1]:
from langchain_core.runnables import (
    RunnableBranch,
    RunnableLambda,
    RunnableParallel,
    RunnablePassthrough,
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import Tuple, List, Optional
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
import os
from langchain_community.graphs import Neo4jGraph
from langchain.document_loaders import WikipediaLoader
from langchain.text_splitter import TokenTextSplitter
from langchain_openai import ChatOpenAI
from langchain_experimental.graph_transformers import LLMGraphTransformer
from neo4j import GraphDatabase
from yfiles_jupyter_graphs import GraphWidget
from langchain_community.vectorstores import Neo4jVector
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars
from langchain_core.runnables import ConfigurableField, RunnableParallel, RunnablePassthrough

try:
  import google.colab
  from google.colab import output
  output.enable_custom_widget_manager()
except:
  pass

In [3]:
os.environ["OPENAI_API_KEY"] = ""
os.environ["NEO4J_URI"] = ""
os.environ["NEO4J_USERNAME"] = ""
os.environ["NEO4J_PASSWORD"] = ""

graph = Neo4jGraph()


In [None]:
from langchain.text_splitter import TokenTextSplitter
import json
from langchain.document_loaders import TextLoader

loader = TextLoader('/home/yu.zhiyin/CellRAG/data/total/total_train.txt')
documents = loader.load()

from langchain.text_splitter import CharacterTextSplitter
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
chunks = text_splitter.split_documents(documents)

print(f"Number of chunks: {len(chunks)}")

In [None]:
from langchain_openai import ChatOpenAI
import time
from neo4j.exceptions import ServiceUnavailable
from langchain_experimental.graph_transformers import LLMGraphTransformer


def process_batch(batch):
    retry_attempts = 3
    for attempt in range(retry_attempts):
        try:
            graph_documents = llm_transformer.convert_to_graph_documents(batch)
            graph.add_graph_documents(
                graph_documents,
                baseEntityLabel=True,
                include_source=True
            )
            print("Batch processed successfully.")
            return 
        except ServiceUnavailable as e:
            print(f"Service unavailable, attempt {attempt + 1} of {retry_attempts}: {e}")
            if attempt < retry_attempts - 1:
                time.sleep(5)  
            else:
                print("Failed to process batch after several attempts.")
                raise  

batch_size = 10  

llm=ChatOpenAI(temperature=0, model_name="gpt-4-turbo") 
llm_transformer = LLMGraphTransformer(llm=llm)

start_batch = 0 
for i in range(start_batch * batch_size, len(chunks), batch_size):
    batch = chunks[i:i + batch_size]
    process_batch(batch)

In [None]:
default_cypher = "MATCH (s)-[r:!MENTIONS]->(t) RETURN s,r,t LIMIT 50"

def showGraph(cypher: str = default_cypher):
    driver = GraphDatabase.driver(
        uri = os.environ["NEO4J_URI"],
        auth = (os.environ["NEO4J_USERNAME"],
                os.environ["NEO4J_PASSWORD"]))
    session = driver.session()
    widget = GraphWidget(graph = session.run(cypher).graph())
    widget.node_label_mapping = 'id'
    return widget

showGraph()

In [8]:
vector_index = Neo4jVector.from_existing_graph(
    OpenAIEmbeddings(),
    search_type="hybrid",
    node_label="Document",
    text_node_properties=["text"],
    embedding_node_property="embedding"
)

In [9]:
llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo") 

In [10]:

graph.query(
    "CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id]"
)

class Entities(BaseModel):
    """Identifying information about entities."""
    names: List[str] = Field(
        ...,
        description="All Tissue and top 100 gene entities that appear in the text.",
    )

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are extracting Tissue and top 100 gene entities from the text."),
        ("human", "Extract the following information from the input: {question}")
    ]
)

entity_chain = prompt | llm.with_structured_output(Entities)

In [None]:
entity_chain.invoke({"question": "Task: Given the following information about a cell, predict its most likely cell type. Provide only the single most probable cell type without any additional explanation. Tissue: heart left ventricle. Top 100 genes for this cell (highest expression first): MALAT1, RYR2, TTN, LINC02388, DMD, SLC8A1, FHL2, THSD4, TTN-AS1, RBM20, CACNB2, LAMA2, SORBS2, PRKG1, CTNNA3, RP11-362K2.2, ABLIM1, MLIP, MARK3, PRKN, FHOD3, IL1RAPL1, DAPK2, RNF150, CD36, PDLIM5, PHACTR1, QKI, CUX1, CACNA1C, TNNT2, AC011288.2, AKAP13, PDZRN3, CDIN1, CFLAR, LDB3, NEAT1, PDE3A, SORBS1, TXNIP, MYL2, PTPRK, WDPCP, SLC1A3, NDRG3, LARGE1, ATXN1, FHIT, GALNT17, ATP1A2, GPHN, ELL2, CRADD, TMEM117, FANCC, DTNA, PDE4DIP, PLEKHA5, ZNF721, FOXO1, PALLD, JMJD1C, LRMDA, OBSCN, RP11-499P20.2, SLC8A1-AS1, EXOC6B, LRRTM3, TECRL, RASSF3, NEBL, DANT2, SVIL, PDK4, ANKRD17, REV3L, MYH7, PDE1C, FBXL7, ARL15, RBMS3, ACTC1, PLCL1, MSRB3, PKP2, CH17-189H20.1, SDK1, AKAP6, EXOC4, EXOC6, GAPVD1, ENSG00000273748, MAGI1, EPHA4, AKAP9, UBE2E2, USP49, MEF2A, RWDD1."}).names

In [None]:
def generate_full_text_query(input: str) -> str:
    """
    Generate a full-text search query for a given input string.

    This function constructs a query string suitable for a full-text search.
    It processes the input string by splitting it into words and appending a
    similarity threshold (~2 changed characters) to each word, then combines
    them using the AND operator. Useful for mapping entities from user questions
    to database values, and allows for some misspelings.
    """
    full_text_query = ""
    words = [el for el in remove_lucene_chars(input).split() if el]
    for word in words[:-1]:
        full_text_query += f" {word} AND"
    full_text_query += f" {words[-1]}"
    return full_text_query.strip()

def structured_retriever(question: str) -> str:
    """
    Collects the neighborhood of entities mentioned in the question.
    """
    result = []
    entities_org = entity_chain.invoke({"question": question})
    ##entities = entities_org.names[:11]
    ##print(f"Extracted entities: {entities}")  # Debugging
    
    for entity in entities_org.names:
        query_str = generate_full_text_query(entity)
        #print(f"Generated query: {query_str}")  # Debugging
        
        response = graph.query(
            """CALL db.index.fulltext.queryNodes('entity', $query, {limit:2})
            YIELD node,score
            CALL {
              WITH node
              MATCH (node)-[r:!MENTIONS]->(neighbor)
              RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
              UNION ALL
              WITH node
              MATCH (node)<-[r:!MENTIONS]-(neighbor)
              RETURN neighbor.id + ' - ' + type(r) + ' -> ' + node.id AS output
            }
            RETURN output LIMIT 50
            """,
            {"query": query_str},
        )
        
      #  if not response:
         #   print(f"No results for query: {query_str}")  # Debugging
            
        result.extend([el['output'] for el in response])
    
    return "\n".join(result)
print(structured_retriever("Given the following information about a cell, predict two candidate cell types. Provide only the cell types without additional explanation. Tissue: pancreas. Top 100 genes for this cell (highest expression first): SST, SERPINA1, GNAS, PCSK1N, RBP4, CHGA, RPL3, ACTG1, EEF1A1, TPT1, RPL19, CHGB, HLA-A, HSPA1A, CPE, RPL41, SCG5, EDN3, RPS4X, RPL8, RPL37A, TUBA1B, DYNLL1, RPL7A, GAD2, RPS8, RPL27A, RPS11, B2M, TIMP1, PTPRN, RPS2, RPL15, CD63, RPS15, TTR, RPL13A, SCG2, AQP3, IDS, PCSK2, RPS3A, RPL23A, GPX3, RPL10, TUBA1A, FOS, H3F3A, SEC11C, SERF2, RPS27A, EMC10, SCGN, RPS12, GAPDH, H3F3B, TAGLN2, NLRP1, RPL13, RPL14, PEG10, RPS14, RPS9, RPL24, ZFP36, RPS24, JUNB, RPS23, RPS28, EIF1, FAU, RPL11, FTH1, CLU, ATP5E, CALY, TMSB4X, RPL18, RPS29, RPL35A, FTL, PSAP, ENO1, RPL23, RPS18, DHRS2, RPLP2, RPS19, S100A6, MIF, RPLP1, HSP90AA1, RNASEK, CHCHD2, SSR4, RPL6, RPL28, HSPA5, HINT1, MALAT1."))

In [26]:
def unstructured_retriever(question: str):
    search_results = vector_index.similarity_search(question,k=2)
    
    unstructured_data = [el.page_content for el in search_results]
    
    return unstructured_data

In [None]:
from collections import defaultdict

def merge_structured_data(structured_data: str) -> str:
    lines = structured_data.strip().split('\n')
    ##print(lines)
    merged_located = defaultdict(list)
    merged_expressed = defaultdict(list)

    for line in lines:
        line = line.strip()
        if "LOCATED_IN" in line:
            key, value = line.split(' - LOCATED_IN')
            key = key.strip()
            value = value.strip()
            merged_located[value].append(key)  
        elif "EXPRESSED_IN" in line:
            key, value = line.split('->')
            key = key.strip()
            value = value.strip()
            merged_expressed[key].append(value) 

    merged_lines = []
    
    for value, keys in merged_located.items():
        unique_keys = list(set(keys)) 
        merged_line = f"{', '.join(unique_keys)} - LOCATED_IN {value.strip()}"
        merged_lines.append(merged_line)

    for key, values in merged_expressed.items():
        unique_values = list(set(values))  
        merged_line = f"{key} -> {', '.join(unique_values)}"
        merged_lines.append(merged_line)

    return '\n'.join(merged_lines)

def retriever(question: str):
    print(f"Search query: {question}")
    structured_data = structured_retriever(question)
    merged_structured_data = merge_structured_data(structured_data)
    
    unstructured_data = [el.page_content for el in vector_index.similarity_search(question,k=4)]
    final_data = f"""Structured data:
{merged_structured_data}  
Unstructured data:
{"#Document ".join(unstructured_data)}
    """
    return final_data


# question = "Given the following information about a cell, predict two candidate cell types. Provide only the cell types without additional explanation. Tissue: pancreas. Top 100 genes for this cell (highest expression first): SST, SERPINA1, GNAS, PCSK1N, RBP4, CHGA, RPL3, ACTG1, EEF1A1, TPT1, RPL19, CHGB, HLA-A, HSPA1A, CPE, RPL41, SCG5, EDN3, RPS4X, RPL8, RPL37A, TUBA1B, DYNLL1, RPL7A, GAD2, RPS8, RPL27A, RPS11, B2M, TIMP1, PTPRN, RPS2, RPL15, CD63, RPS15, TTR, RPL13A, SCG2, AQP3, IDS, PCSK2, RPS3A, RPL23A, GPX3, RPL10, TUBA1A, FOS, H3F3A, SEC11C, SERF2, RPS27A, EMC10, SCGN, RPS12, GAPDH, H3F3B, TAGLN2, NLRP1, RPL13, RPL14, PEG10, RPS14, RPS9, RPL24, ZFP36, RPS24, JUNB, RPS23, RPS28, EIF1, FAU, RPL11, FTH1, CLU, ATP5E, CALY, TMSB4X, RPL18, RPS29, RPL35A, FTL, PSAP, ENO1, RPL23, RPS18, DHRS2, RPLP2, RPS19, S100A6, MIF, RPLP1, HSP90AA1, RNASEK, CHCHD2, SSR4, RPL6, RPL28, HSPA5, HINT1, MALAT1."

# result = retriever(question)
# print(result)

In [14]:
# Condense a chat history and follow-up question into a standalone question
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""  # noqa: E501
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)

def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
    buffer = []
    for human, ai in chat_history:
        buffer.append(HumanMessage(content=human))
        buffer.append(AIMessage(content=ai))
    return buffer

_search_query = RunnableBranch(
    # If input includes chat_history, we condense it with the follow-up question
    (
        RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
            run_name="HasChatHistoryCheck"
        ),  # Condense follow-up question and chat into a standalone_question
        RunnablePassthrough.assign(
            chat_history=lambda x: _format_chat_history(x["chat_history"])
        )
        | CONDENSE_QUESTION_PROMPT
        | ChatOpenAI(temperature=0)
        | StrOutputParser(),
    ),
    # Else, we have no chat history, so just pass through the question
    RunnableLambda(lambda x : x["question"]),
)

In [15]:
template = """Answer the question based only on the following context:
{context}

Question: {question}
Use natural language and be concise.
Answer:"""
prompt = ChatPromptTemplate.from_template(template)

chain = (
    RunnableParallel(
        {
            "context": _search_query | retriever,
            "question": RunnablePassthrough(),
        }
    )
    | prompt
    | llm
    | StrOutputParser()
)

In [None]:
llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo")
from requests.exceptions import ConnectionError

n_start = 800  
n_end = 900  

with open('/home/yu.zhiyin/CellRAG/data/total/total_test_task1.txt', 'r') as question_file, open('cell_type_predictions12_add2.txt', 'a+') as output_file:
    questions = question_file.readlines()

    for question in questions[n_start:n_end]:  
        try:
            response = chain.invoke({
                "question": question.strip()  
            })
            
            output_file.write(f"{response}\n")  
        except ConnectionError as e:
            print(f"API connection error {e}")
            output_file.write("API connection error\n")
        except Exception as e:
            print(f"error: {e}")
            output_file.write("error\n")



In [None]:
import csv
from langchain import LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI

def read_top100_genes(file_path):
    top100_genes_list = []
    with open(file_path, mode='r') as file:
        for line in file:
            if "Top 100 genes for this cell" in line:
                genes_part = line.split("Top 100 genes for this cell (highest expression first): ")[1].strip()
                top100_genes_list.append(genes_part)
    return top100_genes_list

def read_cell_types(file_path):
    cell_types_list = []
    with open(file_path, mode='r') as file:
        for line in file:
            types = line.strip().lower()
            cell_types_list.append(types)
    return cell_types_list

def get_marker_sentence(cell_types, cell_marker_dict):
    marker_sentences = []
    for cell_type in cell_types.split(', '): 
        normalized_cell_type = cell_type.strip().lower()
     ##   print(f"Normalized cell type: {normalized_cell_type}")
        if normalized_cell_type in cell_marker_dict:
            markers = ', '.join(cell_marker_dict[normalized_cell_type])
        else:
            markers = "unknown"  
       ## print(f"Cell type: {cell_type}, Markers: {markers}")
        marker_sentences.append(f"{cell_type}: {markers}")
    return ". ".join(marker_sentences)

def remove_duplicates(cell_types):
    cell_types_set = set(cell_types.split(', '))
    return ', '.join(cell_types_set)

def read_total_marker(file_path):
    cell_marker_dict = {}
    with open(file_path, mode='r') as file:
        reader = csv.reader(file)
        headers = next(reader) 
        for row in reader:
            if row:  
                for i, marker in enumerate(row):
                    key = headers[i].strip().lower()
                    marker = marker.strip()
                    if key not in cell_marker_dict:
                        cell_marker_dict[key] = []
                    if marker: 
                        cell_marker_dict[key].append(marker)
    return cell_marker_dict

llm=ChatOpenAI(temperature=0, model_name="gpt-4o-mini") 

top100_genes_list = read_top100_genes('/home/yu.zhiyin/CellRAG/data/total/total_test_task2.txt')
predicted_cell_types_list = read_cell_types('/home/yu.zhiyin/CellRAG/data/cell_type_predictions12_add.txt')
retrieved_cell_types_list = read_cell_types('/home/yu.zhiyin/CellRAG/data/similar_cell_types111.txt')

cell_marker_dict = read_total_marker('/home/yu.zhiyin/CellRAG/data/marker gene/total.csv')

template = """
Given the following information about a cell:
Top 100 genes: {top100_genes}.
Candidate cell types and their marker genes: {predicted_markers}.
Similar cell types retrieved and their marker genes: {retrieved_markers}.

Task: Given the following information about a cell, predict its most likely cell type. Provide only the single most probable cell type without any additional explanation.
From the following cell types, select the most probable: 'fibroblast', 'activated CD4-positive, alpha-beta T cell', 'HSPCs', 'ductal', 'regular ventricular cardiac myocyte', 'vein endothelial cell', 'Erythrocytes', 'endothelial cell of artery', 'acinar', 'beta', 'B cell', 'natural killer cell', 'CD4-positive, alpha-beta cytotoxic T cell', 'epicardial adipocyte', 'regular atrial cardiac myocyte', 'macrophage', 'Plasmacytoid dendritic cells', 'gamma', 'native cell', 'smooth muscle cell', 'endothelial', 'CD20+ B cells', 'neural cell', 'Megakaryocyte progenitors', 'Plasma cells', 'endothelial cell', 'CD4+ T cells', 'CD10+ B cells', 'Monocyte-derived dendritic cells', 'CD14+ Monocytes', 'delta', 'Erythroid progenitors', 'activated CD8-positive, alpha-beta T cell', 'NK cells', 'mature NK T cell', 'alpha', 'CD8-positive, alpha-beta cytotoxic T cell', 'monocyte', 'pericyte cell', 'capillary endothelial cell', 'NKT cells', 'CD14-positive, CD16-positive monocyte', 'CD8+ T cells', 'Monocyte progenitors'.
"""

prompt = PromptTemplate(
    input_variables=["top100_genes", "predicted_markers", "retrieved_markers"],
    template=template,
)

# 定义起始和结束索引
start_index = 0  
end_index = 900   

with open('/home/yu.zhiyin/CellRAG/data/llm_predictions_whole_graph_gpt4omini.txt', mode='w') as file:

    for i in range(start_index, min(end_index, len(top100_genes_list))):
        top100_genes = top100_genes_list[i]
        predicted_cell_types = predicted_cell_types_list[i]
        retrieved_cell_types = retrieved_cell_types_list[i]
        
        unique_predicted_cell_types = remove_duplicates(predicted_cell_types)
        unique_retrieved_cell_types = remove_duplicates(retrieved_cell_types)
        
        predicted_markers_sentence = get_marker_sentence(unique_predicted_cell_types, cell_marker_dict)
        retrieved_markers_sentence = get_marker_sentence(unique_retrieved_cell_types, cell_marker_dict)

        filled_prompt = prompt.format(
            top100_genes=top100_genes,
            predicted_markers=predicted_markers_sentence,
            retrieved_markers=retrieved_markers_sentence
        )

        chain = LLMChain(llm=llm, prompt=prompt)
        response = chain.run({
            "top100_genes": top100_genes,
            "predicted_markers": predicted_markers_sentence,
            "retrieved_markers": retrieved_markers_sentence
        }).strip()  

        file.write(f"{response}\n")
        
        print(f"Cell {i+1}: Response saved to file.")

    print("All responses have been saved to llm_predictions.txt.")
