In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# packages
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time

# from goatools import obo_parser

# 1. Pubtator API (TODO)
- using the input folder, check if the file(s) have been pre-annotated by checking if the following APIs return something
```
https://www.ncbi.nlm.nih.gov/research/pubtator3-api/publications/export/biocxml?pmids={id}
or
https://www.ncbi.nlm.nih.gov/research/pubtator3-api/publications/pmc_export/biocxml?pmcids=PMC{id}
```
  - if annotated, retrieve it and put it in the output folder
  - if not, use the example code they provide to send it to the server to be annotated

# 2. Relevance (TODO)
- This can be integrated with Gene Chunking. We can just check if the paper has genes (i.e check if there is an item in the dictionary)

*3. Gene Chunking*
=============
Given a **preannotated** article (check input folder) analyze the **XML** file via `parse_xml_file(xml_file_path)` which returns a dictionary of "Gene Class" from that file  

The Gene Class has the following information:
- self.gene_id: the gene id (from the xml file)
- self.occurences: A list of occurences of the gene in the text. Each occurence is **3-sentences long**.
    - The second sentence contains the explicit mention of the gene.
    - The first and third sentences are the context of the gene mention.
- self.symbol: the gene symbol (from NCBI. Initialized to None)
- self.organism: the organism of the gene (from NCBI. Initialized to None)
- self.full_name: the full name of the gene (from NCBI. Initialized to None. Set to name_from_article if not found in NCBI)
- self.also_known_as: a list of other names of the gene (from NCBI. Initialized to None. Set to name_from_article if not found in NCBI)
- self.name_from_article: the name of the gene as found in the article

Steps:
1. define the class Gene
2. Parse the xml file
3. update the Gene class with the information from NCBI
Now you have a dictionary of Gene classes.~

TIPS:
- control + F "sentence_buffer" to find where you can adjust the buffer

In [None]:
!pip install biopython
import xml.etree.ElementTree as ET
from Bio import Entrez
import re
import time

class Gene:
    def __init__(self, gene_id):
        self.gene_id = gene_id
        self.occurrences = []  # list to store snippet(s) where the gene was mentioned
        self.symbol = None
        self.organism = None
        self.full_name = None
        self.also_known_as = None
        self.name_from_article = None

    def add_occurrence(self, snippet):
        if snippet not in self.occurrences:  # Avoid duplicates
            self.occurrences.append(snippet)

    def set_name_from_article(self, name_from_article):
        """Sets the temporary name of the gene. This is the name accroding to the article"""
        # will be used if the official name is not available
        self.name_from_article = name_from_article

    def get_name_from_article(self):
        return self.name_from_article

    def get_occurrences(self):
        return self.occurrences

    def get_also_known_as(self):
        return self.also_known_as

    def get_gene_id(self):
        return self.gene_id

    def update_info(self, symbol, organism, full_name, also_known_as):
        self.symbol = symbol
        self.organism = organism
        self.full_name = full_name
        self.also_known_as = also_known_as

    def __repr__(self):
        return (f"Gene({self.gene_id})\n"
                f"  Symbol          : {self.symbol}\n"
                f"  Organism        : {self.organism}\n"
                f"  Full Name       : {self.full_name}\n"
                f"  Also Known As   : {self.also_known_as}\n"
                f"  In-text Name    : {self.name_from_article}\n"
                f"  Occurrences     : {self.occurrences}")   # occurrences is a list of 3-sentence snippets

