In [1]:
# Import libraries
import os
import pandas as pd
import torch
import numpy as np
import joblib
from transformers import AutoTokenizer, AutoModel
from sklearn.linear_model import LogisticRegression
from Bio import Entrez
import xml.etree.ElementTree as ET
import time
import re

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device set to: {device}")

Device set to: cuda


In [2]:
# Load tokenizer and SciBERT encoder
model_name = "allenai/scibert_scivocab_uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
encoder = AutoModel.from_pretrained(model_name).to(device)
encoder.eval()

# Load trained logistic classifier
clf = joblib.load("../01.1_binary_classification/logistic_classifier_binary.pkl")

print("Models loaded: SciBERT encoder + Logistic classifier")

Models loaded: SciBERT encoder + Logistic classifier


In [3]:
# Load citation data
col_names = ["PMID", "ISSN", "DP", "EDAT", "PYEAR"]

citation_df = pd.read_csv(
    "/user/work/nd23942/semmeddb/raw/CITATION.csv",
    encoding="latin1",   
    header=None,        
    names=col_names       
)

print(f"Loaded {len(citation_df)} rows from CITATION.csv")
print(citation_df.head())

Loaded 37233341 rows from CITATION.csv
    PMID       ISSN           DP       EDAT  PYEAR
0      1  0006-2944     1975 Jun   1975-6-1   1975
1     10  1873-2968  1975 Sep 01   1975-9-1   1975
2    100  0547-6844         1975   1975-1-1   1975
3   1000  0264-6021     1975 Sep   1975-9-1   1975
4  10000  0006-3002  1976 Sep 28  1976-9-28   1976


In [4]:
# Sample 200 unique PMIDs
sample_pmids = citation_df["PMID"].dropna().astype(int).drop_duplicates().sample(n=1000, random_state=42).tolist()
print(f"Sampled {len(sample_pmids)} PMIDs, first 5:", sample_pmids[:5])

Sampled 1000 PMIDs, first 5: [8057077, 27168519, 5484088, 20203436, 27353385]


In [5]:
# Function to fetch abstracts
def fetch_abstracts(pmids):
    records = []
    for pmid in pmids:
        try:
            handle = Entrez.efetch(db="pubmed", id=str(pmid), rettype="abstract", retmode="xml")
            xml_data = handle.read()
            handle.close()
            time.sleep(0.4)  # polite waiting

            # Parse XML and collect every AbstractText paragraph
            root = ET.fromstring(xml_data)
            article = root.find(".//PubmedArticle")
            abstract_texts = []
            if article is not None:
                abstract_elem = article.find(".//Abstract")
                if abstract_elem is not None:
                    for node in abstract_elem.findall("AbstractText"):
                        # Some nodes have .text or may be empty; strip and append
                        if node.text and node.text.strip():
                            abstract_texts.append(node.text.strip())

            if abstract_texts:
                status = "valid"
                # join with a space to get the full multi-paragraph abstract
                abstract = " ".join(abstract_texts)
            else:
                status = "invalid"
                abstract = None

            records.append({
                "PMID": pmid,
                "Abstract": abstract,
                "Status": status
            })

        except Exception as e:
            print(f"Error fetching PMID {pmid}: {e}")
            records.append({
                "PMID": pmid,
                "Abstract": None,
                "Status": "error"
            })

    return pd.DataFrame(records)

# Fetch abstracts
Entrez.email = "nd23942@bristol.ac.uk"

abstract_df = fetch_abstracts(sample_pmids)
print(abstract_df["Status"].value_counts())
abstract_df.head(2)

Status
valid      693
invalid    307
Name: count, dtype: int64


