In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


*3. Gene Chunking*
=============
Given a **preannotated** article (check input folder) analyze the **XML** file via `parse_xml_file(xml_file_path)` which returns a dictionary of "Gene Class" from that file  

The Gene Class has the following information:
- self.gene_id: the gene id (from the xml file)
- self.occurences: A list of occurences of the gene in the text. Each occurence is **3-sentences long**. 
    - The second sentence contains the explicit mention of the gene. 
    - The first and third sentences are the context of the gene mention.
- self.symbol: the gene symbol (from NCBI. Initialized to None)
- self.organism: the organism of the gene (from NCBI. Initialized to None)
- self.full_name: the full name of the gene (from NCBI. Initialized to None. Set to name_from_article if not found in NCBI)
- self.also_known_as: a list of other names of the gene (from NCBI. Initialized to None. Set to name_from_article if not found in NCBI)
- self.name_from_article: the name of the gene as found in the article

Steps:
1. define the class Gene
2. Parse the xml file
3. update the Gene class with the information from NCBI
Now you have a dictionary of Gene classes.~

TIPS:
- control + F "sentence_buffer" to find where you can adjust the buffer 

In [3]:
import xml.etree.ElementTree as ET
from Bio import Entrez
import re
import time

class Gene:
    def __init__(self, gene_id):
        self.gene_id = gene_id
        self.occurrences = []  # list to store snippet(s) where the gene was mentioned
        self.symbol = None
        self.organism = None
        self.full_name = None
        self.also_known_as = None
        self.name_from_article = None

    def add_occurrence(self, snippet):
        if snippet not in self.occurrences:  # Avoid duplicates
            self.occurrences.append(snippet)

    def set_name_from_article(self, name_from_article):
        """Sets the temporary name of the gene. This is the name accroding to the article"""
        # will be used if the official name is not available
        self.name_from_article = name_from_article

    def get_name_from_article(self):
        return self.name_from_article

    def get_occurrences(self):
        return self.occurrences

    def get_also_known_as(self):
        return self.also_known_as
    
    def get_gene_id(self):
        return self.gene_id
    
    def update_info(self, symbol, organism, full_name, also_known_as):
        self.symbol = symbol
        self.organism = organism
        self.full_name = full_name
        self.also_known_as = also_known_as

    def __repr__(self):
        return (f"Gene({self.gene_id})\n"
                f"  Symbol          : {self.symbol}\n"
                f"  Organism        : {self.organism}\n"
                f"  Full Name       : {self.full_name}\n"
                f"  Also Known As   : {self.also_known_as}\n"
                f"  In-text Name    : {self.name_from_article}\n"
                f"  Occurrences     : {self.occurrences}")   # occurrences is a list of 3-sentence snippets