def parse_xml_file(xml_path):
    """Parses the XML file and returns a gene dictionary keyed by gene ID."""
    tree = ET.parse(xml_path)
    root = tree.getroot()
    gene_dict = {}

    for document in root.findall('document'):
        for passage in document.findall('passage'):
            section_type_elem = passage.find("infon[@key='section_type']")
            if section_type_elem is not None and (section_type_elem.text.upper() == "METHODS") or (section_type_elem.text.upper() == "FIG") or (section_type_elem.text.upper() == "TABLE"):
                continue  # Skip passages under METHODS
            # Get the full passage text.
            passage_text_elem = passage.find("text")
            passage_text = passage_text_elem.text if passage_text_elem is not None else ""
            # determine the starting offset for this passage.
            passage_offset_elem = passage.find("offset")
            passage_offset = int(passage_offset_elem.text) if passage_offset_elem is not None else 0
            # Split the passage into sentences using regex.
            sentences = re.split(r'(?<=[.!?])\s+', passage_text)
            # Compute start indices for each sentence within the passage text.
            start_indices = []
            current_index = passage_offset
            for sentence in sentences:
                start_indices.append(current_index)
                current_index += len(sentence) + 1  # account for the delimiter space

            processed_ranges = set()  # To track which sentences have already been covered

            # Process each annotation in the passage.
            for annotation in passage.findall('annotation'):
                ann_type = annotation.find("infon[@key='type']")
                if ann_type is not None and ann_type.text == "Gene":
                    # Check for both 'identifier' and 'NCBI Gene' keys
                    gene_id_elem = annotation.find("infon[@key='identifier']")
                    if gene_id_elem is None:
                        gene_id_elem = annotation.find("infon[@key='NCBI Gene']")
                    gene_id = gene_id_elem.text if gene_id_elem is not None else None

                    # Extract the gene name from the annotation text. (Temporary name if official name is not available)
                    in_text_gene_name_elem = annotation.find("text")
                    in_text_gene_name = in_text_gene_name_elem.text if in_text_gene_name_elem is not None else None

                    # Extract the annotation location (offset) to find the sentence containing the gene.
                    location_elem = annotation.find("location")
                    ann_offset = int(location_elem.attrib.get('offset', 0)) if location_elem is not None else 0

                    # Determine which sentence contains the annotation based on its offset.
                    sentence_index = None
                    for i, start in enumerate(start_indices):
                        if start <= ann_offset < start + len(sentences[i]):
                            sentence_index = i
                            break

                    # sentence_buffer
                    if sentence_index is not None:
                        start_sentence = max(0, sentence_index - 1)  # one sentence before
                        end_sentence = min(len(sentences), sentence_index + 2)  # one sentence after
                        range_tuple = (start_sentence, end_sentence)
                        if range_tuple in processed_ranges:
                            continue  # Skip duplicate extractions
                        processed_ranges.add(range_tuple)
                        snippet = " ".join(sentences[start_sentence:end_sentence])
                        if gene_id:
                            if gene_id in gene_dict:
                                gene_dict[gene_id].add_occurrence(snippet)
                            else:
                                gene_obj = Gene(gene_id)
                                gene_obj.add_occurrence(snippet)
                                gene_obj.set_name_from_article(in_text_gene_name)       # Temporary name if official name is not available
                                gene_dict[gene_id] = gene_obj
    return gene_dict

def fetch_and_update_gene_info(gene_dict):
    """Retrieves gene information from NCBI and updates the genes in gene_dict."""
    gene_ids = list(gene_dict.keys())
    if gene_ids:
        # Post the gene IDs to NCBI.
        handle = Entrez.epost(db="gene", id=",".join(gene_ids))
        result = Entrez.read(handle)
        handle.close()

        webenv = result["WebEnv"]
        query_key = result["QueryKey"]

        handle = Entrez.esummary(db="gene", webenv=webenv, query_key=query_key)
        record = Entrez.read(handle)
        handle.close()

        for docsum in record["DocumentSummarySet"]["DocumentSummary"]:
            gene_id = docsum.attributes["uid"]
            symbol = docsum.get('NomenclatureSymbol', 'No symbol')
            organism = docsum.get('Organism', {}).get('ScientificName', 'No organism')
            full_name = docsum.get('NomenclatureName', gene_dict[gene_id].get_name_from_article())
            also_known_as = docsum.get('OtherAliases', gene_dict[gene_id].get_name_from_article())

            # if full name and also known as are not available, use the name from the article
            if (full_name == ''):
                full_name = gene_dict[gene_id].get_name_from_article()
            if (also_known_as == ''):
                also_known_as = gene_dict[gene_id].get_name_from_article()

            if gene_id in gene_dict:
                gene_dict[gene_id].update_info(symbol, organism, full_name, also_known_as)
            # Pause briefly to avoid overwhelming NCBI servers.
            time.sleep(0.5)     # 2 requests per second (safe). WITH API KEY, be increased if needed to 10 requests per second.


# Set your email (and API key if available)
Entrez.email = "email here"
# Entrez.api_key = "your_api_key"