Unnamed: 0,PMID,Abstract,Status
0,8057077,Inward rectifier (IR) K+ channels of bovine pu...,valid
1,27168519,Self-harm (SH; intentional self-poisoning or s...,valid


In [6]:
# Filter valid abstracts
abstract_df = abstract_df[
    (abstract_df["Status"] == "valid") & (abstract_df["Abstract"].notnull())
].reset_index(drop=True)

print(f"Valid abstracts for classification: {len(abstract_df)}")
abstract_df

Valid abstracts for classification: 693


Unnamed: 0,PMID,Abstract,Status
0,8057077,Inward rectifier (IR) K+ channels of bovine pu...,valid
1,27168519,Self-harm (SH; intentional self-poisoning or s...,valid
2,20203436,"Fecal samples from Ruddy Shelduck, Tadorna fer...",valid
3,27353385,"Young women, especially adolescents, often lac...",valid
4,34657444,[Figure: see text].,valid
...,...,...,...
688,2383353,Systemic absorption of water-soluble and water...,valid
689,36694093,To evaluate the impact of an optimal and repro...,valid
690,1593194,Lateral elbow pain syndrome is probably a fair...,valid
691,16739407,Anopheles gambiae s.s. Giles accepted a range ...,valid


In [7]:
# Save to CSV
abstract_df.to_csv("raw_abstracts.csv", index=False)
print(abstract_df.head())

       PMID                                           Abstract Status
0   8057077  Inward rectifier (IR) K+ channels of bovine pu...  valid
1  27168519  Self-harm (SH; intentional self-poisoning or s...  valid
2  20203436  Fecal samples from Ruddy Shelduck, Tadorna fer...  valid
3  27353385  Young women, especially adolescents, often lac...  valid
4  34657444                                [Figure: see text].  valid


In [8]:
# Function to embed a batch of sentences
def get_cls_embeddings(sentences, batch_size=16):
    embeddings = []
    for i in range(0, len(sentences), batch_size):
        batch = sentences[i:i+batch_size]
        inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=256).to(device)
        with torch.no_grad():
            outputs = encoder(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states
        last_four = torch.stack(hidden_states[-4:], dim=0)
        pooled = last_four.mean(dim=0)[:, 0, :]
        embeddings.append(pooled.cpu().numpy())
    return np.vstack(embeddings)

In [9]:
# Very basic sentence splitter
def split_sentences(text):
    sentences = re.split(r'(?<=[.!?])\s+', text.strip())
    return [s for s in sentences if len(s) > 5]

# Example check
abstract_df = pd.read_csv("raw_abstracts.csv")
sample_sentences = split_sentences(abstract_df.loc[0, "Abstract"])
print(f"Example abstract split into {len(sample_sentences)} sentences.")
print(sample_sentences[:3])

Example abstract split into 14 sentences.
['Inward rectifier (IR) K+ channels of bovine pulmonary artery endothelial cells were studied using the whole-cell, cell-attached, and outside-out patch-clamp configurations.', 'The effects of Rb+ on the voltage dependence and kinetics of IR gating were explored, with [Rb+]o + [K+]o = 160 mM.', 'Partial substitution of Rb+ for K+ resulted in voltage-dependent reduction of inward currents, consistent with Rb+ being a weakly permeant blocker of the IR.']


In [10]:
# Predict sections for all sentences using binary model
all_records = []

for idx, row in abstract_df.iterrows():
    pmid = row["PMID"]
    abstract = row["Abstract"]
    sentences = split_sentences(abstract)
    
    if not sentences:
        continue

    embeddings = get_cls_embeddings(sentences)
    preds = clf.predict(embeddings)

    # Build the new abstract with [TAR] tags
    targeted_sentences = []
    for sent, label_id in zip(sentences, preds):
        label = ["background", "finding"][label_id]
        if label == "finding":
            sent = "[TAR] " + sent  # Tag finding sentences
        targeted_sentences.append(sent)

    # Rebuild abstract
    targeted_abstract = " ".join(targeted_sentences)

    all_records.append({
        "PMID": pmid,
        "Targeted_Abstract": targeted_abstract
    })

print(f"Generated targeted abstracts for {len(all_records)} articles.")



Generated targeted abstracts for 693 articles.


In [11]:
# Convert to DataFrame and save
targeted_df = pd.DataFrame(all_records)
targeted_df.to_csv("targeted_abstracts.csv", index=False)

print(targeted_df.head())

       PMID                                  Targeted_Abstract
0   8057077  Inward rectifier (IR) K+ channels of bovine pu...
1  27168519  Self-harm (SH; intentional self-poisoning or s...
2  20203436  Fecal samples from Ruddy Shelduck, Tadorna fer...
3  27353385  Young women, especially adolescents, often lac...
4  34657444                          [TAR] [Figure: see text].