def parse_xml_file(xml_path):
    """Parses the XML file and returns a gene dictionary keyed by gene ID."""
    tree = ET.parse(xml_path)
    root = tree.getroot()
    gene_dict = {}

    for document in root.findall('document'):
        for passage in document.findall('passage'):
            section_type_elem = passage.find("infon[@key='section_type']")
            if section_type_elem is not None and section_type_elem.text.upper() == "METHODS":
                continue  # Skip passages under METHODS
            # Get the full passage text.
            passage_text_elem = passage.find("text")
            passage_text = passage_text_elem.text if passage_text_elem is not None else ""
            # determine the starting offset for this passage.
            passage_offset_elem = passage.find("offset")
            passage_offset = int(passage_offset_elem.text) if passage_offset_elem is not None else 0
            # Split the passage into sentences using regex.
            sentences = re.split(r'(?<=[.!?])\s+', passage_text)
            # Compute start indices for each sentence within the passage text.
            start_indices = []
            current_index = passage_offset
            for sentence in sentences:
                start_indices.append(current_index)
                current_index += len(sentence) + 1  # account for the delimiter space

            processed_ranges = set()  # To track which sentences have already been covered

            # Process each annotation in the passage.
            for annotation in passage.findall('annotation'):
                ann_type = annotation.find("infon[@key='type']")
                if ann_type is not None and ann_type.text == "Gene":
                    gene_id_elem = annotation.find("infon[@key='identifier']")
                    gene_id = gene_id_elem.text if gene_id_elem is not None else None
                    
                    # Extract the gene name from the annotation text. (Temporary name if official name is not available)
                    in_text_gene_name_elem = annotation.find("text")
                    in_text_gene_name = in_text_gene_name_elem.text if in_text_gene_name_elem is not None else None

                    # Extract the annotation location (offset) to find the sentence containing the gene.
                    location_elem = annotation.find("location")
                    ann_offset = int(location_elem.attrib.get('offset', 0)) if location_elem is not None else 0

                    # Determine which sentence contains the annotation based on its offset.
                    sentence_index = None
                    for i, start in enumerate(start_indices):
                        if start <= ann_offset < start + len(sentences[i]):
                            sentence_index = i
                            break
                    
                    # sentence_buffer
                    if sentence_index is not None:
                        start_sentence = max(0, sentence_index - 1)  # one sentence before
                        end_sentence = min(len(sentences), sentence_index + 2)  # one sentence after
                        range_tuple = (start_sentence, end_sentence)
                        if range_tuple in processed_ranges:
                            continue  # Skip duplicate extractions
                        processed_ranges.add(range_tuple)
                        snippet = " ".join(sentences[start_sentence:end_sentence])
                        if gene_id:
                            if gene_id in gene_dict:
                                gene_dict[gene_id].add_occurrence(snippet)
                            else:
                                gene_obj = Gene(gene_id)
                                gene_obj.add_occurrence(snippet)
                                gene_obj.set_name_from_article(in_text_gene_name)       # Temporary name if official name is not available
                                gene_dict[gene_id] = gene_obj
    return gene_dict

def fetch_and_update_gene_info(gene_dict):
    """Retrieves gene information from NCBI and updates the genes in gene_dict."""
    gene_ids = list(gene_dict.keys())
    if gene_ids:
        # Post the gene IDs to NCBI.
        handle = Entrez.epost(db="gene", id=",".join(gene_ids))
        result = Entrez.read(handle)
        handle.close()

        webenv = result["WebEnv"]
        query_key = result["QueryKey"]

        handle = Entrez.esummary(db="gene", webenv=webenv, query_key=query_key)
        record = Entrez.read(handle)
        handle.close()

        for docsum in record["DocumentSummarySet"]["DocumentSummary"]:
            gene_id = docsum.attributes["uid"]
            symbol = docsum.get('NomenclatureSymbol', 'No symbol')
            organism = docsum.get('Organism', {}).get('ScientificName', 'No organism')
            full_name = docsum.get('NomenclatureName', gene_dict[gene_id].get_name_from_article())
            also_known_as = docsum.get('OtherAliases', gene_dict[gene_id].get_name_from_article())
            
            # if full name and also known as are not available, use the name from the article
            if (full_name == ''):
                full_name = gene_dict[gene_id].get_name_from_article()
            if (also_known_as == ''):
                also_known_as = gene_dict[gene_id].get_name_from_article()
            
            if gene_id in gene_dict:
                gene_dict[gene_id].update_info(symbol, organism, full_name, also_known_as)
            # Pause briefly to avoid overwhelming NCBI servers.
            time.sleep(0.5)     # 2 requests per second (safe). WITH API KEY, be increased if needed to 10 requests per second.


# Set your email (and API key if available)
Entrez.email = "email here"
# Entrez.api_key = "your_api_key"

# example2.xml takes a while since it has 64 genes. 
xml_path = "input/full_text_annotated_example.xml"
gene_dict = parse_xml_file(xml_path)
fetch_and_update_gene_info(gene_dict)           # Fetch gene information from NCBI (Optional but has good information)
for gene in gene_dict.values():
    print(gene)
    print(f"There are {len(gene.get_occurrences())} occurrences (each are AT MOST 3-setences long SO MAX of {int(3*len(gene.get_occurrences()))} sentences in total) of {gene.get_gene_id()}:{gene.get_name_from_article()}.")
    print("")