# example2.xml takes a while since it has 64 genes.
xml_path = "/content/drive/MyDrive/GOLLM/output/full_text_annotated_example.xml"
gene_dict = parse_xml_file(xml_path)
fetch_and_update_gene_info(gene_dict)           # Fetch gene information from NCBI (Optional but has good information)
for gene in gene_dict.values():
    print(gene)
    print(f"There are {len(gene.get_occurrences())} occurrences (each are AT MOST 3-setences long SO MAX of {int(3*len(gene.get_occurrences()))} sentences in total) of {gene.get_gene_id()}:{gene.get_name_from_article()}.")
    print("")

Gene(2670)
  Symbol          : GFAP
  Organism        : Homo sapiens
  Full Name       : glial fibrillary acidic protein
  Also Known As   : ALXDRD
  In-text Name    : glial fibrillary acidic protein
  Occurrences     : ['We found that a lectin, Datura stramonium agglutinin, induced irreversible differentiation in C6 glioma cells. The differentiated cells had long processes, a low rate of proliferation and a high content of glial fibrillary acidic protein. When the medium was replaced with Datura stramonium agglutinin-free medium after 1 h, cell proliferation continued to be inhibited.', 'Proliferation of four human glial tumour cells was also inhibited by Datura stramonium agglutinin. Further, these differentiated human glial tumour cells had long processes and a high content of glial fibrillary acidic protein similar to differentiated C6 glioma cells. Taken together, these observations suggest that Datura stramonium agglutinin may be useful as a new therapy for treating glioma withou

# 3.5 Prompt Engineering
- Creates the prompt using the Gene dictionary (from Gene chunking)
- batch_size: allows you to have more than 1 Gene in the prompt such that you have something like the following. Ideally keep it at batch_size=1 to ensure that the GO terms generated are exclusive to the Gene being analyzed.

```
# batch_size > 1 has the basic format:
Basic Prompt
Gene 1
Gene 2
...
# batch_size = 1 has the basic format:
Basic Prompt
Gene 1
```
- batch_index: allows you to iterate through the gene dictionary so that if you have a paper with x amount of genes, you can just increment  the index and run the inference x amount of times.
- direct: ignore this. This is for the future.
- customized_prompt: ignore this as well. This is for the future.

In [15]:
def make_go_term_prompt(gene_dict, batch_size=1, batch_index=0, direct=False, customized_prompt=None):
    """
    Create a prompt for Llama to output Gene Ontology (GO) terms (without IDs) for a batch of genes
    from the gene dictionary. The output is expected to be categorized into three sections for each gene:
      1. Biological Processes
      2. Cellular Components
      3. Molecular Functions

    For each gene, the output should follow the format below without mixing results across different genes.

    :param gene_dict: A dictionary where each key is a gene ID and each value is a Gene object that
                      contains attributes such as symbol, organism, full_name, and occurrences.
    :param batch_size: The number of genes to include in this prompt batch.
    :param batch_index: The index (0-based) of the batch to process. For instance, if batch_size=1,
                        batch_index=2 will include the 3rd gene in the dictionary.
    :param direct: If True, use a more direct instruction for GO term extraction.
    :param customized_prompt: If provided, use this custom prompt text in place of the default.
    :return: A string containing the complete prompt.
    """

    # Instruction blocks

    # context = """You are an efficient and insightful assistant to a geneticist."""

    general_instructions = """Analyze the following gene information and text snippets under "Occurrences".
Identify all potentially relevant Gene Ontology (GO) terms.
Base your analysis on prior knowledge available in your training data.

After completing your analysis, for each Gene Ontology (GO) term you list, provide a short one-line justification referencing the text as well as a relationship between the gene and the term.
If you cannot identify any GO terms for a particular gene information, say “None found.”

Be concise, do not use unnecessary words.
Be factual, do not editorialize.
Be specific, avoid overly general statements like "it is involved in many cellular processes.

"""

    example = """To help you in your work, I am providing an example of gene information and the corresponding example analysis output structured in JSON.

Example gene information is:
Gene: myogenin
Organism: Mus musculus
Occurrences:
- Activation of NF-kappaB increases the expression of the inducible nitric oxide synthase (iNOS) in muscle cells that sequester HuR (RNA-binding protein) by preventing transcription of MyoD. Denervation-induced atrophy is related to the upregulation of specific histone deacetylases (HDAC), able to repress a negative regulator of myogenin, with an increase in MuRF1 expression and muscle wasting. AMPK (protein kinase activated by 5'adenosine monophosphate) is also involved in muscle atrophy.

Example anaylysis output is:
[
  {
    "gene": "myogenin",
    "go_data": [
      {
        "term": "DNA-binding transcription activator activity",
        "relationship": "enables",
        "namespace": "molecular_function",
        "justification": "Myogenin is a key transcription factor in muscle differentiation that binds to E-box sequences in muscle-specific genes and activates their transcription through RNA polymerase II."

      },
      {
        "term": "skeletal muscle cell differentiation",
        "relationship": "involved_in",
        "namespace": "biological_process"
        "justification": "Myogenin is one of the four myogenic regulatory factors (MRFs) that drive skeletal muscle differentiation by promoting the transition from myoblasts to myotubes."
      }]
    }
]
"""


    # Assemble the prompt text based on the chosen mode
    if direct:
        # prompt_text = context
        # prompt_text += direct_instructions
        # prompt_text += format_instructions
        # prompt_text += example_output
        pass
    elif customized_prompt:
        # prompt_text = context
        # prompt_text += customized_prompt
        # prompt_text += format_instructions
        # prompt_text += example_output
        pass
    else:
        prompt_text = general_instructions
        prompt_text += example
        prompt_text += "\nHere is the gene information:\n"

    # Convert the gene dictionary to a list and determine the batch slice.
    gene_items = list(gene_dict.items())
    start_index = batch_index * batch_size
    end_index = start_index + batch_size
    batch_genes = gene_items[start_index:end_index]

    # Add details and occurrence context for each gene in the selected batch.
    for gene_id, gene_obj in batch_genes:
        # prompt_text += f"\nGene ID: {gene_id}\n"
        # prompt_text += f"Symbol: {gene_obj.symbol}\n"
        # prompt_text += f"Organism: {gene_obj.organism}\n"
        # prompt_text += f"Full Name: {gene_obj.full_name}\n"
        prompt_text += f"\nGene: {gene_obj.get_name_from_article()}\n"
        prompt_text += "Occurrences:\n"
        for occ in gene_obj.occurrences:
            prompt_text += f"- {occ}\n"

    return prompt_text

# Example usage:
# Assuming gene_dict is defined and contains Gene objects with attributes: symbol, organism, full_name, and occurrences.
# To process only one gene at a time:
# prompt = make_go_term_prompt(gene_dict, batch_size=1, batch_index=0)
# print(prompt)

prompt = make_go_term_prompt(gene_dict)
print(prompt)

Analyze the following gene information and text snippets under "Occurrences".
Identify all potentially relevant Gene Ontology (GO) terms.
Base your analysis on prior knowledge available in your training data.

After completing your analysis, for each Gene Ontology (GO) term you list, provide a short one-line justification referencing the text as well as a relationship between the gene and the term.
If you cannot identify any GO terms for a particular gene information, say “None found.”

Be concise, do not use unnecessary words.
Be factual, do not editorialize.
Be specific, avoid overly general statements like "it is involved in many cellular processes.

To help you in your work, I am providing an example of gene information and the corresponding example analysis output structured in JSON.

Example gene information is:
Gene: myogenin
Organism: Mus musculus
Occurrences:
- Activation of NF-kappaB increases the expression of the inducible nitric oxide synthase (iNOS) in muscle cells that s

# 4. Inference (Local)
The two cells below rely on local inference meaning that you must have the model locally. Skip them if youre using the API like GROQ

In [None]:
print("Loading LLaMA model.")
# model_dir = "/content/drive/MyDrive/Llama 3.2-3B-Instruct-model"
model_dir = "/content/drive/MyDrive/GOLLM/Llama 3.2-3B-Instruct-model"
device = torch.device("cuda")
print(torch.cuda.is_available())

tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(model_dir, local_files_only=True).to(device)


In [None]:
def run_inference(prompt):
    """
    Run inference on the provided prompt using the specified model and tokenizer.

    :param prompt: The prompt string to send to the model.
    :return: The generated text.
    """

    # Encode the prompt and move to the appropriate device
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    attention_mask = inputs["attention_mask"]

    # Greedy or minimal variation == When: You want consistent, short answers, and you don’t mind if it’s a bit “blunt.”
    # outputs = model.generate(
    #     input_ids=inputs["input_ids"],
    #     attention_mask=attention_mask,
    #     max_new_tokens=150,
    #     do_sample=False,     # no sampling
    #     num_beams=1,         # purely greedy
    #     no_repeat_ngram_size=3,
    #     pad_token_id=tokenizer.eos_token_id
    # )

    # Beam Search (Higher Quality / Less Randomness) == You want a more “global optimum” text.
    # outputs = model.generate(
    #     input_ids=inputs["input_ids"],
    #     attention_mask=attention_mask,
    #     max_new_tokens=200,
    #     do_sample=False,
    #     num_beams=4,      # search multiple beams
    #     length_penalty=1.0,  # see if you want to encourage or discourage long outputs
    #     no_repeat_ngram_size=3,
    #     pad_token_id=tokenizer.eos_token_id
    # )

    # Beam SearchV2
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=attention_mask,
        max_new_tokens=400,       # Enough tokens for a concise but thorough answer
        do_sample=False,          # Turn off sampling; we want a more deterministic, stable answer
        num_beams=3,              # Explore multiple beams for higher-quality completions
        length_penalty=0.1,       # 1.0 means "neutral" length preference (>=1.0 encourages longer outputs)
        no_repeat_ngram_size=4,   # Helps avoid repeating the same phrase
        early_stopping=True,      # Stops as soon as the best beam is complete
        pad_token_id=tokenizer.eos_token_id,
        temperature=None,
        top_p=None,
    )
    # Sampling (More Creative / Less Deterministic)
    # outputs = model.generate(
    #     input_ids=inputs["input_ids"],
    #     attention_mask=attention_mask,
    #     max_new_tokens=200,
    #     do_sample=True,       # sampling
    #     temperature=0.7,      # moderate creativity
    #     top_p=0.9,            # nucleus sampling
    #     no_repeat_ngram_size=3,
    #     pad_token_id=tokenizer.eos_token_id
    # )

    # Decode the output and return the generated text
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    completion = generated_text[len(prompt):].strip()
    return completion

