In [1]:
# ==============================================================================
# SCRIPT FOR CREATING A REPRODUCIBLE SUBSET AND CROSS-VALIDATION FOLDS
# ==============================================================================
#
# PURPOSE:
# This script takes a large dataset (e.g., the full MeDAL corpus) and
# creates a smaller, reproducible subset for experimentation. It also generates
# stratified 10-fold cross-validation splits to ensure that results are robust
# and that future research can use the exact same data partitions for fair
# comparison.
#
# Author: G. M. Farouk
# Date: 07/2025
#
# ==============================================================================

import pandas as pd
from sklearn.model_selection import StratifiedKFold
from google.colab import drive
import os

# --- 1. CONFIGURATION ---


# The number of cross-validation folds to create
N_SPLITS = 10

# A fixed random seed to ensure all random operations are reproducible
RANDOM_SEED = 42

# --- The column name in the CSV that will be CREATED to hold the unique ID for each row.
INSTANCE_ID_COLUMN = 'instance_id'

# --- The column name for your classification target.
# This is used for stratified splitting to maintain class balance in each fold.
# UPDATED to match your dataset's "LABEL" column.
TARGET_COLUMN = 'LABEL'


In [None]:
# --- Paths to original dataset download
!wget -nc -P data/ https://zenodo.org/record/4482922/files/train.csv
!wget -nc -P data/ https://zenodo.org/record/4482922/files/valid.csv
!wget -nc -P data/ https://zenodo.org/record/4482922/files/test.csv

--2024-09-28 18:25:24--  https://zenodo.org/record/4482922/files/train.csv
Resolving zenodo.org (zenodo.org)... 188.184.103.159, 188.185.79.172, 188.184.98.238, ...
Connecting to zenodo.org (zenodo.org)|188.184.103.159|:443... connected.
HTTP request sent, awaiting response... 301 MOVED PERMANENTLY
Location: /records/4482922/files/train.csv [following]
--2024-09-28 18:25:24--  https://zenodo.org/records/4482922/files/train.csv
Reusing existing connection to zenodo.org:443.
HTTP request sent, awaiting response... 200 OK
Length: 3541556520 (3.3G) [text/plain]
Saving to: ‘data/train.csv’


2024-09-28 18:29:52 (12.6 MB/s) - ‘data/train.csv’ saved [3541556520/3541556520]

--2024-09-28 18:29:52--  https://zenodo.org/record/4482922/files/valid.csv
Resolving zenodo.org (zenodo.org)... 188.185.79.172, 188.184.98.238, 188.184.103.159, ...
Connecting to zenodo.org (zenodo.org)|188.185.79.172|:443... connected.
HTTP request sent, awaiting response... 301 MOVED PERMANENTLY
Location: /records/448292

In [None]:
!pip install datasets
from datasets import load_dataset
dataset = load_dataset("medal")

In [2]:
# Paths to local/drive storage of dataset (updated by user to his dataset downloaded location)
DRIVE_PATH = '/content/drive/My Drive/'
# Path where the output files will be saved
OUTPUT_PATH = '/content/drive/My Drive/data_for_publication/'

# Name of your original dataset file
INPUT_FILENAME = 'medal_training.csv'

# --- 2. SETUP AND DATA LOADING ---

print("--- Starting Reproducible Subset Creation ---")
# Mount Google Drive to access files
try:
    drive.mount('/content/drive', force_remount=True) # force_remount can help avoid issues
    print("Google Drive mounted successfully.")
except Exception as e:
    print(f"Error mounting Google Drive: {e}")
    # Exit if drive mounting fails
    exit()


--- Starting Reproducible Subset Creation ---
Mounted at /content/drive
Google Drive mounted successfully.


In [5]:
# Create the output directory if it doesn't exist
print(f"Creating output directory: {OUTPUT_PATH}")
os.makedirs(OUTPUT_PATH, exist_ok=True)
print(f"Output directory ensured at: {OUTPUT_PATH}")

# Load the full downloaded dataset
full_dataset_path = os.path.join(DRIVE_PATH, INPUT_FILENAME)
try:
    print(f"Loading full dataset from: {full_dataset_path}...")
    df_full = pd.read_csv(full_dataset_path)
    print(f"Successfully loaded dataset with {len(df_full)} rows.")
except FileNotFoundError:
    print(f"ERROR: The file was not found at {full_dataset_path}. Please check the path and filename.")
    exit()

