# 1. Data loading
Firstly, we will load the *OHSUMED* data from different files and merge them into one combined dataframe, containing all years from 1987 - 1991. Afterwards, we will load the file where the relevance labels of documents during all five years will be labeled. In the end, we will left join the relevance labels to our main dataframe.

In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from transformers import DistilBertTokenizer
from transformers import DistilBertForSequenceClassification
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

def parse_ohsumed_file(file_path):
    """Parses an OHSUMED file into a DataFrame with proper column names."""
    documents = []
    document = {}

    # Read the file line by line
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()  # Remove extra whitespace

            if line.startswith(".I"):  # New document identifier
                if document:  # If there's an existing document, save it
                    documents.append(document)
                document = {"sequential identifier": line[3:]}  # Initialize a new document

            elif line.startswith(".U"):  # MEDLINE identifier
                document["MEDLINE identifier"] = next(f).strip()

            elif line.startswith(".S"):  # Source
                document["source"] = next(f).strip()

            elif line.startswith(".M"):  # MeSH terms
                document["mesh_terms"] = next(f).strip()

            elif line.startswith(".T"):  # Title
                document["title"] = next(f).strip()

            elif line.startswith(".P"):  # Publication type
                document["publication type"] = next(f).strip()

            elif line.startswith(".W"):  # Abstract
                document["abstract"] = next(f).strip()

            elif line.startswith(".A"):  # Author
                document["author"] = next(f).strip()

    # Add the last document if it exists
    if document:
        documents.append(document)

    # Convert the list of documents into a DataFrame
    return pd.DataFrame(documents)

# File paths for individual files
file_87_path = "./data/ohsumed.87.txt"
file_88_path = "./data/ohsumed.88.txt"
file_89_path = "./data/ohsumed.89.txt"
file_90_path = "./data/ohsumed.90.txt"
file_91_path = "./data/ohsumed.91.txt"

# Parse each file into its own DataFrame
df_ohsumed_87 = parse_ohsumed_file(file_87_path)
df_ohsumed_88 = parse_ohsumed_file(file_88_path)
df_ohsumed_89 = parse_ohsumed_file(file_89_path)
df_ohsumed_90 = parse_ohsumed_file(file_90_path)
df_ohsumed_91 = parse_ohsumed_file(file_91_path)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Checking row amount of all years files
print(f"Number of rows in ohsumed.87: {len(df_ohsumed_87)}")
print(f"Number of rows in ohsumed.88: {len(df_ohsumed_88)}")
print(f"Number of rows in ohsumed.89: {len(df_ohsumed_89)}")
print(f"Number of rows in ohsumed.90: {len(df_ohsumed_90)}")
print(f"Number of rows in ohsumed.91: {len(df_ohsumed_91)}")

Number of rows in ohsumed.87: 54710
Number of rows in ohsumed.88: 70825
Number of rows in ohsumed.89: 74869
Number of rows in ohsumed.90: 73824
Number of rows in ohsumed.91: 74338


In [3]:
# Combine all DataFrames into a single DataFrame
ohsumed_combined_df = pd.concat([df_ohsumed_87, df_ohsumed_88, df_ohsumed_89, df_ohsumed_90, df_ohsumed_91], ignore_index=True)

# Print the combined DataFrame's shape (rows and columns). Should be: 348566
print(f"Combined DataFrame shape: {ohsumed_combined_df.shape}")

# Display the first few rows of the combined DataFrame
print(ohsumed_combined_df.head())

Combined DataFrame shape: (348566, 8)
  sequential identifier MEDLINE identifier                             source  \
0                     1           87049087    Am J Emerg Med 8703; 4(6):491-5   
1                     2           87049088  Am J Emerg Med 8703; 4(6):496-500   
2                     3           87049089    Am J Emerg Med 8703; 4(6):501-3   
3                     4           87049090    Am J Emerg Med 8703; 4(6):504-6   
4                     5           87049092    Am J Emerg Med 8703; 4(6):511-3   

                                          mesh_terms  \
