In [6]:
from transformers import AutoTokenizer, AutoModel
import torch
from data_gatherer.orchestrator import Orchestrator
import os
from lxml import etree
import torch.nn.functional as F
import json
import pandas as pd
import dspy

In [7]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")  # google-bert/bert-base-uncased
model = AutoModel.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
# model.eval()

In [8]:
def get_embedding(text: str) -> torch.Tensor:
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    # Mean pooling
    last_hidden = outputs.last_hidden_state
    mask = inputs['attention_mask'].unsqueeze(-1).expand(last_hidden.size())
    summed = (last_hidden * mask).sum(1)
    count = mask.sum(1)
    return (summed / count).squeeze()

In [9]:
data_gatherer = Orchestrator('config.json')

orchestrator.py - line 40 - INFO - Data_Gatherer Orchestrator initialized. Extraction step Model: gemini-2.0-flash


In [10]:
# find all the files in the html_xml_dir directory
files = []
for root, dirs, file_names in os.walk('../' + data_gatherer.config['html_xml_dir']):
    for file_name in file_names:
        if file_name.endswith('.xml'):
            files.append(os.path.join(root, file_name))
print(f"Found {len(files)} XML files in {data_gatherer.config['html_xml_dir']}")

Found 27 XML files in html_xml_samples/


In [11]:
def extract_paragraphs_from_xml(xml_root) -> list[dict]:
    """
    Extract paragraphs and their section context from an XML document.
    
    Args:
        xml_root: lxml.etree.Element — parsed XML root.

    Returns:
        List of dicts with 'paragraph', 'section_title', and 'sec_type'.
    """
    paragraphs = []

    # Iterate over all section blocks
    for sec in xml_root.findall(".//sec"):
        sec_type = sec.get("sec-type", "unknown")
        title_elem = sec.find("title")
        section_title = title_elem.text.strip() if title_elem is not None and title_elem.text else "No Title"

        for p in sec.findall(".//p"):
            itertext = " ".join(p.itertext()).strip()
            para_text = etree.tostring(p, encoding="unicode", method="xml").strip()
            if len(para_text) >= 5:  # avoid tiny/junk paragraphs
                paragraphs.append({
                    "paragraph": para_text,
                    "section_title": section_title,
                    "sec_type": sec_type,
                    "text": itertext
                })
                #print(f"Extracted paragraph: {paragraphs[-1]}")

    return paragraphs

In [12]:
def extract_sections_from_xml(xml_root) -> list[dict]:
    """
    Extract sections from an XML document.
    
    Args:
        xml_root: lxml.etree.Element — parsed XML root.

    Returns:
        List of dicts with 'section_title' and 'sec_type'.
    """
    sections = []

    # Iterate over all section blocks
    for sec in xml_root.findall(".//sec"):
        sec_type = sec.get("sec-type", "unknown")
        title_elem = sec.find("title")
        section_title = title_elem.text.strip() if title_elem is not None and title_elem.text else "No Title"
        
        section_text_from_paragraphs = f'{section_title}\n'
        section_rawtxt_from_paragraphs = ''

        for p in sec.findall(".//p"):
            
            itertext = " ".join(p.itertext()).strip()
            
            if len(itertext) >= 5:
                section_text_from_paragraphs += "\n" + itertext + "\n"
            
            para_text = etree.tostring(p, encoding="unicode", method="xml").strip()
            
            if len(para_text) >= 5:  # avoid tiny/junk paragraphs
                section_rawtxt_from_paragraphs += "\n" + para_text + "\n"
                
        sections.append({
                        "raw_sec_txt": section_rawtxt_from_paragraphs,
                        "section_title": section_title,
                        "sec_type": sec_type,
                        "sec_txt": section_text_from_paragraphs
                    })
    return sections

In [13]:
corpus = []
for i,file in enumerate(files):
    print(f"Processing file {i+1}: {file}")
    with open(file, 'rb') as f:  # ✅ open in binary mode
        xml_root = etree.fromstring(f.read())
   
    sections = extract_sections_from_xml(xml_root) # called at the beginning of LLM_Parser.parse_data
    print(f"Extracted {len(sections)} sections")
    
    query = "dataset reference"
    query_vec = get_embedding(query)
    print(f"Query '{query}' vectorized to shape {query_vec.shape}")
    
    for j,sect in enumerate(sections):
        sect_vec = get_embedding(sect['sec_txt'])
        score = F.cosine_similarity(
            F.normalize(query_vec.unsqueeze(0), dim=1), 
            F.normalize(sect_vec.unsqueeze(0), dim=1), 
            dim=1
        ).item()
        sect['embedding'] = sect_vec
        sect['score'] = score
        sect['source'] = file
        print(f"File_{i+1}-Section_{j} score: {score:.4f} for sec_txt: {sect['sec_txt'][:100]}...")
    
    corpus.extend(sections)

Processing file 1: ../html_xml_samples/PMC/Recurrent WNT pathway alterations are frequent in relapsed small cell lung cancer.xml
Extracted 38 sections
Query 'dataset reference' vectorized to shape torch.Size([768])
File_1-Section_0 score: 0.8664 for sec_txt: Introduction

Lung cancer is the leading cause of cancer-related death. Nearly 13% of patients with ...
File_1-Section_1 score: 0.8727 for sec_txt: Results

Consistent with previously published studies, the total tumor mutation burden (TMB) of sing...
File_1-Section_2 score: 0.8728 for sec_txt: The mutational landscape of relapsed SCLCs

Consistent with previously published studies, the total ...


KeyboardInterrupt: 

In [None]:
import dspy
from sentence_transformers import SentenceTransformer

def embedder(texts):
    return embedder_model.encode(texts, convert_to_numpy=True)

for i,file in enumerate(files):
    print(f"Processing file {i+1}: {file}")
    with open(file, 'rb') as f:  # ✅ open in binary mode
        xml_root = etree.fromstring(f.read())
   
    sections = extract_sections_from_xml(xml_root) # called at the beginning of LLM_Parser.parse_data
    print(f"Extracted {len(sections)} sections")
    
    if len(sections) == 0:
        print(f"No sections found in file {file}. Skipping.")
        continue

    query = "Data available with accession code ABC0123 in Repository XYZ"
    
    texts = [section['sec_txt'] for section in sections]
    
    embedder_model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L3-v2")
    
    
    retriever = dspy.retrievers.Embeddings(
        corpus=texts,
        embedder=embedder,
        k=1  # number of results to return
    )
    
    result = retriever(query)
    
    for p in result.passages:
        #print("Retrieved:", p)
        
        # check if the section is in the corpus
        for section in sections:
            if section['sec_txt'] == p:
                print("Found in corpus!!")
                break
        else:
            print("Not found in corpus")
            print("Section text:", p)

In [None]:
type(corpus), len(corpus), corpus[0].keys(), corpus[0].values()

In [None]:
for section in corpus:
    # Convert the embedding tensor to a list for JSON serialization
    section['embedding'] = section['embedding'].tolist() if isinstance(section['embedding'], torch.Tensor) else section['embedding']
    # Convert the score to a float
    section['score'] = float(section['score'])

In [None]:
# write corpus to a file
with open('pubmed_section_corpus_PubMedBERT-embeddings.json', 'w') as f:
    json.dump(corpus, f, indent=2)

In [None]:
# json to excel
df = pd.DataFrame(corpus)
df.to_excel('PubMedBERT_pubmed_paragraphs_corpus-embeddings.xlsx', index=False)