# --- [CRITICAL UPDATE] ---
# Create the instance_id column since it doesn't exist in the original file.
# We will use the original row number (the index) as the unique ID.
if INSTANCE_ID_COLUMN not in df_full.columns:
    print(f"'{INSTANCE_ID_COLUMN}' column not found. Creating it from the DataFrame's index...")
    df_full[INSTANCE_ID_COLUMN] = df_full.index
    print("Successfully created instance_id column.")
else:
    print(f"'{INSTANCE_ID_COLUMN}' column already exists.")

# Verify that necessary columns exist AFTER creating the instance_id
if INSTANCE_ID_COLUMN not in df_full.columns or TARGET_COLUMN not in df_full.columns:
    print(f"ERROR: The required columns ('{INSTANCE_ID_COLUMN}' or '{TARGET_COLUMN}') were not found in the CSV.")
    print("Please check the column names in your file and the script's configuration.")
    exit()


Creating output directory: /content/drive/My Drive/data_for_publication/
Output directory ensured at: /content/drive/My Drive/data_for_publication/
Loading full dataset from: /content/drive/My Drive/medal_training.csv...
Successfully loaded dataset with 100000 rows.
'instance_id' column not found. Creating it from the DataFrame's index...
Successfully created instance_id column.


In [11]:
#Extract ABBREV term from dataset and put in new column#
import re
def create_abbrev_and_extract_from_text(df):
    """
    Creates an "ABBREV" column by extracting the abbreviation
    from TEXT using the word in LOCATION index, stopping at a space.
    """

    def extract_abbr(row):
        text = row['TEXT']
        location = row['LOCATION']

        if isinstance(text, str):
           try:
             word_index = int(location)
             words = re.findall(r'\b\w+\b', text) # Tokenize by word
             if 0 <= word_index < len(words):
                extracted_abbrev = words[word_index] # Get word at index
                return extracted_abbrev
           except (ValueError, IndexError):
               pass # Handle cases where location is invalid
        return ''
    df_full['ABBREV'] = df_full.apply(extract_abbr, axis=1)
    print("Added 'ABBREV' column to DataFrame successfully.")
    return df


NEW_CSV_FILE_PATH = "medal_with_abbrev.csv"


# Add new columns
df_full = create_abbrev_and_extract_from_text(df_full)

# Save the updated dataframe to a new csv file
df_full.to_csv(NEW_CSV_FILE_PATH, index=False, encoding='utf-8')
print(f"Saved updated CSV to: {NEW_CSV_FILE_PATH}")

Added 'ABBREV' column to DataFrame successfully.
Saved updated CSV to: medal_with_abbrev.csv


In [12]:
#Extract top frequencues Abbreviations#
def extract_top_frequent_abbrevs(df, column_name, top_n):
    """
    Extracts the top N most frequent values from a column in a DataFrame and stores them in a new DataFrame.

    Args:
        df (pd.DataFrame): The input DataFrame.
        column_name (str): The name of the column to analyze.
        top_n (int): The number of top frequent values to extract.

    Returns:
        pd.DataFrame: A new DataFrame with the top N most frequent values and their counts.
    """
    # Count the occurrences of each value in the column
    value_counts = df[column_name].value_counts()

    # Extract the top N most frequent values
    top_values = value_counts.nlargest(top_n)

    # Convert the result to a DataFrame
    top_df = pd.DataFrame({'ABBREV': top_values.index, 'count': top_values.values})
    return top_df

def filter_dataframe_by_top_abbrevs(df, column_name, top_abbreviations_df):
    """
    Filters a DataFrame to keep rows where the specified column contains any of the top abbreviations.

    Args:
        df (pd.DataFrame): The input DataFrame.
        column_name (str): The name of the column containing the abbreviations.
        top_abbreviations_df (pd.DataFrame): A DataFrame with the top abbreviations.

    Returns:
        pd.DataFrame: A filtered DataFrame containing only the rows with top abbreviations.
    """
    top_abbrevs_list = top_abbreviations_df['ABBREV'].tolist()
    filtered_df = df[df[column_name].isin(top_abbrevs_list)]
    return filtered_df


def save_dataframe_to_csv(df, output_file_path):
    """Saves the DataFrame to a CSV file."""
    df.to_csv(output_file_path, index=False)
    print(f"DataFrame saved to: {output_file_path}")

In [21]:
# Parameters
column_to_analyze = 'ABBREV'
top_count = 500
output_csv_file = 'MeDAL_Training_Subset.csv' # Specify the path to where you want to save

# Get the top abbreviations
top_abbreviations_df = extract_top_frequent_abbrevs(df_full, column_to_analyze, top_count)

# Filter the dataframe
filtered_df = filter_dataframe_by_top_abbrevs(df_full, column_to_analyze, top_abbreviations_df)