0  Allied Health Personnel/*; Electric Countersho...   
1  Antidepressive Agents, Tricyclic/*PO; Arrhythm...   
2  Adult; Aircraft/*; Altitude/*; Blood Gas Monit...   
3  Adolescence; Adult; Aged; Blood Glucose/*ME; D...   
4  Aged; Aged, 80 and over; Case Report; Female; ...   

                                               title  publication type  \
0  Refibrillation managed by EMT-Ds: incidence an...  JOURNAL AR

In [4]:
# Load the judged file (relevance labeles)
judged_df = pd.read_csv("./data/judged.txt", sep="\t", header=None,
                        names=["Query", "Document-UI", "Document-Index", "Relevance1", "Relevance2", "Relevance3"])

print(f"Loaded judged file with {len(judged_df)} rows")
judged_df.head()

Loaded judged file with 16140 rows


Unnamed: 0,Query,Document-UI,Document-Index,Relevance1,Relevance2,Relevance3
0,1,87097544,40626,d,,d
1,1,87153566,11852,n,,n
2,1,87157536,12693,d,,
3,1,87157537,12694,d,,
4,1,87184723,15450,n,,


In [5]:
judged_df.rename(columns={"Document-UI": "MEDLINE identifier"}, inplace=True)

# Converting key to string
ohsumed_combined_df["MEDLINE identifier"] = ohsumed_combined_df["MEDLINE identifier"].astype(int)
judged_df["MEDLINE identifier"] = judged_df["MEDLINE identifier"].astype(int)
judged_df["is_relevant_ind"] = 1

# Perform the left join
merged_df = ohsumed_combined_df.merge(judged_df, on="MEDLINE identifier", how="left")

# Display the resulting DataFrame
print(f"Resulting DataFrame shape: {merged_df.shape}")
print(merged_df.head())

# Create a new column Relevance_total based on the rules that can be kept showing relevance for all three relevance columns
merged_df["Relevance_total"] = np.where(
    ~merged_df["Relevance1"].isna(),  # If Relevance1 is not NaN, take it
    merged_df["Relevance1"],
    np.where(
        ~merged_df["Relevance2"].isna(),  # Else if Relevance2 is not NaN, take it
        merged_df["Relevance2"],
        merged_df["Relevance3"]  # Else take Relevance3
    )
)
merged_df.head(10000)

# Drop the specified columns
columns_to_drop = ["Query", "Document-Index", "Relevance1", "Relevance2", "Relevance3"]
merged_df.drop(columns=columns_to_drop, inplace=True)
# Mapping relevance labels to int
relevance_mapping = {'n': 0, 'p': 1, 'd': 2}

# Filling mising values in abstract so that BERT can be trained on strings
merged_df["abstract"] = merged_df["abstract"].fillna("")

# Apply the mapping to the Relevance1 column
merged_df["Relevance_total"] = merged_df["Relevance_total"].map(relevance_mapping)

# Display the resulting DataFrame
print(f"Resulting DataFrame shape: {merged_df.shape}")
print(merged_df.head())

#filtered_df = merged_df[merged_df["Relevance1"].notna()]
#
## Display the filtered DataFrame
#print(f"Number of rows where Relevance1 is not NaN: {len(filtered_df)}")
#print(filtered_df.head())

Resulting DataFrame shape: (350276, 14)
  sequential identifier  MEDLINE identifier  \
0                     1            87049087   
1                     2            87049088   
2                     3            87049089   
3                     4            87049090   
4                     5            87049092   

                              source  \
0    Am J Emerg Med 8703; 4(6):491-5   
1  Am J Emerg Med 8703; 4(6):496-500   
2    Am J Emerg Med 8703; 4(6):501-3   
3    Am J Emerg Med 8703; 4(6):504-6   
4    Am J Emerg Med 8703; 4(6):511-3   

                                          mesh_terms  \
0  Allied Health Personnel/*; Electric Countersho...   
1  Antidepressive Agents, Tricyclic/*PO; Arrhythm...   
2  Adult; Aircraft/*; Altitude/*; Blood Gas Monit...   
3  Adolescence; Adult; Aged; Blood Glucose/*ME; D...   
4  Aged; Aged, 80 and over; Case Report; Female; ...   

                                               title  publication type  \
0  Refibrillation managed

In [6]:
#Checking how many relevant documents are present in merged DF. Expectation: 16140
count_is_relevant = merged_df[merged_df["is_relevant_ind"] == 1].shape[0]
print(f"Number of rows where is_relevant_ind = 1: {count_is_relevant}")

Number of rows where is_relevant_ind = 1: 16140


Now, we have a merged dataframe, containing all document data and also the relevance labeling of the documents. Now we can proceed with splitting the data.

In [7]:
test = merged_df[merged_df["is_relevant_ind"] == 1]
remaining_rows = merged_df[merged_df["is_relevant_ind"] != 1]
training, validation = train_test_split(remaining_rows, test_size=0.15, random_state=42)

print(f"Test set size: {len(test)}")
print(f"Training set size: {len(training)}")
print(f"Validation set size: {len(validation)}")

Test set size: 16140
Training set size: 284015
Validation set size: 50121


Training BERT

In [8]:
training.loc[:, 'title_abstract'] = training['title'] + ' ' + training['abstract']
test.loc[:, 'title_abstract'] = test['title'] + ' ' + test['abstract']
validation.loc[:, 'title_abstract'] = validation['title'] + ' ' + validation['abstract']
training.head()
test.head()
validation.head()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test.loc[:, 'title_abstract'] = test['title'] + ' ' + test['abstract']


Unnamed: 0,sequential identifier,MEDLINE identifier,source,mesh_terms,title,publication type,abstract,author,is_relevant_ind,Relevance_total,title_abstract
33852,33698,87124329,Am Heart J 8705; 113(2 Pt 1):273-9,Aged; Comparative Study; Electrocardiography/*...,Non-Q wave myocardial infarction: recent chang...,JOURNAL ARTICLE.,A community-wide study of patients hospitalize...,Goldberg RJ; Gore JM; Alpert JS; Dalen JE.,,,Non-Q wave myocardial infarction: recent chang...
246463,245341,90357619,Transplant Proc 9011; 22(4):1885-6,"Antibodies, Anti-Idiotypic/*IM; Antibodies, Mo...",IgM-anti-IgG antibody as cause of positive B-c...,JOURNAL ARTICLE.,,Terness P; Berteli AJ; Steinitz M; Mytillineos...,,,IgM-anti-IgG antibody as cause of positive B-c...
307688,306209,91170056,J Appl Physiol 9106; 69(6):2091-6,Animal; Blood Pressure; Cardiac Output; Hemody...,Altered baroreflex function after tail suspens...,JOURNAL ARTICLE.,Experiments were performed on conscious chroni...,Brizzee BL; Walker BR.,,,Altered baroreflex function after tail suspens...
5082,5058,87097540,Am J Obstet Gynecol 8704; 156(1):52-6,Apgar Score; Cesarean Section/*; Delivery/*MT;...,Randomized management of the second nonvertex ...,JOURNAL ARTICLE.,Sixty twin deliveries after the thirty-fifth g...,Rabinovici J; Barkai G; Reichman B; Serr DM; M...,,,Randomized management of the second nonvertex ...
120978,120428,88110706,Chest 8805; 93(2):294-8,Adult; Aged; Arteries/*; Bloodletting/*; Carbo...,Single arterial puncture vs arterial cannula f...,JOURNAL ARTICLE.,"In an attempt to find the least invasive, safe...",Frye M; DiBenedetto R; Lain D; Morgan K.,,,Single arterial puncture vs arterial cannula f...


In [9]:
def is_sensitive(mesh_terms):
    if isinstance(mesh_terms, str):
        return 1 if 'urogenital' in mesh_terms.lower() or 'pregnancy complications' in mesh_terms.lower() else 0
    else:
        return 0

training.loc[:, 'label'] = training['mesh_terms'].apply(is_sensitive)
test.loc[:, 'label'] = test['mesh_terms'].apply(is_sensitive)
validation.loc[:, 'label'] = validation['mesh_terms'].apply(is_sensitive)
training.head()
test.head()
validation.head()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test.loc[:, 'label'] = test['mesh_terms'].apply(is_sensitive)


Unnamed: 0,sequential identifier,MEDLINE identifier,source,mesh_terms,title,publication type,abstract,author,is_relevant_ind,Relevance_total,title_abstract,label
33852,33698,87124329,Am Heart J 8705; 113(2 Pt 1):273-9,Aged; Comparative Study; Electrocardiography/*...,Non-Q wave myocardial infarction: recent chang...,JOURNAL ARTICLE.,A community-wide study of patients hospitalize...,Goldberg RJ; Gore JM; Alpert JS; Dalen JE.,,,Non-Q wave myocardial infarction: recent chang...,0
246463,245341,90357619,Transplant Proc 9011; 22(4):1885-6,"Antibodies, Anti-Idiotypic/*IM; Antibodies, Mo...",IgM-anti-IgG antibody as cause of positive B-c...,JOURNAL ARTICLE.,,Terness P; Berteli AJ; Steinitz M; Mytillineos...,,,IgM-anti-IgG antibody as cause of positive B-c...,0
307688,306209,91170056,J Appl Physiol 9106; 69(6):2091-6,Animal; Blood Pressure; Cardiac Output; Hemody...,Altered baroreflex function after tail suspens...,JOURNAL ARTICLE.,Experiments were performed on conscious chroni...,Brizzee BL; Walker BR.,,,Altered baroreflex function after tail suspens...,0
5082,5058,87097540,Am J Obstet Gynecol 8704; 156(1):52-6,Apgar Score; Cesarean Section/*; Delivery/*MT;...,Randomized management of the second nonvertex ...,JOURNAL ARTICLE.,Sixty twin deliveries after the thirty-fifth g...,Rabinovici J; Barkai G; Reichman B; Serr DM; M...,,,Randomized management of the second nonvertex ...,0
120978,120428,88110706,Chest 8805; 93(2):294-8,Adult; Aged; Arteries/*; Bloodletting/*; Carbo...,Single arterial puncture vs arterial cannula f...,JOURNAL ARTICLE.,"In an attempt to find the least invasive, safe...",Frye M; DiBenedetto R; Lain D; Morgan K.,,,Single arterial puncture vs arterial cannula f...,0


In [10]:
from datasets import Dataset
train_dataset = Dataset.from_pandas(training)
test_dataset = Dataset.from_pandas(test)
validation_dataset = Dataset.from_pandas(validation)

Tokenizing sentences

In [11]:
# Load the tokenizer
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

# Tokenize the dataset for BERTs training
def tokenize_function(examples):
    return tokenizer(examples["title_abstract"], truncation=True, padding="max_length", max_length=512)

train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)
validation_dataset = validation_dataset.map(tokenize_function, batched=True)

train_dataset = train_dataset.remove_columns(["title_abstract"])
test_dataset = test_dataset.remove_columns(["title_abstract"])
validation_dataset = validation_dataset.remove_columns(["title_abstract"])

train_dataset.set_format("torch")
test_dataset.set_format("torch")
validation_dataset.set_format("torch")

Map: 100%|██████████| 284015/284015 [29:02<00:00, 163.01 examples/s]
Map: 100%|██████████| 16140/16140 [03:02<00:00, 88.34 examples/s]
Map: 100%|██████████| 50121/50121 [06:48<00:00, 122.63 examples/s]


Loading DistilBERT

In [12]:
print(training.columns)

Index(['sequential identifier', 'MEDLINE identifier', 'source', 'mesh_terms',
       'title', 'publication type', 'abstract', 'author', 'is_relevant_ind',
       'Relevance_total', 'title_abstract', 'label'],
      dtype='object')


In [None]:
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

# Define the compute_metrics function
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)

    precision = precision_score(labels, predictions, average="binary")
    recall = recall_score(labels, predictions, average="binary")
    f1 = f1_score(labels, predictions, average="binary")
    f2 = f1_score(labels, predictions, beta=2, average="binary")
    accuracy = accuracy_score(labels, predictions)

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "f2": f2,
    }

# Load the model with the number of labels
num_labels = len(set(training["label"]))
model = DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=num_labels
)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",                  # Directory to save models and checkpoints
    evaluation_strategy="epoch",            # Evaluate after each epoch
    #learning_rate=2e-5,                     # Learning rate
    #per_device_train_batch_size=16,         # Training batch size
    #per_device_eval_batch_size=16,          # Evaluation batch size
    #num_train_epochs=3,                     # Number of epochs
    #weight_decay=0.01,                      # Regularization weight decay
    logging_dir="./logs",                   # Directory for logs
    logging_steps=10,                       # Log every 10 steps
    save_strategy="epoch",                  # Save checkpoint after each epoch
    metric_for_best_model="f1",             # Choose F1 as the metric to optimize
    #load_best_model_at_end=True             # Load the best model at the end of training
)

# Defining the Trainer
trainer = Trainer(
    model=model,                            # Model to train
    args=training_args,                     # Training arguments
    train_dataset=train_dataset,            # Training dataset
    eval_dataset=validation_dataset,        # Validation dataset
    tokenizer=tokenizer,                    # Tokenizer for data processing
    compute_metrics=compute_metrics         # Function to compute metrics
)

# Train the model
trainer.train()

# Evaluate the model on the test dataset
test_results = trainer.evaluate(test_dataset)

print("Test Metrics:", test_results)


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: sergejs (sergejs-tu-wien). Use `wandb login --relogin` to force relogin


Epoch,Training Loss,Validation Loss


Defining trainer and finetuning BERT

Training and evaluating