# Gets from Pub Med and process it so it can be chunked, vectorized, and stored in Qdrant

In [10]:
from Bio import Entrez
from lxml import etree
from io import BytesIO
import re
import spacy
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from collections import defaultdict
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
import time

In [5]:
Entrez.email = "charlie.kotula@gmail.com"

# Search query for getting relevant research articles
query = """
(
  rehabilitation AND "physical therapy" OR "return to sport" OR "return to play"
) AND (
  injury OR surgery OR postoperative OR musculoskeletal
) AND (
  exercise OR "therapeutic exercise" OR training
) AND (
  review[pt] OR systematic review[pt] OR meta-analysis[pt]
)
"""

def get_ids_with_metadata(query):
    """
    Gets PMC article UIDs, PMCIDs, and titles based on search query

    Args: query (str) - search query used to retrieve PMC articles
    Returns: metadata (dict) - dictionary containing PMCID and titles corresponding
        to the articles UID
    """
    # Get relevant UIDs and titles
    metadata = []
    handle = Entrez.esearch(
        db='pmc',
        term=query,
        retmax=10, # CHANGE
    )
    
    # Get relevant articles
    uids = Entrez.read(handle)['IdList']
    handle.close()
    
    # Get summaries for metadata
    summary = Entrez.esummary(
        db='pmc',
        id=','.join(uids)
    )
    records = Entrez.read(summary)
    
    # Map UIDs to titles and pmids
    for rec in records:
        title = rec['Title'].lower()
        title = re.sub(r'[^a-z0-9]+', '_', title)
    
        metadata.append(
            {
                'uid': rec['Id'],
                'pmcid': rec['ArticleIds']['pmcid'],
                'title': title
            }
        )

    return metadata

# Create metadata list to be used in multiprocessing
metadata = get_ids_with_metadata(query)
metadata