# Save the filtered DataFrame to CSV
save_dataframe_to_csv(filtered_df, output_csv_file)

DataFrame saved to: MeDAL_Training_Subset.csv


In [23]:

# --- 3. CREATE THE REPRODUCIBLE SUBSET ---
N_SUBSET =filtered_df.shape[0] # Set the desired size of the subset
print(f"\nCreating a reproducible subset of {N_SUBSET} instances...")

# We use .sample() with a fixed random_state. This guarantees that the
# exact same "random" subset is chosen every time this script is run.
df_subset = df_full.sample(n=N_SUBSET, random_state=RANDOM_SEED)
df_subset = df_subset.reset_index(drop=True) # Reset index for clean processing

print(f"Subset created with {len(df_subset)} rows.")

# --- 4. CREATE STRATIFIED K-FOLD SPLITS ---

print(f"\nGenerating {N_SPLITS}-fold stratified cross-validation splits...")

# Initialize StratifiedKFold. shuffle=True and a fixed random_state ensure
# that the data is shuffled the same way before splitting every time.
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_SEED)

# Add a new 'fold' column to our subset DataFrame to store the fold number
df_subset['fold'] = -1

# Get the data (X) and target (y) for splitting
X = df_subset
y = df_subset[TARGET_COLUMN]

# Loop through the splits and assign each instance to a fold
for fold_num, (train_index, test_index) in enumerate(skf.split(X, y)):
    # The 'test_index' contains the indices for the instances in the current fold
    df_subset.loc[test_index, 'fold'] = fold_num
    print(f"Assigned {len(test_index)} instances to fold {fold_num}.")

# Verify that all instances were assigned a fold
if (df_subset['fold'] == -1).any():
    print("WARNING: Some rows were not assigned to a fold. Please check the process.")
else:
    print("All instances successfully assigned to a fold.")

# --- 5. SAVE THE OUTPUT FILES ---

print("\n--- Saving output files for publication ---")

# 1. The full subset data, including the new 'fold' column
subset_filename = 'MeDAL_subset_with_folds.csv'
subset_path = os.path.join(OUTPUT_PATH, subset_filename)
df_subset.to_csv(subset_path, index=False)
print(f"1. Saved subset with fold information to: {subset_path}")

# 2. A simple text file with only the unique instance IDs of the subset
ids_filename = 'MeDAL_subset_instance_ids.txt'
ids_path = os.path.join(OUTPUT_PATH, ids_filename)
df_subset[INSTANCE_ID_COLUMN].to_csv(ids_path, index=False, header=False)
print(f"2. Saved list of instance IDs to: {ids_path}")

# 3. A README file explaining the files
readme_filename = 'README.txt'
readme_path = os.path.join(OUTPUT_PATH, readme_filename)
with open(readme_path, 'w') as f:
    f.write("DATASET FILES FOR REPRODUCIBILITY\n")
    f.write("=====================================\n\n")
    f.write(f"This folder contains the data subset used in the paper: 'Harrensing Transformer Knowledge: A Novel Approach for Biomedical Abbreviation Disambiguation'.\n\n")
    f.write(f"1. {subset_filename}:\n")
    f.write(f"   - Contains the full subset from the data with instances used in the study.\n")
    f.write("   - The 'fold' column indicates which of the 10 cross-validation folds each instance belongs to (0-9).\n\n")
    f.write(f"2. {ids_filename}:\n")
    f.write("   - A lightweight text file containing only the unique instance IDs (corresponding to original row numbers) that make up our subset.\n")
    f.write("   - This file can be used to reconstruct the exact subset from the original full dataset.\n")

print(f"3. Saved README file to: {readme_path}")
print("\n--- Process complete! ---")


Creating a reproducible subset of 44669 instances...
Subset created with 44669 rows.

Generating 10-fold stratified cross-validation splits...




Assigned 4467 instances to fold 0.
Assigned 4467 instances to fold 1.
Assigned 4467 instances to fold 2.
Assigned 4467 instances to fold 3.
Assigned 4467 instances to fold 4.
Assigned 4467 instances to fold 5.
Assigned 4467 instances to fold 6.
Assigned 4467 instances to fold 7.
Assigned 4467 instances to fold 8.
Assigned 4466 instances to fold 9.
All instances successfully assigned to a fold.

--- Saving output files for publication ---
1. Saved subset with fold information to: /content/drive/My Drive/data_for_publication/MeDAL_subset_with_folds.csv
2. Saved list of instance IDs to: /content/drive/My Drive/data_for_publication/MeDAL_subset_instance_ids.txt
3. Saved README file to: /content/drive/My Drive/data_for_publication/README.txt

--- Process complete! ---
