In [8]:
# Import libraries
import os
import pandas as pd
import torch
import numpy as np
import joblib
from transformers import AutoTokenizer, AutoModel
from sklearn.linear_model import LogisticRegression
from Bio import Entrez
from openai import OpenAI
from tqdm import tqdm
from difflib import SequenceMatcher
import xml.etree.ElementTree as ET
import time
import re

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device set to: {device}")

Device set to: cuda


In [2]:
# Load citation data
col_names = ["PMID", "ISSN", "DP", "EDAT", "PYEAR"]

citation_df = pd.read_csv(
    "/user/work/nd23942/semmeddb/raw/CITATION.csv",
    encoding="latin1",   
    header=None,        
    names=col_names       
)

print(f"Loaded {len(citation_df)} rows from CITATION.csv")
print(citation_df.head())

Loaded 37233341 rows from CITATION.csv
    PMID       ISSN           DP       EDAT  PYEAR
0      1  0006-2944     1975 Jun   1975-6-1   1975
1     10  1873-2968  1975 Sep 01   1975-9-1   1975
2    100  0547-6844         1975   1975-1-1   1975
3   1000  0264-6021     1975 Sep   1975-9-1   1975
4  10000  0006-3002  1976 Sep 28  1976-9-28   1976


In [3]:
# Sample 10000 unique PMIDs
sample_pmids = citation_df["PMID"].dropna().astype(int).drop_duplicates().sample(n=10000, random_state=42).tolist()
print(f"Sampled {len(sample_pmids)} PMIDs, first 5:", sample_pmids[:5])

Sampled 10000 PMIDs, first 5: [8057077, 27168519, 5484088, 20203436, 27353385]


In [4]:
# Function to fetch abstracts
def fetch_abstracts(pmids):
    records = []
    for pmid in pmids:
        try:
            handle = Entrez.efetch(db="pubmed", id=str(pmid), rettype="abstract", retmode="xml")
            xml_data = handle.read()
            handle.close()
            time.sleep(0.4)  # polite waiting

            # Parse XML and collect every AbstractText paragraph
            root = ET.fromstring(xml_data)
            article = root.find(".//PubmedArticle")
            abstract_texts = []
            if article is not None:
                abstract_elem = article.find(".//Abstract")
                if abstract_elem is not None:
                    for node in abstract_elem.findall("AbstractText"):
                        # Some nodes have .text or may be empty; strip and append
                        if node.text and node.text.strip():
                            abstract_texts.append(node.text.strip())

            if abstract_texts:
                status = "valid"
                # join with a space to get the full multi-paragraph abstract
                abstract = " ".join(abstract_texts)
            else:
                status = "invalid"
                abstract = None

            records.append({
                "PMID": pmid,
                "Abstract": abstract,
                "Status": status
            })

        except Exception as e:
            print(f"Error fetching PMID {pmid}: {e}")
            records.append({
                "PMID": pmid,
                "Abstract": None,
                "Status": "error"
            })

    return pd.DataFrame(records)

# Fetch abstracts
Entrez.email = "nd23942@bristol.ac.uk"

abstract_df = fetch_abstracts(sample_pmids)
print(abstract_df["Status"].value_counts())
abstract_df.head(2)

Error fetching PMID 31382024: HTTP Error 400: Bad Request
Status
valid      6863
invalid    3136
error         1
Name: count, dtype: int64


