In [1]:
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
import pickle
import os

from torch.utils.data import Dataset, DataLoader
from sklearn.utils.class_weight import compute_class_weight
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Read Data

In [None]:
dir = ''

In [None]:
# Load data
df_train = pd.read_csv(dir + '/trials_train.csv')
df_val = pd.read_csv(dir + '/trials_val.csv')
df_test = pd.read_csv(dir + '/trials_test.csv')

df_train.head()

Unnamed: 0,id,start_date,status,why_stopped,hasResults,phase,allocation,intervention_model,primary_purpose,acc_text,...,text_outcomes,text_criteria,dmc_oversight,fda_drug,fda_device,unapproved_device,ae_score,sae_events,other_ae_events,stringency_index
0,NCT05099822,2020-03-13,TERMINATED,Business objectives changed.,False,PHASE1,RANDOMIZED,SEQUENTIAL,TREATMENT,"This study aims to evaluate the safety, tolera...",...,Primary Outcomes: \n1. Measure: Incidence of A...,"Inclusion Criteria:\n* In good health, as dete...",False,False,False,,,,,30.09
1,NCT05225870,2021-01-01,COMPLETED,,False,,,SINGLE_GROUP,BASIC_SCIENCE,Colorectal carcinoma is one of the most aggres...,...,Primary Outcomes: \n1. Measure: immunohistoche...,Inclusion Criteria:\n* patients with colorecta...,True,False,False,,,,,71.76
2,NCT05617417,2021-05-05,COMPLETED,,False,,,SINGLE_GROUP,TREATMENT,We aimed to evaluate the efficacy of locally a...,...,Primary Outcomes: \n1. Measure: Changes in uri...,Inclusion Criteria:\n* The patient who has pur...,False,False,False,,,,,53.98
3,NCT03696576,2018-09-20,TERMINATED,"Due to the pandemic, recruitment ended earlier...",True,,RANDOMIZED,PARALLEL,TREATMENT,The larynx and vocal folds undergo many age-re...,...,Primary Outcomes: \n1. Measure: Voice Handicap...,Inclusion Criteria:\n* Age 65 or older\n* Diag...,False,False,True,,0.0,,,0.0
4,NCT02400723,2018-12-05,COMPLETED,,True,,RANDOMIZED,PARALLEL,TREATMENT,"Anxiety leads to poor quality of life, avoidan...",...,Primary Outcomes: \n1. Measure: Change in Anxi...,Inclusion Criteria:\n* Veterans aged 60 years ...,False,False,False,,0.005893,,{'Musculoskeletal and connective tissue disord...,0.0


# Load Embedding Model

Special tokens
- CLS - 101
- SEP - 102

**STEP:** Set embedding input size.

In [None]:
#SET
emb_input_size = 512

In [None]:
# Check if CUDA (GPU) is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load BioBERT
model_name = "dmis-lab/biobert-base-cased-v1.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
bb_model = AutoModel.from_pretrained(model_name)

# Move the model to GPU if available, otherwise CPU
bb_model = bb_model.to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

# Data Preprocess

## Classification Indicator
Create binary classification target indicator

In [None]:
def set_targets_terminated(df):
  status_mapping = {'COMPLETED': 0, 'TERMINATED': 1}  # Map strings to 0 and 1
  df.loc[:, 'terminated'] = df['status'].map(status_mapping)  # Create a new numeric column
  return df

## Target AE Risk Score

Process Risk Score. Add ceiling to range [0, 1.33]

In [None]:
##SET
ceiling=1.33

In [None]:
def cap_targets_ae_score(df, ceiling = ceiling):
  df['ae_score'] = df['ae_score'].clip(upper=ceiling)
  return df

## Date

In [None]:
def preprocess_date(df, date_col = 'start_date'):
    """Extracts month, year, cyclical month, and month-year interaction."""

    # Create a copy of the DataFrame
    df_copy = df.copy()

    # Convert the date column to datetime objects
    df_copy[date_col] = pd.to_datetime(df_copy[date_col])

    # Extract month and year
    df_copy['month'] = df_copy[date_col].dt.month
    df_copy['year'] = df_copy[date_col].dt.year

    # Cyclical encoding for month
    df_copy['month_sin'] = np.sin(2 * np.pi * df_copy['month'] / 12)
    df_copy['month_cos'] = np.cos(2 * np.pi * df_copy['month'] / 12)

    # Drop intermediate columns
    df_copy = df_copy.drop(columns=['month'])

    return df_copy

## Text
1. Tokenize each individual text section and create input IDs. Leave space for special tokens.
2. Add special tokens CLS and SEP
3. Add padding and attention masks




In [None]:
# Use this one to create input ids and attention mask for BioBERT unfrozen training
def tokenize_text_sections(df, tokenizer=tokenizer, max_length=emb_input_size, batch_size=128):

    all_intro_input_ids = []
    all_intro_attention_masks = []
    all_outcomes_input_ids = []
    all_outcomes_attention_masks = []
    all_criteria_input_ids = []
    all_criteria_attention_masks = []


    # Create batch
    for i in range(0, len(df), batch_size):
        batch_df = df.iloc[i:i + batch_size].copy() # Batch dataframe
        text_intros_batch = batch_df['text_intro'].tolist() # Batch of text rows
        text_outcomes_batch = batch_df['text_outcomes'].tolist()
        text_criteria_batch = batch_df['text_criteria'].tolist()

        # Handle potential NaN values from reading CSV default behavior
        text_intros_batch = [str(text) if not pd.isnull(text) else '' for text in text_intros_batch]
        text_outcomes_batch = [str(text) if not pd.isnull(text) else '' for text in text_outcomes_batch]
        text_criteria_batch = [str(text) if not pd.isnull(text) else '' for text in text_criteria_batch]

        # Batch tokenize the texts.
        # Tokenize intros
        encoded_intro_batch = tokenizer(text_intros_batch,
            max_length=max_length,
            truncation=True,
            padding='max_length',
            add_special_tokens=True,  # Let tokenizer add CLS and SEP
            return_attention_mask=True,
        )
        all_intro_input_ids.extend(encoded_intro_batch['input_ids'])
        all_intro_attention_masks.extend(encoded_intro_batch['attention_mask'])

        # Tokenize outcomes
        encoded_outcomes_batch = tokenizer(text_outcomes_batch,
            max_length=max_length,
            truncation=True,
            padding='max_length',
            add_special_tokens=True,  # Let tokenizer add CLS and SEP
            return_attention_mask=True,
        )
        all_outcomes_input_ids.extend(encoded_outcomes_batch['input_ids'])
        all_outcomes_attention_masks.extend(encoded_outcomes_batch['attention_mask'])

        # Tokenize criteria
        encoded_criteria_batch = tokenizer(text_criteria_batch,
            max_length=max_length,
            truncation=True,
            padding='max_length',
            add_special_tokens=True,  # Let tokenizer add CLS and SEP
            return_attention_mask=True,
        )
        all_criteria_input_ids.extend(encoded_criteria_batch['input_ids'])
        all_criteria_attention_masks.extend(encoded_criteria_batch['attention_mask'])


    df['intro_input_ids'] = all_intro_input_ids
    df['intro_attention_mask'] = all_intro_attention_masks
    df['outcomes_input_ids'] = all_outcomes_input_ids
    df['outcomes_attention_mask'] = all_outcomes_attention_masks
    df['criteria_input_ids'] = all_criteria_input_ids
    df['criteria_attention_mask'] = all_criteria_attention_masks

    return df

## Run Preprocess Steps

In [None]:
#SET
# Start with smaller data for exploring at first. Remove for full model run.
# df_train = df_train.head(35000)
# df_val = df_val.head(1000)
df_test = df_test.head(1000)

In [None]:
# Set targets terminated
df_train = set_targets_terminated(df_train)
df_val = set_targets_terminated(df_val)
df_test = set_targets_terminated(df_test)

# Set targets column to data splits
df_train = cap_targets_ae_score(df_train)
df_val = cap_targets_ae_score(df_val)
df_test = cap_targets_ae_score(df_test)

# Preproces date
df_train = preprocess_date(df_train)
df_val = preprocess_date(df_val)
df_test = preprocess_date(df_test)

# Tokenize, chunk
df_train = tokenize_text_sections(df_train)
df_val = tokenize_text_sections(df_val)
df_test = tokenize_text_sections(df_test)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.loc[:, 'terminated'] = df['status'].map(status_mapping)  # Create a new numeric column
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['ae_score'] = df['ae_score'].clip(upper=ceiling)


Will produce multiple chunk rows for each chunked trial. CLS tokens are to be aggregated, whether within the model or after, for classification.

## Prepare and Check
- Each token_id and attention mask length should be the set embedding input size, default 512.
- Each chunk should start with 101 and end with 102
- Padding and attention mask should be 0
- No chunking, the resulting number of rows should be equal to original


In [None]:
#SET split
df = df_train

print(f"Embedding input size: {emb_input_size}\n\nCheck:")

# Check
print(f"Check max token ids length: {df['intro_input_ids'].apply(len).max()}")
print(f"Check min token ids length: {df['intro_input_ids'].apply(len).min()}")
print(f"Check max attention mask length: {df['intro_attention_mask'].apply(len).max()}")
print(f"Check min attention mask length: {df['intro_attention_mask'].apply(len).min()}")

print(f"Check max token ids length: {df['outcomes_input_ids'].apply(len).max()}")
print(f"Check min token ids length: {df['outcomes_input_ids'].apply(len).min()}")
print(f"Check max attention mask length: {df['outcomes_attention_mask'].apply(len).max()}")
print(f"Check min attention mask length: {df['outcomes_attention_mask'].apply(len).min()}")

print(f"Check max token ids length: {df['criteria_input_ids'].apply(len).max()}")
print(f"Check min token ids length: {df['criteria_input_ids'].apply(len).min()}")
print(f"Check max attention mask length: {df['criteria_attention_mask'].apply(len).max()}")
print(f"Check min attention mask length: {df['criteria_attention_mask'].apply(len).min()}")

# Check known example k
k=0
print(f"k = {k}")
print(f"Check token_ids: {df['intro_input_ids'].iloc[k]}")
print(f"Check attention_mask: {df['intro_attention_mask'].iloc[k]}")
print(f"Check token_ids: {df['outcomes_input_ids'].iloc[k]}")
print(f"Check attention_mask: {df['outcomes_attention_mask'].iloc[k]}")
print(f"Check token_ids: {df['criteria_input_ids'].iloc[k]}")
print(f"Check attention_mask: {df['criteria_attention_mask'].iloc[k]}")

display(df)

Embedding input size: 512

Check:
Check max token ids length: 512
Check min token ids length: 512
Check max attention mask length: 512
Check min attention mask length: 512
Check max token ids length: 512
Check min token ids length: 512
Check max attention mask length: 512
Check min attention mask length: 512
Check max token ids length: 512
Check min token ids length: 512
Check max attention mask length: 512
Check min attention mask length: 512
k = 0
Check token_ids: [101, 1142, 2025, 8469, 1106, 17459, 1103, 3429, 117, 1106, 2879, 6328, 117, 1104, 14402, 118, 5311, 19203, 1580, 2975, 131, 8071, 8118, 3443, 1739, 131, 122, 119, 3107, 131, 3469, 1104, 14402, 118, 5311, 19203, 1580, 2076, 131, 6700, 9108, 16124, 1116, 131, 3850, 131, 14402, 118, 5311, 19203, 1580, 123, 119, 3107, 131, 3469, 1104, 1282, 4043, 2076, 131, 6700, 9108, 16124, 1116, 131, 1168, 131, 1282, 4043, 3443, 22496, 131, 122, 119, 2076, 131, 3850, 1271, 131, 14402, 118, 5311, 19203, 1580, 6136, 131, 9467, 13753, 1113, 94

Unnamed: 0,id,start_date,status,why_stopped,hasResults,phase,allocation,intervention_model,primary_purpose,acc_text,...,terminated,year,month_sin,month_cos,intro_input_ids,intro_attention_mask,outcomes_input_ids,outcomes_attention_mask,criteria_input_ids,criteria_attention_mask
0,NCT05099822,2020-03-13,TERMINATED,Business objectives changed.,False,PHASE1,RANDOMIZED,SEQUENTIAL,TREATMENT,"This study aims to evaluate the safety, tolera...",...,1,2020,1.000000e+00,6.123234e-17,"[101, 1142, 2025, 8469, 1106, 17459, 1103, 342...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 1107, 1363, 2332,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1,NCT05225870,2021-01-01,COMPLETED,,False,,,SINGLE_GROUP,BASIC_SCIENCE,Colorectal carcinoma is one of the most aggres...,...,0,2021,5.000000e-01,8.660254e-01,"[101, 2942, 10294, 6163, 1610, 16430, 7903, 11...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 4420, 1114, 2942,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,NCT05617417,2021-05-05,COMPLETED,,False,,,SINGLE_GROUP,TREATMENT,We aimed to evaluate the efficacy of locally a...,...,0,2021,5.000000e-01,-8.660254e-01,"[101, 1195, 5850, 1106, 17459, 1103, 23891, 11...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 1103, 5351, 1150,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
3,NCT03696576,2018-09-20,TERMINATED,"Due to the pandemic, recruitment ended earlier...",True,,RANDOMIZED,PARALLEL,TREATMENT,The larynx and vocal folds undergo many age-re...,...,1,2018,-1.000000e+00,-1.836970e-16,"[101, 1103, 2495, 15023, 1775, 1105, 5563, 173...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 1425, 2625, 1137,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
4,NCT02400723,2018-12-05,COMPLETED,,True,,RANDOMIZED,PARALLEL,TREATMENT,"Anxiety leads to poor quality of life, avoidan...",...,0,2018,-2.449294e-16,1.000000e+00,"[101, 10507, 4501, 1106, 2869, 3068, 1104, 129...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 11461, 4079, 2539...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
47935,NCT03277495,2023-07-12,COMPLETED,,False,,RANDOMIZED,PARALLEL,TREATMENT,The primary goal of this study is to examine w...,...,0,2023,-5.000000e-01,-8.660254e-01,"[101, 1103, 2425, 2273, 1104, 1142, 2025, 1110...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 122, 119, 1441, 1105, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
47936,NCT05284409,2022-07-01,COMPLETED,,False,PHASE4,RANDOMIZED,PARALLEL,SUPPORTIVE_CARE,Single Shot Spinal anesthesia (SSSA) is associ...,...,0,2022,-5.000000e-01,-8.660254e-01,"[101, 1423, 2046, 19245, 1126, 2556, 27300, 11...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 1821, 26237, 1389...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
47937,NCT03986502,2021-01-22,COMPLETED,,False,,RANDOMIZED,PARALLEL,HEALTH_SERVICES_RESEARCH,This trial studies how well a financial naviga...,...,0,2021,5.000000e-01,8.660254e-01,"[101, 1142, 3443, 2527, 1293, 1218, 170, 2798,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 5351, 131, 4035, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
47938,NCT03684473,2018-10-31,COMPLETED,,False,,,SINGLE_GROUP,TREATMENT,"Post-traumatic stress disorder (PTSD), a chron...",...,0,2018,-8.660254e-01,5.000000e-01,"[101, 2112, 118, 23057, 6600, 8936, 113, 185, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 7401, 1106, 122, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."


## Feature Selection

Features and other features for model input

In [None]:
# Feature types
id_col = 'id'
date_col = 'start_date'
numerical_cols = ['stringency_index',
                  # 'year',
                  # 'month_sin',
                  # 'month_cos'
                  ]
categorical_cols = ['phase',
                    'allocation',
                    'intervention_model',
                    'primary_purpose',
                    'dmc_oversight',
                    'fda_drug',
                    'fda_device',
                    'unapproved_device'
                    ]
target_term_col = 'terminated'
target_score_col = 'ae_score'

## Handling Nulls

For robustness, set categorical missing values to explicit null value

In [None]:
# For data consistency, set dataframe null to explicity representation of null
def fill_nan_categorical(df, categorical_cols, fill_value="null"):
    for col in categorical_cols:
        # Replace actual np.nan
        df[col] = df[col].fillna(fill_value)
        # Replace literal 'NA' string
        df[col] = df[col].replace('NA', fill_value)

In [None]:
fill_nan_categorical(df_train, categorical_cols)
fill_nan_categorical(df_val, categorical_cols)
fill_nan_categorical(df_test, categorical_cols)

  df[col] = df[col].fillna(fill_value)


## Time Series Sorting

In [None]:
# def sort_data(df):
#   ''' Sort by date, chunk sequence number '''
#   df = df.sort_values(by=['start_date', 'id', 'chunk_seq'])
#   return df

# df_train = sort_data(df_train)
# df_val = sort_data(df_val)
# df_test = sort_data(df_test)

# Display
df_train

Unnamed: 0,id,start_date,status,why_stopped,hasResults,phase,allocation,intervention_model,primary_purpose,acc_text,...,terminated,year,month_sin,month_cos,intro_input_ids,intro_attention_mask,outcomes_input_ids,outcomes_attention_mask,criteria_input_ids,criteria_attention_mask
0,NCT05099822,2020-03-13,TERMINATED,Business objectives changed.,False,PHASE1,RANDOMIZED,SEQUENTIAL,TREATMENT,"This study aims to evaluate the safety, tolera...",...,1,2020,1.000000e+00,6.123234e-17,"[101, 1142, 2025, 8469, 1106, 17459, 1103, 342...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 1107, 1363, 2332,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1,NCT05225870,2021-01-01,COMPLETED,,False,,,SINGLE_GROUP,BASIC_SCIENCE,Colorectal carcinoma is one of the most aggres...,...,0,2021,5.000000e-01,8.660254e-01,"[101, 2942, 10294, 6163, 1610, 16430, 7903, 11...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 4420, 1114, 2942,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,NCT05617417,2021-05-05,COMPLETED,,False,,,SINGLE_GROUP,TREATMENT,We aimed to evaluate the efficacy of locally a...,...,0,2021,5.000000e-01,-8.660254e-01,"[101, 1195, 5850, 1106, 17459, 1103, 23891, 11...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 1103, 5351, 1150,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
3,NCT03696576,2018-09-20,TERMINATED,"Due to the pandemic, recruitment ended earlier...",True,,RANDOMIZED,PARALLEL,TREATMENT,The larynx and vocal folds undergo many age-re...,...,1,2018,-1.000000e+00,-1.836970e-16,"[101, 1103, 2495, 15023, 1775, 1105, 5563, 173...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 1425, 2625, 1137,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
4,NCT02400723,2018-12-05,COMPLETED,,True,,RANDOMIZED,PARALLEL,TREATMENT,"Anxiety leads to poor quality of life, avoidan...",...,0,2018,-2.449294e-16,1.000000e+00,"[101, 10507, 4501, 1106, 2869, 3068, 1104, 129...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 11461, 4079, 2539...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
47935,NCT03277495,2023-07-12,COMPLETED,,False,,RANDOMIZED,PARALLEL,TREATMENT,The primary goal of this study is to examine w...,...,0,2023,-5.000000e-01,-8.660254e-01,"[101, 1103, 2425, 2273, 1104, 1142, 2025, 1110...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 122, 119, 1441, 1105, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
47936,NCT05284409,2022-07-01,COMPLETED,,False,PHASE4,RANDOMIZED,PARALLEL,SUPPORTIVE_CARE,Single Shot Spinal anesthesia (SSSA) is associ...,...,0,2022,-5.000000e-01,-8.660254e-01,"[101, 1423, 2046, 19245, 1126, 2556, 27300, 11...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 1821, 26237, 1389...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
47937,NCT03986502,2021-01-22,COMPLETED,,False,,RANDOMIZED,PARALLEL,HEALTH_SERVICES_RESEARCH,This trial studies how well a financial naviga...,...,0,2021,5.000000e-01,8.660254e-01,"[101, 1142, 3443, 2527, 1293, 1218, 170, 2798,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 5351, 131, 4035, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
47938,NCT03684473,2018-10-31,COMPLETED,,False,,,SINGLE_GROUP,TREATMENT,"Post-traumatic stress disorder (PTSD), a chron...",...,0,2018,-8.660254e-01,5.000000e-01,"[101, 2112, 118, 23057, 6600, 8936, 113, 185, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 7401, 1106, 122, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."


# Model

## Preprocess Fitting on Train
Due to time series nature, do not scale year

In [None]:
#SET
numerical_excl = ['year']

# Pipeline fit on training data
def fit_pipeline_train(df, numerical_cols, categorical_cols):
    # fit numerical scaler
    numerical_scalers = {}
    for col in numerical_cols:
        if col in numerical_excl: continue # Do not scale year
        scaler = StandardScaler()
        scaler.fit(df[[col]])  # Fit on a DataFrame with a single column
        numerical_scalers[col] = scaler

    # fit categorical mappings
    categorical_mappings = {} # Store category mappings for all categorical cols
    for col in categorical_cols:
        unique_categories = df[col].unique()
        category_mapping = {category: i+1 for i, category in enumerate(unique_categories)} # i+1 to start index from 1
        categorical_mappings[col] = category_mapping

    return numerical_scalers, categorical_mappings # Return mappings instead of encoders.

# Fit
numerical_scalers, categorical_mappings = fit_pipeline_train(df_train, numerical_cols, categorical_cols)

In [None]:
# Pipeline transform data
def pipeline_transform(df, numerical_scalers, categorical_mappings):
    ''' Iterate through the scalers and mappings to modify relevant features
    A KeyError would indicate column missing in the dataframe '''

    # Transform numerical columns (iterate through scaler keys)
    for col in numerical_scalers:
        df[col] = numerical_scalers[col].transform(df[[col]])  # Transform single column

    # Transform categorical columns (iterate through mapping keys)
    for col in categorical_mappings:
        mapping = categorical_mappings[col]
        df[col] = df[col].apply(lambda x: mapping.get(x, 0))  # 0 to handle unseen values

    return df


# Transform
df_train = pipeline_transform(df_train,
                              numerical_scalers,
                              categorical_mappings
                              )
df_val = pipeline_transform(df_val,
                            numerical_scalers,
                            categorical_mappings
                            )
df_test = pipeline_transform(df_test,
                             numerical_scalers,
                             categorical_mappings
                             )

# Display
df_train

Unnamed: 0,id,start_date,status,why_stopped,hasResults,phase,allocation,intervention_model,primary_purpose,acc_text,...,terminated,year,month_sin,month_cos,intro_input_ids,intro_attention_mask,outcomes_input_ids,outcomes_attention_mask,criteria_input_ids,criteria_attention_mask
0,NCT05099822,2020-03-13,TERMINATED,Business objectives changed.,False,1,1,1,1,"This study aims to evaluate the safety, tolera...",...,1,2020,1.000000e+00,6.123234e-17,"[101, 1142, 2025, 8469, 1106, 17459, 1103, 342...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 1107, 1363, 2332,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1,NCT05225870,2021-01-01,COMPLETED,,False,2,2,2,2,Colorectal carcinoma is one of the most aggres...,...,0,2021,5.000000e-01,8.660254e-01,"[101, 2942, 10294, 6163, 1610, 16430, 7903, 11...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 4420, 1114, 2942,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,NCT05617417,2021-05-05,COMPLETED,,False,2,2,2,1,We aimed to evaluate the efficacy of locally a...,...,0,2021,5.000000e-01,-8.660254e-01,"[101, 1195, 5850, 1106, 17459, 1103, 23891, 11...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 1103, 5351, 1150,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
3,NCT03696576,2018-09-20,TERMINATED,"Due to the pandemic, recruitment ended earlier...",True,2,1,3,1,The larynx and vocal folds undergo many age-re...,...,1,2018,-1.000000e+00,-1.836970e-16,"[101, 1103, 2495, 15023, 1775, 1105, 5563, 173...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 1425, 2625, 1137,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
4,NCT02400723,2018-12-05,COMPLETED,,True,2,1,3,1,"Anxiety leads to poor quality of life, avoidan...",...,0,2018,-2.449294e-16,1.000000e+00,"[101, 10507, 4501, 1106, 2869, 3068, 1104, 129...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 11461, 4079, 2539...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
47935,NCT03277495,2023-07-12,COMPLETED,,False,2,1,3,1,The primary goal of this study is to examine w...,...,0,2023,-5.000000e-01,-8.660254e-01,"[101, 1103, 2425, 2273, 1104, 1142, 2025, 1110...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 122, 119, 1441, 1105, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
47936,NCT05284409,2022-07-01,COMPLETED,,False,6,1,3,3,Single Shot Spinal anesthesia (SSSA) is associ...,...,0,2022,-5.000000e-01,-8.660254e-01,"[101, 1423, 2046, 19245, 1126, 2556, 27300, 11...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 1821, 26237, 1389...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
47937,NCT03986502,2021-01-22,COMPLETED,,False,2,1,3,7,This trial studies how well a financial naviga...,...,0,2021,5.000000e-01,8.660254e-01,"[101, 1142, 3443, 2527, 1293, 1218, 170, 2798,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 5351, 131, 4035, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
47938,NCT03684473,2018-10-31,COMPLETED,,False,2,2,2,1,"Post-traumatic stress disorder (PTSD), a chron...",...,0,2018,-8.660254e-01,5.000000e-01,"[101, 2112, 118, 23057, 6600, 8936, 113, 185, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 2425, 13950, 131, 122, 119, 4929, 131, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[101, 10838, 9173, 131, 115, 7401, 1106, 122, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."


Save scalers and mappers

In [None]:
# # Save
# dir = '/content/drive/MyDrive/W210-Capstone-ClincalGroup/trial_risk_models'
# numerical_scaler_filename = os.path.join(dir, 'numerical_scalers.pkl')
# categorical_mapping_filename = os.path.join(dir, 'categorical_mappings.pkl')

# try:
#     with open(numerical_scaler_filename, 'wb') as f:
#         pickle.dump(numerical_scalers, f)
#     print(f"Numerical scalers saved to: {numerical_scaler_filename}")
# except Exception as e:
#     print(f"Error saving numerical scalers: {e}")

# try:
#     with open(categorical_mapping_filename, 'wb') as f:
#         pickle.dump(categorical_mappings, f)
#     print(f"Categorical mappings saved to: {categorical_mapping_filename}")
# except Exception as e:
#     print(f"Error saving categorical mappings: {e}")

## Torch Data Preparation

In [None]:
# Define PyTorch Dataset class
class TrialDataset(Dataset):
    def __init__(self, dataframe, id_col, date_col, categorical_cols, numerical_cols,
                 intro_ids_col, intro_mask_col,
                 outcomes_ids_col, outcomes_mask_col,
                 criteria_ids_col, criteria_mask_col,
                 target_term_col=None, target_score_col=None):
        self.id_col = id_col
        self.date_col = date_col
        self.categorical_cols = categorical_cols
        self.numerical_cols = numerical_cols
        self.intro_ids_col = intro_ids_col
        self.intro_mask_col = intro_mask_col
        self.outcomes_ids_col = outcomes_ids_col
        self.outcomes_mask_col = outcomes_mask_col
        self.criteria_ids_col = criteria_ids_col
        self.criteria_mask_col = criteria_mask_col
        self.target_term_col = target_term_col
        self.target_score_col = target_score_col
        #
        self.data = dataframe

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):  # For dataloader
        item = self.data.iloc[idx]
        trial_id = item[self.id_col]
        intro_input_ids = torch.tensor(item[self.intro_ids_col], dtype=torch.long)
        intro_attention_mask = torch.tensor(item[self.intro_mask_col], dtype=torch.long)
        outcomes_input_ids = torch.tensor(item[self.outcomes_ids_col], dtype=torch.long)
        outcomes_attention_mask = torch.tensor(item[self.outcomes_mask_col], dtype=torch.long)
        criteria_input_ids = torch.tensor(item[self.criteria_ids_col], dtype=torch.long)
        criteria_attention_mask = torch.tensor(item[self.criteria_mask_col], dtype=torch.long)
        categorical_inputs = torch.tensor([item[col] for col in self.categorical_cols], dtype=torch.long)
        numerical_inputs = torch.tensor([item[col] for col in self.numerical_cols], dtype=torch.float)

        dloader_dict =  {
            'id': trial_id,
            'intro_input_ids': intro_input_ids,
            'intro_attention_mask': intro_attention_mask,
            'outcomes_input_ids': outcomes_input_ids,
            'outcomes_attention_mask': outcomes_attention_mask,
            'criteria_input_ids': criteria_input_ids,
            'criteria_attention_mask': criteria_attention_mask,
            'categorical_inputs': categorical_inputs,
            'numerical_inputs': numerical_inputs
        }
        # Target is optional to account for new-world data
        if self.target_term_col is not None:
            target = torch.tensor(item[self.target_term_col], dtype=torch.long)
            dloader_dict['targets_term'] = target
        if self.target_score_col is not None:
            target = torch.tensor(item[self.target_score_col], dtype=torch.float)
            dloader_dict['targets_score'] = target

        return dloader_dict

In [None]:
#TEST
# Create sample PyTorch Dataset to test
trial_dataset = TrialDataset(df_train.head(100),
                             id_col=id_col,
                             date_col=date_col,
                             categorical_cols=categorical_cols,
                             numerical_cols=numerical_cols,
                             intro_ids_col='intro_input_ids', intro_mask_col='intro_attention_mask',
                             outcomes_ids_col='outcomes_input_ids', outcomes_mask_col='outcomes_attention_mask',
                             criteria_ids_col='criteria_input_ids', criteria_mask_col='criteria_attention_mask',
                             target_term_col=target_term_col, target_score_col=target_score_col
                             )

# Create Dataloader object to view
trial_dataloader = DataLoader(trial_dataset, batch_size=10, shuffle=False)


# Check dataloader data
for batch_idx, batch in enumerate(trial_dataloader):
    print(f"Batch {batch_idx}:")
    print("  ID:", batch['id'])
    print("  Intro Input IDs (Tensor):", batch['intro_input_ids'])
    print("  Intro Attention Mask (Tensor):", batch['intro_attention_mask'])
    print("  Outcomes Input IDs (Tensor):", batch['outcomes_input_ids'])
    print("  Outcomes Attention Mask (Tensor):", batch['outcomes_attention_mask'])
    print("  Criteria Input IDs (Tensor):", batch['criteria_input_ids'])
    print("  Cirteria Attention Mask (Tensor):", batch['criteria_attention_mask'])
    print("  Categorical Data:", batch['categorical_inputs'])
    print("  Numerical Data:", batch['numerical_inputs'])
    print("  Targets Terminated:", batch['targets_term'])
    print("  Targets Risk Score:", batch['targets_score'])
    if batch_idx == 0: # Just print the first batch for demonstration
        break

Batch 0:
  ID: ['NCT05099822', 'NCT05225870', 'NCT05617417', 'NCT03696576', 'NCT02400723', 'NCT06043479', 'NCT04828642', 'NCT05778227', 'NCT05887349', 'NCT03871725']
  Intro Input IDs (Tensor): tensor([[  101,  1142,  2025,  ...,     0,     0,     0],
        [  101,  2942, 10294,  ...,  5511,  1209,   102],
        [  101,  1195,  5850,  ...,  3987, 20497,   102],
        ...,
        [  101,  1103,  3501,  ...,   119,  1160,   102],
        [  101,  9071, 14928,  ..., 10024, 12149,   102],
        [  101,  1488, 22320,  ...,  8241,  1988,   102]])
  Intro Attention Mask (Tensor): tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])
  Outcomes Input IDs (Tensor): tensor([[  101,  2425, 13950,  ...,   119,  4929,   102],
        [  101,  2425, 13950,  ...,     0,     0,     0],
        [  101,  2425, 13950,  ...,     0,    

Create (actual) model Dataset and Dataloader objects for model development


In [None]:
# Create (actual) model Dataset and Dataloader objects for model development
train_dataset = TrialDataset(df_train,
                             id_col=id_col,
                             date_col=date_col,
                             categorical_cols=categorical_cols,
                             numerical_cols=numerical_cols,
                             intro_ids_col='intro_input_ids', intro_mask_col='intro_attention_mask',
                             outcomes_ids_col='outcomes_input_ids', outcomes_mask_col='outcomes_attention_mask',
                             criteria_ids_col='criteria_input_ids', criteria_mask_col='criteria_attention_mask',
                             target_term_col=target_term_col, target_score_col=target_score_col
                             )

val_dataset = TrialDataset(df_val,
                             id_col=id_col,
                             date_col=date_col,
                             categorical_cols=categorical_cols,
                             numerical_cols=numerical_cols,
                             intro_ids_col='intro_input_ids', intro_mask_col='intro_attention_mask',
                             outcomes_ids_col='outcomes_input_ids', outcomes_mask_col='outcomes_attention_mask',
                             criteria_ids_col='criteria_input_ids', criteria_mask_col='criteria_attention_mask',
                             target_term_col=target_term_col, target_score_col=target_score_col
                             )

test_dataset = TrialDataset(df_test,
                             id_col=id_col,
                             date_col=date_col,
                             categorical_cols=categorical_cols,
                             numerical_cols=numerical_cols,
                             intro_ids_col='intro_input_ids', intro_mask_col='intro_attention_mask',
                             outcomes_ids_col='outcomes_input_ids', outcomes_mask_col='outcomes_attention_mask',
                             criteria_ids_col='criteria_input_ids', criteria_mask_col='criteria_attention_mask',
                             target_term_col=target_term_col, target_score_col=target_score_col
                             )

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

## Create Model

In [None]:
from transformers import AutoModel
import torch.nn as nn
import torch.nn.functional as F

class Terminate_Model(nn.Module):
    def __init__(self,
                 num_categorical_features=None,
                 categorical_embedding_dims=[],
                 num_numerical_features=None,
                 embed_model=bb_model):

        super(Terminate_Model, self).__init__()
        self.biobert = embed_model

        self.categorical_embeddings = nn.ModuleList([
            nn.Embedding(num_embeddings, embedding_dim)
            for num_embeddings, embedding_dim in categorical_embedding_dims
        ])
        self.num_categorical_features = num_categorical_features

        self.numerical_bn = nn.BatchNorm1d(num_numerical_features)  # Creates and assigns layer (a method) designed for given number of numerical features
        self.num_numerical_features = num_numerical_features

        # Combine features
        combined_input_dim = 3 * self.biobert.config.hidden_size  # 3 Outputs of BioBERT
        if num_categorical_features is not None:
            combined_input_dim += sum([dim for _, dim in categorical_embedding_dims])
        if num_numerical_features is not None:
            combined_input_dim += num_numerical_features

        # Weighting layers
        # Modality weights: 3 CLS embeddings
        self.text_modality_weights = nn.Parameter(torch.ones(3))  # shape (3,)
        # Per-feature categorical weights
        self.categorical_feature_weights = nn.Parameter(torch.ones(num_categorical_features))  # shape (num_categorical_features,)
        # Single numerical feature
        self.numerical_feature_weights = nn.Parameter(torch.ones(num_numerical_features)) # shape (num_numerical_features,)

        # Layers
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.3)
        self.linear1 = nn.Linear(combined_input_dim, 1024)
        self.bn1 = nn.BatchNorm1d(1024)
        self.linear2 = nn.Linear(1024, 512) # Second hidden layer
        self.bn2 = nn.BatchNorm1d(512) # BatchNorm for the second hidden layer
        self.linear3 = nn.Linear(512, 256) # Second hidden layer
        self.bn3 = nn.BatchNorm1d(256) # BatchNorm for the second hidden layer
        self.finallinear = nn.Linear(256, 2) # Output layer


    def forward(self, categorical_inputs, numerical_inputs,
                intro_input_ids, intro_attention_mask,
                outcomes_input_ids, outcomes_attention_mask,
                criteria_input_ids, criteria_attention_mask):

        # Embed each text input using BioBERT
        intro_outputs = self.biobert(intro_input_ids, attention_mask=intro_attention_mask)
        intro_embedding = intro_outputs.pooler_output

        outcomes_outputs = self.biobert(outcomes_input_ids, attention_mask=outcomes_attention_mask)
        outcomes_embedding = outcomes_outputs.pooler_output

        criteria_outputs = self.biobert(criteria_input_ids, attention_mask=criteria_attention_mask)
        criteria_embedding = criteria_outputs.pooler_output

        ## --- Concatenate ---
        # Concatenate the embeddings from the three text inputs
        # Apply modality weights (softmax optional for normalized weight distribution)
        modality_weights = F.softmax(self.text_modality_weights, dim=0)
        text_embeddings = [
            intro_embedding * modality_weights[0],
            outcomes_embedding * modality_weights[1],
            criteria_embedding * modality_weights[2]
        ]
        text_concat = torch.cat(text_embeddings, dim=1)  # (batch, 2304)

        # Categorical: embed and apply per-feature weights
        if self.num_categorical_features is not None:
          categorical_embeds = [emb(categorical_inputs[:, i]) for i, emb in enumerate(self.categorical_embeddings)]
          categorical_embeds = [embed * self.categorical_feature_weights[i] for i, embed in enumerate(categorical_embeds)]
          cat_concat = torch.cat(categorical_embeds, dim=1)  # shape (batch, sum(emb_dims))

        # Numerical
        if self.num_numerical_features is not None:
            numerical_inputs = self.numerical_bn(numerical_inputs)  # Input numerical inputs to batch normalization layer
            # numerical_feature_weights: (num_features,) → auto-broadcast to (batch, num_features)
            weighted_numerical = numerical_inputs * self.numerical_feature_weights

        # Final concat
        combined_features = torch.cat([text_concat, cat_concat, weighted_numerical], dim=1)


        # Forward
        x = self.dropout1(combined_features)
        x = F.relu(self.bn1(self.linear1(x)))
        x = self.dropout2(x)
        x = F.relu(self.bn2(self.linear2(x)))
        x = self.dropout2(x)
        x = F.relu(self.bn3(self.linear3(x)))
        x = self.dropout2(x)
        logits = self.finallinear(x)
        return logits

In [None]:
# Define categorical embeddings based on training data
categorical_embedding_dims = []
for col in categorical_cols:
    num_unique_values = len(df_train[col].unique()) + 1 # +1 to account for unseen
    embedding_size = min(20, (num_unique_values + 1) // 2) # A common heuristic
    categorical_embedding_dims.append((num_unique_values, embedding_size))
    print(f"Column: {col}, Unique Values: {num_unique_values}, Embedding Size: {embedding_size}")

Column: phase, Unique Values: 7, Embedding Size: 4
Column: allocation, Unique Values: 4, Embedding Size: 2
Column: intervention_model, Unique Values: 6, Embedding Size: 3
Column: primary_purpose, Unique Values: 10, Embedding Size: 5
Column: dmc_oversight, Unique Values: 4, Embedding Size: 2
Column: fda_drug, Unique Values: 4, Embedding Size: 2
Column: fda_device, Unique Values: 4, Embedding Size: 2
Column: unapproved_device, Unique Values: 3, Embedding Size: 2


In [None]:
# Instantiate the model
model = Terminate_Model(num_categorical_features=len(categorical_cols),
                  categorical_embedding_dims=categorical_embedding_dims,
                  num_numerical_features=len(numerical_cols)
                  )

display(model)

Terminate_Model(
  (biobert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elemen

## Train

In [None]:
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Apply weighting for class imbalance
all_labels = df_train[target_term_col].values
class_weights = compute_class_weight('balanced', classes=np.unique(all_labels), y=all_labels)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

# Define loss function. Add class weighting.
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-5) # Adjust learning rate as needed
# Scheduler
# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

# Set the number of training epochs
num_epochs = 6

# Move model to the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# early stopping variables to prevent overfitting
best_f1 = 0
best_model_state = None
patience = 3
patience_counter = 0

# Training loop
for epoch in range(num_epochs):

    # Training
    model.train() # Set the model to training mode
    total_loss = 0
    for batch_idx, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
        intro_input_ids_batch = batch['intro_input_ids'].to(device)
        intro_attention_mask_batch = batch['intro_attention_mask'].to(device)
        outcomes_input_ids_batch = batch['outcomes_input_ids'].to(device)
        outcomes_attention_mask_batch = batch['outcomes_attention_mask'].to(device)
        criteria_input_ids_batch = batch['criteria_input_ids'].to(device)
        criteria_attention_mask_batch = batch['criteria_attention_mask'].to(device)
        categorical_batch = batch['categorical_inputs'].to(device)
        numerical_batch = batch['numerical_inputs'].to(device)
        targets_batch = batch['targets_term'].to(device)

        # Zero the gradients
        optimizer.zero_grad()
        # Forward pass
        outputs = model(categorical_batch, numerical_batch,
                        intro_input_ids_batch, intro_attention_mask_batch,
                        outcomes_input_ids_batch, outcomes_attention_mask_batch,
                        criteria_input_ids_batch, criteria_attention_mask_batch)
        # Calculate loss
        loss = criterion(outputs, targets_batch)
        total_loss += loss.item()
        # Backward pass and optimization
        loss.backward()
        optimizer.step()

    avg_loss = total_loss / len(train_dataloader) # Loss calculated as average batch loss
    print(f"Epoch {epoch+1} completed, Average Loss: {avg_loss:.4f}")

    # Validation
    model.eval() # Set the model to evaluation mode
    total_val_loss = 0
    all_preds = []
    all_labels = []
    with torch.no_grad(): # Disable gradient calculations during validation
        for batch_idx, batch in enumerate(tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} (Val)")):
            intro_input_ids_batch = batch['intro_input_ids'].to(device)
            intro_attention_mask_batch = batch['intro_attention_mask'].to(device)
            outcomes_input_ids_batch = batch['outcomes_input_ids'].to(device)
            outcomes_attention_mask_batch = batch['outcomes_attention_mask'].to(device)
            criteria_input_ids_batch = batch['criteria_input_ids'].to(device)
            criteria_attention_mask_batch = batch['criteria_attention_mask'].to(device)
            categorical_batch = batch['categorical_inputs'].to(device)
            numerical_batch = batch['numerical_inputs'].to(device)
            targets_batch = batch['targets_term'].to(device)

            # Forward pass of validation data
            outputs = model(categorical_batch, numerical_batch,
                        intro_input_ids_batch, intro_attention_mask_batch,
                        outcomes_input_ids_batch, outcomes_attention_mask_batch,
                        criteria_input_ids_batch, criteria_attention_mask_batch)
            # Calculate loss
            loss = criterion(outputs, targets_batch)
            total_val_loss += loss.item()

            # Get predictions (assuming binary classification with logits)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            labels = targets_batch.cpu().numpy()

            all_preds.extend(preds)
            all_labels.extend(labels)

    avg_val_loss = total_val_loss / len(val_dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, zero_division=0) # Handle potential division by zero
    recall = recall_score(all_labels, all_preds, zero_division=0)     # Handle potential division by zero
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    print(f"Epoch {epoch+1} Validation Loss: {avg_val_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

    # Step the scheduler based on the validation loss
    # scheduler.step(avg_val_loss)

    # --- Early Stopping Check ---
    # look for best f1 so far
    if f1 > best_f1:
        patience_counter = 0
        best_f1 = f1
        best_model_state = model.state_dict().copy()  # save the model state
        print(f"Best F1: {best_f1:.4f}; saving model.")
    else:
        patience_counter += 1
        print(f"No improvement for {patience_counter} epochs. Best F1: {best_f1:.4f}")

    # check if we should stop training
    if patience_counter >= patience:
        print(f"Early stopping triggered after epoch {epoch+1}")
        break

#load the best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print(f"Loaded best model with F1 score: {best_f1:.4f}")

print("Training finished!")

# Save Model

In [None]:
import joblib

# save the file into pkl file
joblib.dump(model, dir + "/model_terminate.pkl")

In [None]:
PATH = dir + '/model_terminate_state_dict.pth'
torch.save(model.state_dict(), PATH)

print(f"Model state dictionary saved to {PATH}")