[{'uid': '12788749',
  'pmcid': 'PMC12788749',
  'title': 'comparisons_of_the_radiological_and_functional_results_of_femoral_tunnels_created_in_figure_4_and_110_flexion_positions_in_arthroscopic_anterior_cruciate_ligament_reconstruction'},
 {'uid': '12788292',
  'pmcid': 'PMC12788292',
  'title': 'assessment_and_measurement_of_the_side_effects_of_an_evidence_based_intervention_with_an_advanced_smart_cricket_ball_exemplified_by_a_case_report_on_correcting_illegal_bowling_action'},
 {'uid': '12788198',
  'pmcid': 'PMC12788198',
  'title': 'impact_of_anatomical_placement_on_the_accuracy_of_wearable_heart_rate_monitors_during_rest_and_various_exercise_intensities'},
 {'uid': '12788087',
  'pmcid': 'PMC12788087',
  'title': 'from_laboratory_to_field_concurrent_validity_of_kinovea_s_linear_kinematics_tracking_tool_for_semi_automated_countermovement_jump_analysis'},
 {'uid': '12787239',
  'pmcid': 'PMC12787239',
  'title': 'clinical_outcomes_of_arthroscopic_treatment_for_triangular_fibrocarti

In [8]:
######################################################
# Functions to extract, clean, and chunk text from PMC
def get_xml(pmc_id):
    """
    Returns the xml tree representation of the PMC article corresponding to
    the input UID.
    """
    handle = Entrez.efetch(
        db='pmc',
        id=pmc_id,
        retmode='xml',
        # rettype='full'
    )

    xml_dat = handle.read()

    # Converts xml bytes to tree
    xml_tree = etree.parse(BytesIO(xml_dat))
    
    return xml_tree

def get_text(xml):
    """
    Returns a dictionary of {section title: content} for the xml tree root.
    """
    text = []

    root = xml.getroot()

    # Remove references
    for xref in root.xpath('.//xref'):
        parent = xref.getparent()
        if parent is None:
            continue

        # removes punction surrounding references
        prev = xref.getprevious()

        # Handles punctuation before ref
        if prev is not None and prev.tail is not None:
            prev.tail = re.sub(r'[\[\(]\s*$', ' ', prev.tail)
        else:
            # xref is the first child â†’ clean parent.text
            if parent.text:
                parent.text = re.sub(r'[\[\(]\s*$', ' ', parent.text)   

        # Handles punctuation after ref
        if xref.tail:
            xref.tail = re.sub(r'^\s*[\]\)]*', ' ', xref.tail)
            
        parent.remove(xref)
            
    
    for sec in root.xpath('.//body//sec'):
        title = sec.findtext('title')
        if not title:
            continue
        title = title.lower()

        # Gets paragraphs from each section
        paragraphs = [
            ''.join(p.itertext()) for p in sec.findall('p')
        ]
    
        # Add sections to sections list
        if paragraphs: # ignores empty sections
            text.append((title , ' '.join(paragraph for paragraph in paragraphs)))
    
    return text

def clean_text(text):
    """
    Cleans article text, removing extra spaces, etc.
    """
    cleaned_text = []
    for section, words in text:
        words = re.sub(r'\s+', ' ', words)
        words = words.replace('\xa0', ' ').strip()
        cleaned_text.append((section, words))
    
    return cleaned_text

def chunk_text(cleaned_text, uid, pmcid, title,):
    """
    Takes cleaned article text and chunks it into LangChain Documents
    """
    docs = []
    
    for section, text in cleaned_text:
        # Create section label for metadata
        section = section.lower()
        section = re.sub(r'[^a-z0-9]+', '_', section)
    
        # Chunk text
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=500,
            chunk_overlap=50,
            separators=['\n\n', '\n', '. ', ' ', '']
        )
        chunks = splitter.split_text(text)
    
        # Create langchain Documents
        for i, chunk in enumerate(chunks):
            doc=Document(
                page_content=chunk,
                metadata={
                    "uid": uid,
                    "pmcid": pmcid,
                    "article": title,
                    "section": section,
                    "chunk_id": f'{uid}-{section}-{i}'
                }
            )

            docs.append(doc)

    return docs

def process_article(article):
    # get metadata for Document creation
    uid = article['uid']
    pmcid = article['pmcid']
    title = article['title']
    
    # extract sections
    xml = get_xml(uid)
    text = get_text(xml)

    # clean sections
    cleaned_text = clean_text(text)

    # chunk text and create LangChain Documents with metadata
    docs = chunk_text(cleaned_text, uid, pmcid, title)
    return docs

In [9]:
from data_preprocessing import process_article

if __name__ == "__main__":
    #### Multiprocessing of articles ####
    documents = []
    
    ### Processes articles one at a time
    # for article in tqdm(metadata):
    #     # process article (extract text, clean, chunk)d
    #     docs = process_article(article)
    #     documents.append(docs)
    
    ### Processes multiple articles in parallel
    with ProcessPoolExecutor(max_workers=8) as executor:
        futures = [executor.submit(process_article, article) for article in metadata]
    
        for future in as_completed(futures):
            documents.extend(future.result())

# Embedding using OpenAI and Qdrant

### To run the Qdrant docker:

`docker run -p 6333:6333 -p 6334:6334 \
    -v "$(pwd)/qdrant_storage:/qdrant/storage:z" \
    qdrant/qdrant`

In [None]:
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from langchain_openai.embeddings import OpenAIEmbeddings
import uuid

In [None]:
client = QdrantClient(url="http://localhost:6333")

client.create_collection(
    collection_name="rehab_collection",
    vectors_config=VectorParams(
        size=3072,
        distance=Distance.COSINE
    )
)

In [None]:
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")

In [None]:
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]

In [None]:
vectors = embeddings.embed_documents(texts)

In [None]:
# create points for vector db
points = []

for vec, doc in zip(vectors, documents):
    point_id = uuid.uuid5(uuid.NAMESPACE_DNS, doc.metadata['chunk_id'])
    
    points.append(
        PointStruct(
            id=point_id,
            vector=vec,
            payload={
                **doc.metadata,
                'text': doc.page_content
            }
        )
    )

In [None]:
points[0].id

In [None]:
# upsert to db in batches
BATCH_SIZE = 100

def batched(points, batch_size):
    for i in range(0, len(points), batch_size):
        yield points[i : i + batch_size]

for point_batch in batched(points, BATCH_SIZE):
    operation_info = client.upsert(
        collection_name="rehab_collection",
        wait=True,
        points=point_batch
    )
    
    print(operation_info)

In [None]:
# checking db
client.count(collection_name='rehab_collection', exact=True)

In [None]:
# wipe db
client.delete_collection('rehab_collection')

In [None]:
# testing search