Unnamed: 0,PMID,Abstract,Status
0,8057077,Inward rectifier (IR) K+ channels of bovine pu...,valid
1,27168519,Self-harm (SH; intentional self-poisoning or s...,valid


In [6]:
# Filter valid abstracts
abstract_df = abstract_df[
    (abstract_df["Status"] == "valid") & (abstract_df["Abstract"].notnull())
].reset_index(drop=True)

print(f"Valid abstracts for classification: {len(abstract_df)}")

# Save to CSV
abstract_df.to_csv("abstracts_data.csv", index=False)
print(abstract_df.head())

Valid abstracts for classification: 6863
       PMID                                           Abstract Status
0   8057077  Inward rectifier (IR) K+ channels of bovine pu...  valid
1  27168519  Self-harm (SH; intentional self-poisoning or s...  valid
2  20203436  Fecal samples from Ruddy Shelduck, Tadorna fer...  valid
3  27353385  Young women, especially adolescents, often lac...  valid
4  34657444                                [Figure: see text].  valid


In [9]:
os.environ["OPENAI_API_KEY"] = "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"
# Load API key from environment
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

In [10]:
SYSTEM_PROMPT = """
You are a biomedical expert. Given a full abstract, identify the sentence(s)
that clearly describe the key factual findings or results of the study.
Do not label background, objectives, methods, or interpretative statements.
Return each result sentence exactly as it appears, one per line.
""".strip()

def build_prompt(text):
    return f"""
Below is a biomedical abstract. Identify and return only the factual result sentences.
Return each sentence exactly as it appears, one per line.

{text}
""".strip()

In [11]:
def query_gpt(abstract, max_retries=3, pause=2):
    prompt = build_prompt(abstract)
    for i in range(max_retries):
        try:
            resp = client.chat.completions.create(
                model="gpt-3.5-turbo",
                messages=[
                    {"role": "system",  "content": SYSTEM_PROMPT},
                    {"role": "user",    "content": prompt}
                ],
                temperature=0,
                max_tokens=512,
            )
            txt = resp.choices[0].message.content.strip()
            # Filter out apology
            if txt.lower().startswith(("i'm sorry","sorry","unfortunately")):
                return "", 0
            lines = [L.strip() for L in txt.split("\n") if L.strip()]
            return "\n".join(lines), len(lines)
        except Exception as e:
            print(f" GPT retry {i+1}/{max_retries} failed: {e}")
            time.sleep(pause)
    return "", 0

In [12]:
def split_sentences(text):
    paras = re.split(r'\n+', text)
    sents = []
    for p in paras:
        parts = re.split(r'(?<=[\.!?])\s+(?=[A-Z0-9])', p)
        sents.extend(parts)
    return [s.strip() for s in sents if s.strip()]


import string
def normalize(s):
    return s.lower().translate(str.maketrans("", "", string.punctuation)).strip()

def close_match(s, candidates, thresh=0.6):
    s_norm = normalize(s)
    for c in candidates:
        c_norm = normalize(c)
        if s_norm in c_norm or c_norm in s_norm:
            return True
        if SequenceMatcher(None, s_norm, c_norm).ratio() >= thresh:
            return True
    return False

In [14]:
results = []
for _, row in tqdm(abstract_df.iterrows(), total=len(abstract_df), desc="Querying GPT"):
    pmid     = row["PMID"]
    abstract = row["Abstract"]
    gpt_out, gpt_count = query_gpt(abstract)
    gpt_lines = [L.lstrip("-–— ").strip() for L in gpt_out.split("\n") if L.strip()]

    tagged_lines = []
    for sent in split_sentences(abstract):
        if close_match(sent, gpt_lines):
            tagged_lines.append(f"[GPT] {sent}")
        else:
            tagged_lines.append(sent)

    tagged_abstract = " ".join(tagged_lines)
    tag_count = tagged_abstract.count("[GPT]")

    results.append({
        "PMID": pmid,
        "Abstract": abstract,
        "GPT_Result_Sentences": gpt_out,
        "GPT_Count": gpt_count,
        "Tagged_Abstract": tagged_abstract,
        "Tagged_Count": tag_count
    })
    time.sleep(1.2)

Querying GPT: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6863/6863 [4:32:08<00:00,  2.38s/it]


In [15]:
df_out = pd.DataFrame(results)
df_out.to_csv("trainning_targeted_data.csv", index=False)
print("Saved trainning_targeted_data.csv")

print(df_out[["PMID","GPT_Count","Tagged_Count"]].head())

avg_prec = (df_out["Tagged_Count"] / df_out["GPT_Count"].replace(0,1)).mean()
avg_recall = (df_out["Tagged_Count"] / df_out["GPT_Count"].replace(1,1)).mean()
print(f"Approx match rate: Precision-like = {avg_prec:.3f}")

Saved trainning_targeted_data.csv
       PMID  GPT_Count  Tagged_Count
0   8057077          6             6
1  27168519          6             5
2  20203436          3             3
3  27353385          6             6
4  34657444          1             1
Approx match rate: Precision-like = 0.995


In [27]:
import pandas as pd
import nltk
from nltk.tokenize import sent_tokenize
import json
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     /user/home/nd23942/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [19]:
df = pd.read_csv("training_targeted_data.csv")

In [31]:
records = []

for _, row in tqdm(df.iterrows(), total=len(df), desc="🔧 Processing abstracts"):
    pmid = str(row.get('PMID'))
    abstract = str(row.get('Abstract', ''))
    tagged = str(row.get('Tagged_Abstract', ''))

    # Split sentence
    abstract_sents = split_sentences(abstract)
    tagged_sents = split_sentences(tagged)
    
    finding_sents = [s.replace("[GPT]", "").strip() for s in tagged_sents if "[GPT]" in s]

    # Labelling
    labels = [1 if close_match(s, finding_sents) else 0 for s in abstract_sents]

    records.append({
        "pmid": pmid,
        "sentences": abstract_sents,
        "labels": labels
    })

output_path = "training_data.jsonl"
with open(output_path, "w", encoding="utf-8") as f:
    for record in records:
        f.write(json.dumps(record, ensure_ascii=False) + "\n")

print(f"Saved in {output_path}")

🔧 Processing abstracts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6863/6863 [00:14<00:00, 485.42it/s]


Saved in training_data.jsonl
