# 1. Data loading
Firstly, we will load the *OHSUMED* data from different files and merge them into one combined dataframe, containing all years from 1987 - 1991. Afterwards, we will load the file where the relevance labels of documents during all five years will be labeled. In the end, we will left join the relevance labels to our main dataframe.

In [3]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from transformers import DistilBertTokenizer
from transformers import DistilBertForSequenceClassification
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

def parse_ohsumed_file(file_path):
    """Parses an OHSUMED file into a DataFrame with proper column names."""
    documents = []
    document = {}

    # Read the file line by line
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()  # Remove extra whitespace

            if line.startswith(".I"):  # New document identifier
                if document:  # If there's an existing document, save it
                    documents.append(document)
                document = {"sequential identifier": line[3:]}  # Initialize a new document

            elif line.startswith(".U"):  # MEDLINE identifier
                document["MEDLINE identifier"] = next(f).strip()

            elif line.startswith(".S"):  # Source
                document["source"] = next(f).strip()

            elif line.startswith(".M"):  # MeSH terms
                document["mesh_terms"] = next(f).strip()

            elif line.startswith(".T"):  # Title
                document["title"] = next(f).strip()

            elif line.startswith(".P"):  # Publication type
                document["publication type"] = next(f).strip()

            elif line.startswith(".W"):  # Abstract
                document["abstract"] = next(f).strip()

            elif line.startswith(".A"):  # Author
                document["author"] = next(f).strip()

    # Add the last document if it exists
    if document:
        documents.append(document)

    # Convert the list of documents into a DataFrame
    return pd.DataFrame(documents)

# File paths for individual files
file_87_path = "./data/ohsumed.87.txt"
file_88_path = "./data/ohsumed.88.txt"
file_89_path = "./data/ohsumed.89.txt"
file_90_path = "./data/ohsumed.90.txt"
file_91_path = "./data/ohsumed.91.txt"

# Parse each file into its own DataFrame
df_ohsumed_87 = parse_ohsumed_file(file_87_path)
df_ohsumed_88 = parse_ohsumed_file(file_88_path)
df_ohsumed_89 = parse_ohsumed_file(file_89_path)
df_ohsumed_90 = parse_ohsumed_file(file_90_path)
df_ohsumed_91 = parse_ohsumed_file(file_91_path)

In [4]:
# Checking row amount of all years files
print(f"Number of rows in ohsumed.87: {len(df_ohsumed_87)}")
print(f"Number of rows in ohsumed.88: {len(df_ohsumed_88)}")
print(f"Number of rows in ohsumed.89: {len(df_ohsumed_89)}")
print(f"Number of rows in ohsumed.90: {len(df_ohsumed_90)}")
print(f"Number of rows in ohsumed.91: {len(df_ohsumed_91)}")

Number of rows in ohsumed.87: 54710
Number of rows in ohsumed.88: 70825
Number of rows in ohsumed.89: 74869
Number of rows in ohsumed.90: 73824
Number of rows in ohsumed.91: 74338


In [5]:
# Combine all DataFrames into a single DataFrame
ohsumed_combined_df = pd.concat([df_ohsumed_87, df_ohsumed_88, df_ohsumed_89, df_ohsumed_90, df_ohsumed_91], ignore_index=True)

# Print the combined DataFrame's shape (rows and columns). Should be: 348566
print(f"Combined DataFrame shape: {ohsumed_combined_df.shape}")

# Display the first few rows of the combined DataFrame
print(ohsumed_combined_df.head())

Combined DataFrame shape: (348566, 8)
  sequential identifier MEDLINE identifier                             source  \
0                     1           87049087    Am J Emerg Med 8703; 4(6):491-5   
1                     2           87049088  Am J Emerg Med 8703; 4(6):496-500   
2                     3           87049089    Am J Emerg Med 8703; 4(6):501-3   
3                     4           87049090    Am J Emerg Med 8703; 4(6):504-6   
4                     5           87049092    Am J Emerg Med 8703; 4(6):511-3   

                                          mesh_terms  \