Gene(2670)
  Symbol          : GFAP
  Organism        : Homo sapiens
  Full Name       : glial fibrillary acidic protein
  Also Known As   : ALXDRD
  In-text Name    : glial fibrillary acidic protein
  Occurrences     : ['We found that a lectin, Datura stramonium agglutinin, induced irreversible differentiation in C6 glioma cells. The differentiated cells had long processes, a low rate of proliferation and a high content of glial fibrillary acidic protein. When the medium was replaced with Datura stramonium agglutinin-free medium after 1 h, cell proliferation continued to be inhibited.', 'Proliferation of four human glial tumour cells was also inhibited by Datura stramonium agglutinin. Further, these differentiated human glial tumour cells had long processes and a high content of glial fibrillary acidic protein similar to differentiated C6 glioma cells. Taken together, these observations suggest that Datura stramonium agglutinin may be useful as a new therapy for treating glioma withou

In [None]:
# packages
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# from goatools import obo_parser

### 1. download the obo file from the gene ontology consortium https://current.geneontology.org/products/pages/downloads.html

In [3]:
print("Loading Gene Ontology database...")
GO_OBO_FILE = "/content/drive/MyDrive/goa_human.gaf"    # DESTINATION OF THE GO OBO FILE
# go_dag = obo_parser.GODag(GO_OBO_FILE)

def load_human_go_annotations(filepath=GO_OBO_FILE):
    """Load human-specific GO annotations from the GOA database"""
    go_annotations = {}

    with open(filepath, "r") as f:
        for line in f:
            if line.startswith("!"):  # Ignore header lines
                continue

            columns = line.strip().split("\t")
            if len(columns) < 9:
                continue  # Ensure we have enough columns

            gene_name = columns[2]  # DB_Object_Symbol (gene name)
            go_id = columns[4]  # GO term
            qualifier = columns[3]  # Relationship (enables, involved_in, etc.)
            aspect = columns[8]  # BP, MF, or CC

            if gene_name not in go_annotations:
                go_annotations[gene_name] = {"BP": [], "MF": [], "CC": []}

            if aspect == "P":
                go_annotations[gene_name]["BP"].append((go_id, qualifier))
            elif aspect == "F":
                go_annotations[gene_name]["MF"].append((go_id, qualifier))
            elif aspect == "C":
                go_annotations[gene_name]["CC"].append((go_id, qualifier))

    return go_annotations

# Load species-specific GO annotations
human_go_annotations = load_human_go_annotations()

# Check if insulin is in the dataset
if "INSULIN" in human_go_annotations:
    print("Found INSULIN in GO annotations!")
else:
    print("INSULIN NOT found in GO annotations!")



Loading Gene Ontology database...
INSULIN NOT found in GO annotations!


### load the tokenizer and the model and put it to cuda

In [4]:
print("Loading LLaMA model.")
model_dir = "/content/drive/MyDrive/Llama 3.2-3B-Instruct-model"
device = torch.device("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
import torch
print(torch.cuda.is_available())

model = AutoModelForCausalLM.from_pretrained(model_dir, local_files_only=True).to(device)


Loading LLaMA model.
True


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [15]:
def analyze_article(article_text: str) -> dict:
    """Process article text and infer Gene Ontology terms"""

    # Step 1: Extract genes/proteins using SciSpacy
    doc = nlp(article_text)

    detected_genes = {}
    genes = list({ent.text for ent in doc.ents
                if ent.label_ in ["GENE", "PROTEIN"]})

    print(f"Detected entities: {genes}")  # Debugging entity extraction

    # Ensure function always returns a dictionary with expected keys
    if not genes:
        return {"genes": [], "result_text": ""}

    # Step 3: Build structured prompt
    prompt = f"""You are an expert biomedical researcher. Analyze this biomedical article and extract Gene Ontology (GO) terms.

    Article excerpt:
    {article_text[:3000]}... [truncated]

    Detected entities: {', '.join(genes)}

    [OUTPUT FORMAT]
    Return **ONLY** GO terms in this structured format:
    - BP: GO:####### (Biological Process Name)
    - MF: GO:####### (Molecular Function Name)
    - CC: GO:####### (Cellular Component Name)

    Example Output:
    - BP: GO:0006006 (glucose metabolic process)
    - MF: GO:0005543 (insulin receptor binding)
    - CC: GO:0005886 (plasma membrane)
    """


    # Step 4: Generate predictions
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    outputs = model.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.1,
        top_p=0.95,
        pad_token_id=tokenizer.eos_token_id
    )

    # Step 5: Process and validate results
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    result_text = generated_text[len(prompt):].strip()

    print(f"\nGenerated GO Terms:\n{result_text}")  # Debugging LLM output

    return {"genes": genes, "result_text": result_text}

In [16]:
import re  # Import regex for better parsing

def remove_duplicates(go_terms):
    """Remove duplicate GO terms while preserving order"""
    seen = set()
    unique_terms = []
    for term in go_terms:
        if term not in seen:
            seen.add(term)
            unique_terms.append(term)
    return unique_terms

def parse_and_validate_results(result_text, genes):
    """Validate generated GO annotations against GO database"""

    # Detect broken output
    if "[IMPORTANT]" in result_text or "Please provide" in result_text:
        print("⚠️ Warning: The model failed to generate GO terms correctly.")
        return {gene: {"BP": [], "MF": [], "CC": []} for gene in genes}

    validated_results = {gene: {"BP": [], "MF": [], "CC": []} for gene in genes}

    bp_terms = re.findall(r"BP:\s*(GO:\d+ \(.+?\))", result_text)
    mf_terms = re.findall(r"MF:\s*(GO:\d+ \(.+?\))", result_text)
    cc_terms = re.findall(r"CC:\s*(GO:\d+ \(.+?\))", result_text)

    for gene in genes:
        validated_results[gene]["BP"] = remove_duplicates(bp_terms)
        validated_results[gene]["MF"] = remove_duplicates(mf_terms)
        validated_results[gene]["CC"] = remove_duplicates(cc_terms)

    print(f"\n✅ Parsed GO terms for {genes}: {validated_results}")  # Debugging

    return validated_results



# ----------------------------
# Test with the sample article
# ----------------------------
sample_article = "insulin is a hormone that regulates blood sugar levels"
print("\nAnalyzing sample article...")
analysis = analyze_article(sample_article)

if not analysis["genes"]:  # Ensure genes exist before proceeding
    print("No genes/proteins detected.")
else:
    genes = analysis["genes"]
    result_text = analysis["result_text"]

    print("\nResults:")
    print(parse_and_validate_results(result_text, genes))



Analyzing sample article...
Detected entities: ['insulin']

Generated GO Terms:
GO terms extracted from the article excerpt:
    - BP: GO:0008150 (carbohydrate metabolic process)
    - MF: GO:0005518 (protein binding)
    - CC: GO:0005886 (plasma membrane)
    - BP: GO:0006006 (glucose metabolic process)
    - MF: GO:0005543 (insulin receptor binding)
    - MF: GO:0005519 (insulin binding)
    - CC: GO:0009611 (extracellular space)
    - MF: GO:0005515 (transmembrane receptor activity)
    - MF: GO:0005516 (transmembrane receptor activity)
    - MF: GO:0005517 (receptor activity)
    - BP: GO:0008152 (glucose-6-phosphate metabolic process)
    - MF: GO:0005518 (protein binding)
    - MF: GO:0005519 (insulin binding)
    - MF: GO:0005515 (transmembrane receptor activity)
    - MF: GO:0005516 (transmembrane receptor activity)
    - MF: GO:0005517 (receptor activity)
    - BP: GO:0008150 (carbohydrate metabolic process)
    - MF: GO:0005543 (insulin receptor binding)
    - MF: GO:0005518