<a href="https://colab.research.google.com/github/Farhan99-hub/Mistral7B_Gene-Disease_RAG/blob/main/Mistral_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q transformers einops accelerate bitsandbytes langchain

In [None]:
!pip install -U langchain-community

In [None]:
!pip install langchain-huggingface

In [None]:
!pip install faiss-cpu

In [None]:
!huggingface-cli login

In [None]:
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
import pandas as pd
from langchain.schema import Document

In [None]:
clinvar = pd.read_excel("/content/clinvar_drop.xlsx")
curated = pd.read_excel("/content/curated_grouped.xlsx")
gad = pd.read_excel("/content/gad_disease_desc.xlsx")
hpo = pd.read_excel("/content/hpo_disease_sort_desc.xlsx")
omim = pd.read_excel("/content/omim_disease_desc.xlsx")

In [None]:
clinvar.head()

In [None]:
# Create a list to hold all documents
all_documents = []

# Function to process a dataframe and add its contents to all_documents
def process_dataframe(df, source_name):
    for _, row in df.iterrows():
        try:
            text = f"Disease: {row['Genes']} | Gene: {row['Gene_IDs']} | Disease_IDs: {row['Disease_IDs']} | Description: {row['Disease_desc']}"
            all_documents.append(Document(page_content=text, metadata={"source": source_name}))
        except KeyError:
            print(f"Warning: Skipping row due to missing keys in {source_name} dataframe.")
            print(row)


# Process each dataframe
process_dataframe(clinvar, "clinvar")
process_dataframe(curated, "curated")
process_dataframe(gad, "gad")
process_dataframe(hpo, "hpo")
process_dataframe(omim, "omim")

all_documents


[Document(metadata={'source': 'clinvar'}, page_content='Disease: COL11A1 | Gene: 1301 | Disease_IDs: MedGen_CN071412 | Description: A disease characterized by a group of signs and symptoms that occur together and characterize a particular abnormality.'),
 Document(metadata={'source': 'clinvar'}, page_content='Disease: COL11A1 | Gene: 1301 | Disease_IDs: OMIM_604841 | Description: A Stickler syndrome that has_material_basis_in heterozygous mutation in the COL11A1 gene on chromosome 1p21.'),
 Document(metadata={'source': 'clinvar'}, page_content='Disease: COL11A1 | Gene: 1301 | Disease_IDs: OMIM_154780 | Description: An ectodermal dysplasia characterized by hypoplasia of the maxilla, nasal bones, and frontal sinuses, as well as calvarial thickening, myopia, early-onset cataracts, and sensorineural hearing loss that has_material_basis_in heterozygous or homozygous mutation (most frequently affecting splice sites) in the COL11A1 gene on chromosome 1p21.1. Mutations, typically null, in the 

In [None]:
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2" )
vector_store = FAISS.from_documents(all_documents, embeddings)

  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2" )
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
num_docs = len(vector_store.docstore._dict)  # Get the length of the internal dictionary
print(f"Number of documents in the vector store: {num_docs}")

Number of documents in the vector store: 10758


In [None]:
model = "mistralai/Mistral-7B-Instruct-v0.3"

In [None]:
!pip install safetensors




In [None]:
from langchain.llms import HuggingFacePipeline
from transformers import pipeline
#create the pipeline
pipee = pipeline(
    "text-generation",
    model=model,
    torch_dtype="auto",
    device_map="cuda",
    max_new_tokens=512,
    do_sample=True,
    top_k=30,
    num_return_sequences=1,
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Device set to use cuda


In [None]:
custom_prompt = PromptTemplate(
    input_variables=["context", "question"],
    template="""
    You are a biomedical expert specializing in gene-disease relationships.
    Based on the retrieved database, answer the user's question with relevant gene and disease information.

    If the query is about a **gene**, return:
    - Disease(s) associated with the gene
    - Disease ID(s)
    - Disease description(s)

    If the query is about a **disease**, return:
    - Gene(s) linked to the disease
    - Gene ID(s)
    - Disease Description

    Only use the retrieved context and avoid making assumptions.

    Context: {context}
    User Query: {question}

    Answer:
    """
)


In [None]:
# def refine_prompt(user_query):
#     return f"""
#     The user asked: {user_query}

#     If the query mentions a **gene**, list all related **diseases** with their **IDs** and descriptions.
#     If the query mentions a **disease**, give disease description and list all associated **genes** with their **gene IDs**.

#     # Format output neatly.
#     """


In [None]:
from langchain.chains import RetrievalQA


# Initialize the QA model with the custom prompt
qa = RetrievalQA.from_chain_type(
    llm=local_llm,
    chain_type="stuff",
    retriever=vector_store.as_retriever(),
    chain_type_kwargs={"prompt": custom_prompt}  # Apply the prompt template
)



In [None]:

  query = "Hepatitis C virus"
  result = qa.run(query)
  answer_start = result.find("Answer:") + len("Answer:")
  final_answer = result[answer_start:].strip()

  print(final_answer)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


1. Gene(s) linked to Hepatitis C virus: CCR5, IL21R, IFNL3
    2. Gene ID(s): CCR5: OMIM:601373, IL21R: 50615, IFNL3: OMIM:607402

    This is based on the retrieved database, where Hepatitis C virus (Hepacivirus hominis) is shown to be associated with the genes CCR5, IL21R, and IFNL3. Hepatitis C virus is a viral infectious disease that results in inflammation located in the liver and has symptoms including fever, fatigue, loss of appetite, nausea, vomiting, abdominal pain, clay-colored bowel movements, joint pain, and jaundice.


In [None]:
query = "DHTKD1"
result = qa.run(query)
result
answer_start = result.find("Answer:") + len("Answer:")
final_answer = result[answer_start:].strip()

print(final_answer)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


1. DHTKD1 is associated with the following diseases:
        - Amino acid metabolic disorder (OMIM:204750) - This disorder involves the accumulation of argininosuccinic acid (ASA) in the blood and urine.
        - Charcot-Marie-Tooth disease type 2 (OMIM:615025) - This is a neurodegenerative disorder characterized by the degeneration of peripheral nerves, leading to symptoms in the arms and legs such as weakness, numbness, and muscle atrophy.
    2. DHTKD1 gene ID: 55526
    3. DHTKD1 gene is located on chromosome 10p14. The mutations in this gene have a material basis for the aforementioned diseases, either homozygous or compound heterozygous for the amino acid metabolic disorder, or heterozygous loss-of-function mutations for the Charcot-Marie-Tooth disease type 2.


In [None]:

query = "severe t lymphocytopenia"
result = qa.run(query)
answer_start = result.find("Answer:") + len("Answer:")
final_answer = result[answer_start:].strip()

print(final_answer)


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


1. Disease: T-Cell Deficiency
    2. Gene(s) linked to the disease: CD3G, RAG2, NBN, IGHM, CD8A, MAGT1, PTPRC
    3. Gene ID(s): OMIM:186740, 917, 64421, 100, 57724, 5788, 3575
    4. Disease_IDs: OMIM:615607, HP:0005403
    5. Disease Description: A T cell deficiency characterized by partial T-cell lymphopenia with normal numbers of B and NK cells and highly variable clinical severity that has_material_basis_in homozygous or compound heterozygous mutation in the relevant genes.