0  Allied Health Personnel/*; Electric Countersho...   
1  Antidepressive Agents, Tricyclic/*PO; Arrhythm...   
2  Adult; Aircraft/*; Altitude/*; Blood Gas Monit...   
3  Adolescence; Adult; Aged; Blood Glucose/*ME; D...   
4  Aged; Aged, 80 and over; Case Report; Female; ...   

                                               title  publication type  \
0  Refibrillation managed by EMT-Ds: incidence an...  JOURNAL AR

In [6]:
# Load the judged file (relevance labeles)
judged_df = pd.read_csv("./data/judged.txt", sep="\t", header=None,
                        names=["Query", "Document-UI", "Document-Index", "Relevance1", "Relevance2", "Relevance3"])

print(f"Loaded judged file with {len(judged_df)} rows")
judged_df.head()

Loaded judged file with 16140 rows


Unnamed: 0,Query,Document-UI,Document-Index,Relevance1,Relevance2,Relevance3
0,1,87097544,40626,d,,d
1,1,87153566,11852,n,,n
2,1,87157536,12693,d,,
3,1,87157537,12694,d,,
4,1,87184723,15450,n,,


In [7]:
# Get unique documents
unique_docs = judged_df['Document-UI'].nunique()
judged_df = judged_df.drop_duplicates(subset=['Document-UI'], keep='first')

print(f"Unique documents: {unique_docs}")
print(f"New df size: {len(judged_df)}")

judged_df.head()

Unique documents: 14430
New df size: 14430


Unnamed: 0,Query,Document-UI,Document-Index,Relevance1,Relevance2,Relevance3
0,1,87097544,40626,d,,d
1,1,87153566,11852,n,,n
2,1,87157536,12693,d,,
3,1,87157537,12694,d,,
4,1,87184723,15450,n,,


In [8]:
judged_df.rename(columns={"Document-UI": "MEDLINE identifier"}, inplace=True)

# Converting key to string
ohsumed_combined_df["MEDLINE identifier"] = ohsumed_combined_df["MEDLINE identifier"].astype(int)
judged_df["MEDLINE identifier"] = judged_df["MEDLINE identifier"].astype(int)
judged_df["is_relevant_ind"] = 1

# Perform the left join
merged_df = ohsumed_combined_df.merge(judged_df, on="MEDLINE identifier", how="left")

# Display the resulting DataFrame
print(f"Resulting DataFrame shape: {merged_df.shape}")
print(merged_df.head())

# Create a new column Relevance_total based on the rules that can be kept showing relevance for all three relevance columns
merged_df["Relevance_total"] = np.where(
    ~merged_df["Relevance1"].isna(),  # If Relevance1 is not NaN, take it
    merged_df["Relevance1"],
    np.where(
        ~merged_df["Relevance2"].isna(),  # Else if Relevance2 is not NaN, take it
        merged_df["Relevance2"],
        merged_df["Relevance3"]  # Else take Relevance3
    )
)
merged_df.head(10000)

# Drop the specified columns
columns_to_drop = ["Query", "Document-Index", "Relevance1", "Relevance2", "Relevance3"]
merged_df.drop(columns=columns_to_drop, inplace=True)
# Mapping relevance labels to int
relevance_mapping = {'n': 0, 'p': 1, 'd': 2}

# Filling mising values in abstract so that BERT can be trained on strings
merged_df["abstract"] = merged_df["abstract"].fillna("")

# Apply the mapping to the Relevance1 column
merged_df["Relevance_total"] = merged_df["Relevance_total"].map(relevance_mapping)

# Display the resulting DataFrame
print(f"Resulting DataFrame shape: {merged_df.shape}")
merged_df.head()

#filtered_df = merged_df[merged_df["Relevance1"].notna()]
#
## Display the filtered DataFrame
#print(f"Number of rows where Relevance1 is not NaN: {len(filtered_df)}")
#print(filtered_df.head())

Resulting DataFrame shape: (348566, 14)
  sequential identifier  MEDLINE identifier  \
0                     1            87049087   
1                     2            87049088   
2                     3            87049089   
3                     4            87049090   
4                     5            87049092   

                              source  \
0    Am J Emerg Med 8703; 4(6):491-5   
1  Am J Emerg Med 8703; 4(6):496-500   
2    Am J Emerg Med 8703; 4(6):501-3   
3    Am J Emerg Med 8703; 4(6):504-6   
4    Am J Emerg Med 8703; 4(6):511-3   

                                          mesh_terms  \
0  Allied Health Personnel/*; Electric Countersho...   
1  Antidepressive Agents, Tricyclic/*PO; Arrhythm...   
2  Adult; Aircraft/*; Altitude/*; Blood Gas Monit...   
3  Adolescence; Adult; Aged; Blood Glucose/*ME; D...   
4  Aged; Aged, 80 and over; Case Report; Female; ...   

                                               title  publication type  \
0  Refibrillation managed

Unnamed: 0,sequential identifier,MEDLINE identifier,source,mesh_terms,title,publication type,abstract,author,is_relevant_ind,Relevance_total
0,1,87049087,Am J Emerg Med 8703; 4(6):491-5,Allied Health Personnel/*; Electric Countersho...,Refibrillation managed by EMT-Ds: incidence an...,JOURNAL ARTICLE.,Some patients converted from ventricular fibri...,Stults KR; Brown DD.,,
1,2,87049088,Am J Emerg Med 8703; 4(6):496-500,"Antidepressive Agents, Tricyclic/*PO; Arrhythm...",Tricyclic antidepressant overdose: emergency d...,JOURNAL ARTICLE.,There is controversy regarding the appropriate...,Foulke GE; Albertson TE; Walby WF.,,
2,3,87049089,Am J Emerg Med 8703; 4(6):501-3,Adult; Aircraft/*; Altitude/*; Blood Gas Monit...,Transconjunctival oxygen monitoring as a predi...,JOURNAL ARTICLE.,As the use of helicopters for air transport of...,Shufflebarger C; Jehle D; Cottington E; Martin M.,,
3,4,87049090,Am J Emerg Med 8703; 4(6):504-6,Adolescence; Adult; Aged; Blood Glucose/*ME; D...,Serum glucose changes after administration of ...,JOURNAL ARTICLE.,A prospective clinical trial was conducted to ...,Adler PM.,,
4,5,87049092,Am J Emerg Med 8703; 4(6):511-3,"Aged; Aged, 80 and over; Case Report; Female; ...",Nasogastric intubation: morbidity in an asympt...,JOURNAL ARTICLE.,An unusual case of a misdirected nasogastric t...,Gough D; Rust D.,,


In [9]:
#Checking how many relevant documents are present in merged DF. Expectation: 16140
count_is_relevant = merged_df[merged_df["is_relevant_ind"] == 1].shape[0]
print(f"Number of rows where is_relevant_ind = 1: {count_is_relevant}")

Number of rows where is_relevant_ind = 1: 14430


Now, we have a merged dataframe, containing all document data and also the relevance labeling of the documents. Now we can proceed with splitting the data.

In [10]:
test = merged_df[merged_df["is_relevant_ind"] == 1]
remaining_rows = merged_df[merged_df["is_relevant_ind"] != 1]
training, validation = train_test_split(remaining_rows, test_size=0.15, random_state=42)

print(f"Test set size: {len(test)}")
print(f"Training set size: {len(training)}")
print(f"Validation set size: {len(validation)}")

Test set size: 14430
Training set size: 284015
Validation set size: 50121


Training BERT

In [11]:
training.loc[:, 'title_abstract'] = training['title'] + ' ' + training['abstract']
test.loc[:, 'title_abstract'] = test['title'] + ' ' + test['abstract']
validation.loc[:, 'title_abstract'] = validation['title'] + ' ' + validation['abstract']
training.head()
test.head()
validation.head()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test.loc[:, 'title_abstract'] = test['title'] + ' ' + test['abstract']


Unnamed: 0,sequential identifier,MEDLINE identifier,source,mesh_terms,title,publication type,abstract,author,is_relevant_ind,Relevance_total,title_abstract
33697,33698,87124329,Am Heart J 8705; 113(2 Pt 1):273-9,Aged; Comparative Study; Electrocardiography/*...,Non-Q wave myocardial infarction: recent chang...,JOURNAL ARTICLE.,A community-wide study of patients hospitalize...,Goldberg RJ; Gore JM; Alpert JS; Dalen JE.,,,Non-Q wave myocardial infarction: recent chang...
245340,245341,90357619,Transplant Proc 9011; 22(4):1885-6,"Antibodies, Anti-Idiotypic/*IM; Antibodies, Mo...",IgM-anti-IgG antibody as cause of positive B-c...,JOURNAL ARTICLE.,,Terness P; Berteli AJ; Steinitz M; Mytillineos...,,,IgM-anti-IgG antibody as cause of positive B-c...
306208,306209,91170056,J Appl Physiol 9106; 69(6):2091-6,Animal; Blood Pressure; Cardiac Output; Hemody...,Altered baroreflex function after tail suspens...,JOURNAL ARTICLE.,Experiments were performed on conscious chroni...,Brizzee BL; Walker BR.,,,Altered baroreflex function after tail suspens...
5057,5058,87097540,Am J Obstet Gynecol 8704; 156(1):52-6,Apgar Score; Cesarean Section/*; Delivery/*MT;...,Randomized management of the second nonvertex ...,JOURNAL ARTICLE.,Sixty twin deliveries after the thirty-fifth g...,Rabinovici J; Barkai G; Reichman B; Serr DM; M...,,,Randomized management of the second nonvertex ...
120427,120428,88110706,Chest 8805; 93(2):294-8,Adult; Aged; Arteries/*; Bloodletting/*; Carbo...,Single arterial puncture vs arterial cannula f...,JOURNAL ARTICLE.,"In an attempt to find the least invasive, safe...",Frye M; DiBenedetto R; Lain D; Morgan K.,,,Single arterial puncture vs arterial cannula f...


In [12]:
import pandas as pd
import re

# Step 1: Parse the .bin file to extract sensitive MeSH terms
def parse_mesh_bin(file_path, target_categories):
    """
    Extract MeSH terms under specific categories from a .bin file.

    Args:
        file_path (str): Path to the .bin file.
        target_categories (list): List of categories (e.g., ["C12", "C13"]).

    Returns:
        list: List of MeSH terms under the target categories.
    """
    mesh_terms = []
    with open(file_path, "r") as f:
        for line in f:
            # Split by ';' to separate terms and categories
            parts = line.strip().split(";")
            if len(parts) > 1:
                term, category = parts[0].strip().lower(), parts[1].strip()
                # Include terms under target categories
                if any(category.startswith(target) for target in target_categories):
                    mesh_terms.append(term)
    return mesh_terms

# Path to the .bin file
file_path = "mtrees2019.bin"

# Extract MeSH terms under C12 and C13
sensitive_terms = parse_mesh_bin(file_path, ["C12", "C13"])

# Step 2: Preprocessing function for the `mesh_terms` column
def preprocess_mesh_terms(mesh_terms):
    """
    Normalize and preprocess the MeSH terms in a document.

    Args:
        mesh_terms (str): The raw MeSH terms for a document.

    Returns:
        list: A list of cleaned MeSH terms.
    """
    if isinstance(mesh_terms, str):
        terms = mesh_terms.split(";")
        return [re.sub(r"/.*", "", term).strip().lower() for term in terms]
    return []

# Step 3: Define the matching function using sensitive terms
sensitive_pattern = re.compile(r"\b(" + "|".join(re.escape(term) for term in sensitive_terms) + r")\b", re.IGNORECASE)

def is_sensitive_regex(terms):
    """
    Check if any term in the document is sensitive based on MeSH terms.

    Args:
        terms (list of str): Processed MeSH terms.

    Returns:
        int: 1 if sensitive, 0 otherwise.
    """
    return 1 if any(sensitive_pattern.search(term) for term in terms) else 0

# Step 4: Apply preprocessing and matching to datasets
for df in [training, validation, test]:
    df.loc[:, "processed_mesh_terms"] = df["mesh_terms"].apply(preprocess_mesh_terms)
    df.loc[:, "sensitive"] = df["processed_mesh_terms"].apply(is_sensitive_regex)
    df.loc[:, "label"] = df["sensitive"]

# Step 5: Calculate sensitive document percentages
datasets = [
    (pd.concat([training, validation, test]), "Full dataset"),
    (test, "Judged documents")
]

for df, name in datasets:
    percentage = df["label"].mean() * 100
    print(f"{name}: {percentage:.1f}%")

for df in [training, validation, test]:
    df.drop(columns=["sensitive"], inplace=True)


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.loc[:, "processed_mesh_terms"] = df["mesh_terms"].apply(preprocess_mesh_terms)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.loc[:, "sensitive"] = df["processed_mesh_terms"].apply(is_sensitive_regex)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.loc[:, "label"] = df["sensitive"]


Full dataset: 8.1%
Judged documents: 12.3%


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.drop(columns=["sensitive"], inplace=True)


In [13]:
for df, name in [(training, "Training"), (validation, "Validation"), (test, "Test")]:
    sensitive_docs = df[df["label"] == 1]
    total_docs = len(df)
    relative_percentage = (len(sensitive_docs) / total_docs) * 100
    
    print(f"Sensitive documents in {name}:")
    print(sensitive_docs[["mesh_terms", "processed_mesh_terms", "label"]].head())
    print(f"Total sensitive documents in {name}: {len(sensitive_docs)}")
    print(f"Relative amount of sensitive documents in {name}: {relative_percentage:.1f}%\n")

Sensitive documents in Training:
                                               mesh_terms  \
54170   Adult; Cardiovascular Diseases/ET/*MO; Diabete...   
3827    Antigens, Bacterial/*AN; Chlamydia trachomatis...   
63723   Adult; Evaluation Studies; Female; Hospitals, ...   
144305  Adult; Antineoplastic Agents, Combined/*TU; Ca...   
160573  Acidosis, Renal Tubular/*CO/ET; Adult; Case Re...   

                                     processed_mesh_terms  label  
54170   [adult, cardiovascular diseases, diabetes mell...      1  
3827    [antigens, bacterial, chlamydia trachomatis, c...      1  
63723   [adult, evaluation studies, female, hospitals,...      1  
144305  [adult, antineoplastic agents, combined, case ...      1  
160573  [acidosis, renal tubular, adult, case report, ...      1  
Total sensitive documents in Training: 22371
Relative amount of sensitive documents in Training: 7.9%

Sensitive documents in Validation:
                                               mesh_terms  \

In [14]:
from datasets import Dataset
train_dataset = Dataset.from_pandas(training)
test_dataset = Dataset.from_pandas(test)
validation_dataset = Dataset.from_pandas(validation)

Tokenizing sentences

In [15]:
# Load the tokenizer
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

# Tokenize the dataset for BERTs training
def tokenize_function(examples):
    return tokenizer(examples["title_abstract"], truncation=True, padding="max_length", max_length=512)

train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)
validation_dataset = validation_dataset.map(tokenize_function, batched=True)

train_dataset = train_dataset.remove_columns(["title_abstract"])
test_dataset = test_dataset.remove_columns(["title_abstract"])
validation_dataset = validation_dataset.remove_columns(["title_abstract"])

train_dataset.set_format("torch")
test_dataset.set_format("torch")
validation_dataset.set_format("torch")

Map:   0%|          | 0/284015 [00:00<?, ? examples/s]

Map:   0%|          | 0/14430 [00:00<?, ? examples/s]

Map:   0%|          | 0/50121 [00:00<?, ? examples/s]

In [16]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)

CUDA available: True
CUDA version: 12.1


Loading DistilBERT

In [17]:
print(training.columns)

Index(['sequential identifier', 'MEDLINE identifier', 'source', 'mesh_terms',
       'title', 'publication type', 'abstract', 'author', 'is_relevant_ind',
       'Relevance_total', 'title_abstract', 'processed_mesh_terms', 'label'],
      dtype='object')


In [19]:
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments 
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, fbeta_score 
import torch
import numpy as np

# Add GPU check function 
def check_gpu_availability():     
   if torch.cuda.is_available():         
       print(f"GPU available: {torch.cuda.get_device_name(0)}")         
       print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")         
       print(f"Available GPU memory: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")     
   else:         
       print("No GPU available, using CPU")  

# Define the compute_metrics function - FIXED VERSION with F2 
def compute_metrics(eval_pred):     
   logits, labels = eval_pred     
   predictions = logits.argmax(axis=-1)          
   precision = precision_score(labels, predictions, average="binary")     
   recall = recall_score(labels, predictions, average="binary")     
   f1 = f1_score(labels, predictions, average="binary")     
   f2 = fbeta_score(labels, predictions, beta=2, average="binary")     
   accuracy = accuracy_score(labels, predictions)          
   
   return {         
       "accuracy": accuracy,         
       "precision": precision,         
       "recall": recall,         
       "f1": f1,         
       "f2": f2     
   }  

check_gpu_availability()  

# Load the model with GPU support 
num_labels = len(set(training["label"])) 
model = DistilBertForSequenceClassification.from_pretrained(     
   "distilbert-base-uncased",     
   num_labels=num_labels 
).cuda()  

training_args = TrainingArguments(     
   output_dir="./results",     
   evaluation_strategy="epoch",     
   no_cuda=False,      # Enable GPU     
   fp16=True,         # Enable mixed precision training     
   logging_dir="./logs",     
   logging_steps=10,     
   save_strategy="epoch",     
   metric_for_best_model="f1" 
)  

trainer = Trainer(     
   model=model,     
   args=training_args,     
   train_dataset=train_dataset,     
   eval_dataset=validation_dataset,     
   tokenizer=tokenizer,     
   compute_metrics=compute_metrics 
)  

# Train the model
print("Starting training...")
trainer.train()

# Get validation predictions for threshold optimization
print("Getting validation predictions for threshold optimization...")
validation_output = trainer.predict(validation_dataset)
validation_logits = validation_output.predictions
validation_labels = validation_output.label_ids

# Convert logits to probabilities (on GPU)
validation_probabilities = torch.nn.functional.softmax(torch.tensor(validation_logits).cuda(), dim=-1)[:, 1].cpu().numpy()

# Find optimal threshold
print("Finding optimal threshold...")
thresholds = np.arange(0, 1.01, 0.01)
best_threshold = 0
best_f1 = 0

for threshold in thresholds:
   predictions = (validation_probabilities >= threshold).astype(int)
   f1 = f1_score(validation_labels, predictions)
   if f1 > best_f1:
       best_f1 = f1
       best_threshold = threshold

print(f"Best Threshold: {best_threshold}, Best F1-Score: {best_f1}")

# Apply best threshold to test set
print("Evaluating test set with optimal threshold...")
test_output = trainer.predict(test_dataset)
test_logits = test_output.predictions
test_labels = test_output.label_ids
test_probabilities = torch.nn.functional.softmax(torch.tensor(test_logits).cuda(), dim=-1)[:, 1].cpu().numpy()
test_predictions = (test_probabilities >= best_threshold).astype(int)

# Calculate and print final metrics
final_metrics = {
   "accuracy": accuracy_score(test_labels, test_predictions),
   "precision": precision_score(test_labels, test_predictions),
   "recall": recall_score(test_labels, test_predictions),
   "f1": f1_score(test_labels, test_predictions),
   "f2": fbeta_score(test_labels, test_predictions, beta=2)
}

print("\nFinal Test Metrics with optimized threshold:")
for metric, value in final_metrics.items():
   print(f"{metric}: {value:.4f}")

GPU available: NVIDIA GeForce RTX 2060 SUPER
Total GPU memory: 8.00 GB
Available GPU memory: 1.86 GB


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


  0%|          | 0/106506 [00:00<?, ?it/s]

{'loss': 0.3762, 'grad_norm': 1.5239543914794922, 'learning_rate': 4.9995305428802135e-05, 'epoch': 0.0}
{'loss': 0.3949, 'grad_norm': 2.042564868927002, 'learning_rate': 4.999061085760427e-05, 'epoch': 0.0}
{'loss': 0.1684, 'grad_norm': 0.5480949282646179, 'learning_rate': 4.9985916286406406e-05, 'epoch': 0.0}
{'loss': 0.2539, 'grad_norm': 0.3744119107723236, 'learning_rate': 4.998122171520854e-05, 'epoch': 0.0}
{'loss': 0.5055, 'grad_norm': 1.013243556022644, 'learning_rate': 4.997652714401067e-05, 'epoch': 0.0}
{'loss': 0.2877, 'grad_norm': 6.541798114776611, 'learning_rate': 4.9971832572812796e-05, 'epoch': 0.0}
{'loss': 0.3436, 'grad_norm': 3.2850570678710938, 'learning_rate': 4.9967138001614935e-05, 'epoch': 0.0}
{'loss': 0.3122, 'grad_norm': 1.4219810962677002, 'learning_rate': 4.996244343041707e-05, 'epoch': 0.0}
{'loss': 0.1098, 'grad_norm': 0.37721264362335205, 'learning_rate': 4.99577488592192e-05, 'epoch': 0.0}
{'loss': 0.1285, 'grad_norm': 0.1596882939338684, 'learning_rat

  0%|          | 0/6266 [00:00<?, ?it/s]

{'eval_loss': 0.15131372213363647, 'eval_accuracy': 0.9564653538436982, 'eval_precision': 0.7927494385627205, 'eval_recall': 0.6166708260544048, 'eval_f1': 0.6937113980909602, 'eval_f2': 0.6453382084095064, 'eval_runtime': 167.745, 'eval_samples_per_second': 298.793, 'eval_steps_per_second': 37.354, 'epoch': 1.0}
{'loss': 0.2288, 'grad_norm': 3.4459927082061768, 'learning_rate': 3.333849736165099e-05, 'epoch': 1.0}
{'loss': 0.1519, 'grad_norm': 0.2785826325416565, 'learning_rate': 3.333380279045312e-05, 'epoch': 1.0}
{'loss': 0.1149, 'grad_norm': 3.8876028060913086, 'learning_rate': 3.332910821925526e-05, 'epoch': 1.0}
{'loss': 0.1692, 'grad_norm': 0.2403293401002884, 'learning_rate': 3.332441364805739e-05, 'epoch': 1.0}
{'loss': 0.2197, 'grad_norm': 0.21083779633045197, 'learning_rate': 3.331971907685952e-05, 'epoch': 1.0}
{'loss': 0.0791, 'grad_norm': 1.4012254476547241, 'learning_rate': 3.3315024505661655e-05, 'epoch': 1.0}
{'loss': 0.0193, 'grad_norm': 0.12275075912475586, 'learnin

  0%|          | 0/6266 [00:00<?, ?it/s]

{'eval_loss': 0.19948692619800568, 'eval_accuracy': 0.9507192593922708, 'eval_precision': 0.8200749687630154, 'eval_recall': 0.49139006738208135, 'eval_f1': 0.6145443196004994, 'eval_f2': 0.5342123826577676, 'eval_runtime': 167.451, 'eval_samples_per_second': 299.317, 'eval_steps_per_second': 37.42, 'epoch': 2.0}
{'loss': 0.1404, 'grad_norm': 0.12999160587787628, 'learning_rate': 1.6678403094661334e-05, 'epoch': 2.0}
{'loss': 0.4279, 'grad_norm': 0.30992043018341064, 'learning_rate': 1.667370852346347e-05, 'epoch': 2.0}
{'loss': 0.1335, 'grad_norm': 26.28274154663086, 'learning_rate': 1.66690139522656e-05, 'epoch': 2.0}
{'loss': 0.0882, 'grad_norm': 0.15940260887145996, 'learning_rate': 1.6664319381067734e-05, 'epoch': 2.0}
{'loss': 0.0749, 'grad_norm': 20.11797523498535, 'learning_rate': 1.6659624809869866e-05, 'epoch': 2.0}
{'loss': 0.1215, 'grad_norm': 144.13897705078125, 'learning_rate': 1.6654930238672e-05, 'epoch': 2.0}
{'loss': 0.3128, 'grad_norm': 54.09147262573242, 'learning_r

  0%|          | 0/6266 [00:00<?, ?it/s]

{'eval_loss': 0.16985344886779785, 'eval_accuracy': 0.9609943935675664, 'eval_precision': 0.7821782178217822, 'eval_recall': 0.7097579236336411, 'eval_f1': 0.7442103885908674, 'eval_f2': 0.7231489015459723, 'eval_runtime': 167.387, 'eval_samples_per_second': 299.432, 'eval_steps_per_second': 37.434, 'epoch': 3.0}
{'train_runtime': 10844.2158, 'train_samples_per_second': 78.571, 'train_steps_per_second': 9.821, 'train_loss': 0.15172410184219182, 'epoch': 3.0}


  0%|          | 0/1804 [00:00<?, ?it/s]

Test Metrics: {'eval_loss': 0.24031223356723785, 'eval_accuracy': 0.9458766458766459, 'eval_precision': 0.8188775510204082, 'eval_recall': 0.720943290286356, 'eval_f1': 0.7667960585249328, 'eval_f2': 0.7386102162908421, 'eval_runtime': 48.726, 'eval_samples_per_second': 296.146, 'eval_steps_per_second': 37.023, 'epoch': 3.0}


Defining trainer and finetuning BERT

Training and evaluating