for i in range(3):
    start_time = time.time()
    output = run_inference(test_prompt)
    print(output)

    # with open(f'/content/drive/MyDrive/GOLLM/Inference Outputs/output_{i+1}.txt', 'w') as file:
    #     file.write(output)

    print(f"-->Model loaded on {device} named {torch.cuda.get_device_name(0)} and it took {time.time()-start_time:.2f} seconds.")
    print("-"*50)

# prompt = "what is 2+2"
# print("Running inference...")
# start_time = time.time()
# output = run_inference(prompt)
# print(f"Model loaded on {device} in {time.time()-start_time:.2f} seconds.")
# print(output)


# 4v2. GROQ API Inference
- need API Key
- uses llama 3.3-70b-versatile

In [7]:
# using GROQ API
!pip install Groq
from groq import Groq

client = Groq(
    api_key="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"
)




In [16]:
prompt = make_go_term_prompt(gene_dict, batch_index=0)    # remember to increment the index by 1 for all genes rather than just the first one.
# Request a completion for gene analysis

chat_completion = client.chat.completions.create(
    messages=[
        {
            "role": "user",
            "content": prompt,
        }
    ],
    model="llama-3.3-70b-versatile",
    temperature=0,
    max_completion_tokens=2048,
    top_p=0.5,
    stream=False,
    response_format={"type": "json_object"},
    stop=None,
)

# Print the output from the model
print(chat_completion.choices[0].message.content)

{
   "gene":"glial fibrillary acidic protein",
   "go_data":[
      {
         "term":"intermediate filament",
         "relationship":"enables",
         "namespace":"cellular_component",
         "justification":"GFAP is expressed exclusively in astrocytes and is considered a marker of matured astrocytes."
      },
      {
         "term":"astrocyte differentiation",
         "relationship":"involved_in",
         "namespace":"biological_process",
         "justification":"GFAP is closely related to the differentiation of astrocytes and its introduction to C6 cells and human astrocytoma induced morphological changes."
      },
      {
         "term":"regulation of cell proliferation",
         "relationship":"involved_in",
         "namespace":"biological_process",
         "justification":"GFAP expression is closely related to the suppression of cell proliferation in astrocytes."
      },
      {
         "term":"central nervous system development",
         "relationship":"involve

# 5. Normalize
- use an embedding model (SAPBERT) to ensure that the terms generated by Llama are actual GO terms. You will need the Go.obo file from the [Gene Ontology Consortium](https://geneontology.org/docs/download-ontology/)
  - Essentially, for each go term created by Llama, Sapbert will do a similarity search against the Go.obo and retrieve the actual term + GO:ID