## 01_IR

Task 1-4: ran on modal in a modal notebook

In [1]:
!pip install tqdm pubmed-parser

Collecting pubmed-parser
  Downloading pubmed_parser-0.5.1-py3-none-any.whl.metadata (17 kB)
Collecting lxml (from pubmed-parser)
  Downloading lxml-6.0.2-cp312-cp312-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl.metadata (3.6 kB)
Collecting unidecode (from pubmed-parser)
  Downloading Unidecode-1.4.0-py3-none-any.whl.metadata (13 kB)
Downloading pubmed_parser-0.5.1-py3-none-any.whl (56.9 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/56.9 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.5/56.9 MB[0m [31m176.0 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.0/56.9 MB[0m [31m44.0 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━[0m [32m36.7/56.9 MB[0m [31m62.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.9/56.9 MB[0m [31m73.8 M

### I Fetch pubmed

In [2]:
# I_fetch_pubmed.py --> i can run the original here on modal

import os
import pickle
import shutil
from tqdm import tqdm
from ftplib import FTP
from time import sleep
import pubmed_parser as pp
from urllib import request
from random import shuffle
from itertools import chain
from multiprocessing import Pool
from collections import defaultdict


num_workers = 10
base_url = 'https://ftp.ncbi.nlm.nih.gov/pubmed/baseline/'

medline_folder = 'pmid2contents'
os.makedirs(medline_folder, exist_ok=True)


def clean_title(title):
    """
    :param title:
    :return: Basic text cleaning for title
    """
    title = ' '.join(title) if isinstance(title, list) else title

    if title.startswith('['):title = title[1:]
    if title.endswith(']'): title = title[:-1]
    if title.endswith('.'): title = title[:-1]
    if title.endswith(']'): title = title[:-1]
    return title.lower() + ' .'


def clean_abstract(abstract):
    """
    :param abstract:
    :return: Basic text cleaning for abstract
    """
    if abstract.endswith('.'): abstract = abstract[:-1] + ' .'
    return abstract.lower()


def get_medline_files_path():
    """
    :return: helper function to get medline file names
    """
    file_names = []
    with FTP('ftp.ncbi.nlm.nih.gov') as ftp:
        ftp.login()
        lines = []
        ftp.dir('pubmed/baseline', lines.append)
        for i in lines:
            tokens = i.split()
            name = tokens[-1]
            if name.endswith('.gz'):
                file_names.append(name)
    return file_names


def medline_download(renew=False):
    print('Downloading Medline XML files ...')
    file_names = get_medline_files_path()[:30] # Remove this for full-scale operation
    for f_name in tqdm(file_names):
        if not os.path.isfile(os.path.join(medline_folder, f_name)) or renew:
            if f_name not in os.listdir(medline_folder):
                with request.urlopen(os.path.join(base_url, f_name)) as response, open(os.path.join(medline_folder, f_name), 'wb') as out_file:
                    shutil.copyfileobj(response, out_file)
                    sleep(1)


def medline_parser(med_xml):
    dicts_out = pp.parse_medline_xml(os.path.join(medline_folder, med_xml),
                                     year_info_only=False,
                                     nlm_category=False,
                                     author_list=False,
                                     reference_list=False)

    pack = []
    for i in dicts_out:
        pmid = i['pmid']
        c_title = clean_title(i['title'])
        title = c_title if len(c_title)>10 else None # ignore noise titles

        c_abstract = clean_abstract(i['abstract'])
        abstract = c_abstract if len(c_abstract)>10 else None # ignore noise abstract

        if len(i['mesh_terms']):
            mesh_terms = [x.strip().split(':')[1].lower() for x in i['mesh_terms'].split(';')]
        else:
            mesh_terms = None

        if all([title, abstract, mesh_terms]):
            pack.append((pmid, title, abstract, mesh_terms))
    return pack


def multi_process_medline():
    """
    :return: list of pickle files in which pmids are mapped to their mesh terms, titles and abstracts (strings)
    """
    print('Processing XML files ...')
    xml_files = [xml_file for xml_file in os.listdir(medline_folder) if xml_file.endswith('.xml.gz')]
    shuffle(xml_files) #load-balance files with different sizes
    for idx in tqdm(range(0, len(xml_files), 10)):
        xml_files_batch = xml_files[idx: idx + 10]
        with Pool(processes=num_workers) as pool:
            pmid2content_map_all = pool.map(medline_parser, xml_files_batch)
        pmid2content_map_all = list(chain(*pmid2content_map_all))

        pmid2content = defaultdict(set)
        for entry in pmid2content_map_all:
            pmid2content[entry[0]] = entry[1:]

        with open(os.path.join(medline_folder, 'pmid2content%d.pkl' % idx), 'wb') as f:
            pickle.dump(pmid2content, f)
        pmid2content.clear()
    for gz_file in os.listdir(medline_folder): # remove processed files
        if gz_file.endswith('.gz'):
            os.remove(os.path.join(medline_folder, gz_file))



if __name__ == "__main__":
    medline_download()
    multi_process_medline()

Downloading Medline XML files ...


  0%|                                                                                   | 0/30 [00:00<?, ?it/s]  3%|██▌                                                                        | 1/30 [00:01<00:36,  1.27s/it]  7%|█████                                                                      | 2/30 [00:02<00:36,  1.29s/it] 10%|███████▌                                                                   | 3/30 [00:03<00:34,  1.28s/it] 13%|██████████                                                                 | 4/30 [00:05<00:34,  1.34s/it] 17%|████████████▌                                                              | 5/30 [00:06<00:34,  1.37s/it] 20%|███████████████                                                            | 6/30 [00:08<00:32,  1.37s/it] 23%|█████████████████▌                                                         | 7/30 [00:09<00:30,  1.34s/it] 27%|████████████████████                                                       | 8/30 [00:10<00:29,  1

Processing XML files ...


  0%|                                                                                    | 0/3 [00:00<?, ?it/s] 33%|█████████████████████████▎                                                  | 1/3 [00:28<00:57, 28.87s/it] 67%|██████████████████████████████████████████████████▋                         | 2/3 [01:00<00:30, 30.43s/it]100%|████████████████████████████████████████████████████████████████████████████| 3/3 [01:28<00:00, 29.34s/it]100%|████████████████████████████████████████████████████████████████████████████| 3/3 [01:28<00:00, 29.49s/it]


### II Index

In [3]:
!pip install whoosh

Collecting whoosh
  Downloading Whoosh-2.7.4-py2.py3-none-any.whl.metadata (3.1 kB)
Downloading Whoosh-2.7.4-py2.py3-none-any.whl (468 kB)
Installing collected packages: whoosh
Successfully installed whoosh-2.7.4

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [4]:
# check structure of pickle
import os
import pickle

# pick one pickle file
pkl_file = os.listdir(medline_folder)[0]
pkl_path = os.path.join(medline_folder, pkl_file)

with open(pkl_path, "rb") as f:
    data_dict = pickle.load(f)

first_key = list(data_dict.keys())[0]
sample = data_dict[first_key]

print(f"Pickle file: {pkl_file}")
print(f"PMID: {first_key}")
print(f"Type: {type(sample)}")
print(f"Content: {sample}")


Pickle file: pmid2content0.pkl
PMID: 759666
Type: <class 'tuple'>
Content: ('lyme arthritis in wisconsin .', 'rash, severe constitutional symptoms, and arthritis developed in three persons who were bitten by ticks in wisconsin. on comparison with other reports of arthritis related to tick bites, we found that the illness of our patients had clinical features consistent with lyme arthritis. lyme arthritis appears not to be restricted to new england as has been previously reported .', ['adult', 'arthritis', 'child, preschool', 'erythema', 'female', 'humans', 'male', 'meningism', 'new england', 'recurrence', 'tick infestations', 'wisconsin'])


In [5]:
# used II_index_solution.py -->  added mesh term

import os
import pickle
import shutil
from tqdm import tqdm
from whoosh import index
#from I_fetch_pubmed import medline_folder
from whoosh.fields import Schema, TEXT, ID, KEYWORD


# Schema
schema = Schema(
    id=ID(stored=True, unique=True),
    title=TEXT(stored=True),
    body=TEXT(stored=True),
    mesh=KEYWORD(stored=True, commas=True) 
)

# Index
index_dir = "pubmed_index"


def get_index():
    if os.path.exists(index_dir):
        shutil.rmtree(index_dir)
    os.mkdir(index_dir)
    ix = index.create_in(index_dir, schema)
    writer = ix.writer()
    for pkl_file in os.listdir(medline_folder):
        pkl_obj = pickle.load(open(os.path.join(medline_folder, pkl_file), 'rb'))
        print('Indexing %s ...'%pkl_file)

        for idx in tqdm(pkl_obj):
            mesh_terms = pkl_obj[idx][2] if len(pkl_obj[idx]) > 2 else []
            mesh = ",".join(mesh_terms) if isinstance(mesh_terms, list) else ""  # to handle multiple or no mesh terms

            writer.add_document(id=str(idx), title=pkl_obj[idx][0], body=pkl_obj[idx][1], mesh=mesh)

    writer.commit()


if __name__ == "__main__":
    get_index()



Indexing pmid2content0.pkl ...


  0%|                                                                               | 0/125086 [00:00<?, ?it/s]  0%|                                                                     | 6/125086 [00:00<1:11:32, 29.14it/s]  0%|                                                                    | 162/125086 [00:00<03:10, 655.60it/s]  0%|▏                                                                   | 303/125086 [00:00<02:13, 932.56it/s]  0%|▏                                                                  | 443/125086 [00:00<01:54, 1091.93it/s]  0%|▎                                                                  | 570/125086 [00:00<01:48, 1147.26it/s]  1%|▍                                                                   | 695/125086 [00:01<04:28, 463.80it/s]  1%|▍                                                                   | 817/125086 [00:01<03:34, 578.67it/s]  1%|▌                                                                   | 964/125086 [00:01<02:48, 736

Indexing pmid2content10.pkl ...


  0%|                                                                               | 0/125376 [00:00<?, ?it/s]  0%|                                                                     | 53/125376 [00:00<07:55, 263.50it/s]  0%|                                                                      | 80/125376 [00:00<22:50, 91.42it/s]  0%|                                                                    | 216/125376 [00:00<06:48, 306.02it/s]  0%|▏                                                                   | 344/125376 [00:00<04:12, 495.18it/s]  0%|▎                                                                   | 520/125376 [00:01<02:42, 769.88it/s]  1%|▍                                                                   | 694/125376 [00:01<02:04, 998.85it/s]  1%|▍                                                                  | 856/125376 [00:01<01:47, 1155.04it/s]  1%|▌                                                                 | 1008/125376 [00:01<01:39, 1249

Indexing pmid2content20.pkl ...


  0%|                                                                               | 0/138076 [00:00<?, ?it/s]  0%|                                                                     | 8/138076 [00:00<2:31:33, 15.18it/s]  0%|                                                                    | 163/138076 [00:00<06:43, 341.92it/s]  0%|▏                                                                   | 328/138076 [00:00<03:32, 646.75it/s]  0%|▏                                                                    | 453/138076 [00:07<51:40, 44.39it/s]  0%|▎                                                                    | 572/138076 [00:07<34:21, 66.70it/s]  0%|▎                                                                    | 669/138076 [00:07<25:07, 91.15it/s]  1%|▍                                                                   | 789/138076 [00:07<17:17, 132.34it/s]  1%|▍                                                                   | 904/138076 [00:08<12:27, 183

### III Retrieve

In [6]:
# II_retrieve.py

from whoosh import index
#from II import index_dir
from whoosh.qparser import MultifieldParser, OrGroup, AndGroup, FuzzyTermPlugin, PhrasePlugin

# Load your index here
ix = index.open_dir(index_dir)

def search(text, mode="or", fuzzy=False, limit=5):
    with ix.searcher() as searcher:
        # Define parser and search logic
        group = OrGroup if mode == "or" else AndGroup
        parser=MultifieldParser(["title","body"], schema=ix.schema, group=group)

        if fuzzy:
            parser.add_plugin(FuzzyTermPlugin())

        parser.add_plugin(PhrasePlugin())

        query=parser.parse(text)
        
        results=searcher.search(query, limit=limit)

        for hit in results:
            print(hit["id"], hit["title"])


In [8]:
# basic one word
search("schizophrenia")                    

658577 schizophrenia and albinism .
363576 schizophrenia .
743940 schizophrenia and addiction .
919379 catecholamine storage in schizophrenia .
869058 anhedonia and schizophrenia .


In [11]:
# default or (widens search results, only one of the terms needs to match)
search("schizophrenia albinism")              
    

658577 schizophrenia and albinism .
608941 albinism in icelandic sheep .
666626 x-linked ocular albinism in blacks. ocular albinism cum pigmento .
687204 autosomal recessively inherited ocular albinism. a new form of ocular albinism affecting females as severely as males .
687555 the perifoveal vasculature in albinism .


In [12]:
# and
search("schizophrenia albinism", mode="and")   

658577 schizophrenia and albinism .


In [17]:
# fuzzy (with wildcard, handles minor misspelling / typos)
search("schiz*phrenia", fuzzy=True)           

163074 sleep disturbance in schizophrenia. a revisit .
658577 schizophrenia and albinism .
363576 schizophrenia .
743940 schizophrenia and addiction .
919379 catecholamine storage in schizophrenia .


In [21]:
# phrase
search("behavioral therapy for schizophrenia") 

638933 family therapy and schizophrenia .
4627 behavioral and psychodynamic dimensions of the new sex therapy .
678348 behavioral marriage therapy. ii. empirical perspective .
727896 conjoint marital therapy: a cognitive behavioral model .
605181 group therapy of schizophrenia in the inital phase of the disease .


### IV MeSH prediction and evaluation

In [22]:
# attempted to solve task, but cannot test it. 
# i tried it with the 3 pickle files i was able to process with my setup...
# it runs, but the results are not meaningful (both 0.0)


for pmid, (text, mesh) in tqdm(test_data.items()):
    query_text = mesh_to_query(mesh)
    query = qp.parse(query_text)
    results = searcher.search(query, limit=5)
    
    # Check if the correct PMID is in top-1 and top-5
    result_pmids = [hit["pmid"] for hit in results]
    
    if pmid in result_pmids[:1]:
        correct_at_1 += 1
    if pmid in result_pmids[:5]:
        correct_at_5 += 1
    
    total += 1

print(f"Precision@1: {correct_at_1/total:.3f}")
print(f"Precision@5: {correct_at_5/total:.3f}")
print(f"Total test queries: {total}")

import os
import pickle
from random import shuffle
from tqdm import tqdm
from whoosh.fields import Schema, TEXT, ID
from whoosh import index, scoring
from whoosh.qparser import OrGroup
import shutil
#from I_fetch_pubmed import medline_folder
from whoosh.qparser import QueryParser

# Constants
medline_folder = "pmid2contents"  
train_size = 2       #1000000      
test_size = 1        #1000         
index_dir = "mesh_eval_index"  

if os.path.exists(index_dir):
    shutil.rmtree(index_dir)
os.mkdir(index_dir)

# Schema
schema = Schema(pmid=ID(stored=True, unique=True),content=TEXT(stored=True))

ix = index.create_in(index_dir, schema)
writer = ix.writer()


# Load records
all_pmids = []

# Create test set
for pkl_file in os.listdir(medline_folder):
    if pkl_file.endswith(".pkl"):
        with open(os.path.join(medline_folder, pkl_file), "rb") as f:
            all_pmids.extend(list(pickle.load(f).keys()))

shuffle(all_pmids)
test_set_pmids = set(all_pmids[:test_size])
test_data = {}



# Create and index training set
for pkl_file in os.listdir(medline_folder):
    if not pkl_file.endswith(".pkl"):
        continue
    with open(os.path.join(medline_folder, pkl_file), "rb") as f:
        obj = pickle.load(f)
        for pmid, data in obj.items():
            if len(data) < 3 or not data[2]:
                continue
            title, abstract, mesh = data
            text = f"{title} {abstract}"

            if pmid in test_set_pmids:
                test_data[pmid] = (text, mesh)
            else:
                writer.add_document(pmid=str(pmid), content=text)

writer.commit()

# Evaluation 
with ix.searcher(weighting=scoring.BM25F()) as searcher:
    # OrGroup: rank by most matches
    qp = QueryParser("content", ix.schema, group=OrGroup)


def mesh_to_query(mesh):
    return " ".join(mesh) if isinstance(mesh, list) else str(mesh)

correct_at_1 = 0
correct_at_5 = 0
total = 0


with ix.searcher(weighting=scoring.BM25F()) as searcher:
    qp = QueryParser("content", ix.schema, group=OrGroup)

    print("Evaluating test abstracts...")
    for pmid, (_, mesh) in tqdm(test_data.items()):
        query_str = mesh_to_query(mesh)
        query = qp.parse(query_str)

        results = searcher.search(query, limit=5)
        hit_pmids = [hit["pmid"] for hit in results]

        total += 1
        if hit_pmids:
            if hit_pmids[0] == str(pmid):
                correct_at_1 += 1
            if str(pmid) in hit_pmids:
                correct_at_5 += 1

# accuracies
acc_at_1 = correct_at_1 / total
acc_at_5 = correct_at_5 / total

print("Evaluation:")
print(f"Accuracy@1: {acc_at_1:.4f}")
print(f"Accuracy@5: {acc_at_5:.4f}")


Evaluating test abstracts...


  0%|                                                                                    | 0/1 [00:00<?, ?it/s]100%|████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.15it/s]100%|████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.14it/s]

Evaluation:
Accuracy@1: 0.0000
Accuracy@5: 0.0000



