In [1]:
import os
import pandas as pd
import numpy as np
import json
import re
import nltk
from nltk.corpus import stopwords
from sklearn.preprocessing import MultiLabelBinarizer
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit

import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score, precision_score, recall_score
import matplotlib.pyplot as plt

**0. Import data & data clean**

In [2]:
df = pd.read_csv('../data/processed/shuffled_10_data.csv')
print(df.shape)
df.head()

(52800, 7)


Unnamed: 0,AC,PMID,Title,Abstract,Terms,Text_combined,batch_number
0,P06169,2185016,Autoregulation may control the expression of y...,Recently we deleted the pyruvate decarboxylase...,autoregulation,Autoregulation may control the expression of y...,1
1,P0AEM5,12704152,Complete genome sequence and comparative genom...,We determined the complete genome sequence of ...,,Complete genome sequence and comparative genom...,1
2,B8FZE0,22316246,Genome sequence of Desulfitobacterium hafniens...,"The genome of the Gram-positive, metal-reducin...",,Genome sequence of Desulfitobacterium hafniens...,1
3,P14656,12060286,Overlapping expression of cytosolic glutamine ...,In order to estimate whether cytosolic glutami...,,Overlapping expression of cytosolic glutamine ...,1
4,Q7XXS4,27052628,Both overexpression and suppression of an Oryz...,Tight and accurate regulation of immunity and ...,autoactivation,Both overexpression and suppression of an Oryz...,1


In [3]:
# clean text
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\d+', '', text)
    text = " ".join([word.strip() for word in text.split() if word not in stop_words])
    return text

df['Text_Cleaned'] = df['Text_combined'].apply(clean_text)

# fill nan with 'non-autoregulatory'
df['Terms'] = df['Terms'].fillna('non-autoregulatory')

# keep only selected columns
columns_to_keep = ['batch_number','Text_Cleaned','Terms']
df_cleaned = df[columns_to_keep]

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/fiatlux/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [4]:
print(df_cleaned.shape)
df_cleaned.head()

(52800, 3)


Unnamed: 0,batch_number,Text_Cleaned,Terms
0,1,autoregulation may control expression yeast py...,autoregulation
1,1,complete genome sequence comparative genomics ...,non-autoregulatory
2,1,genome sequence desulfitobacterium hafniense d...,non-autoregulatory
3,1,overlapping expression cytosolic glutamine syn...,non-autoregulatory
4,1,overexpression suppression oryza sativa nblrrl...,autoactivation


In [5]:
# convert terms to list
df_cleaned['Terms_List'] = df_cleaned['Terms'].apply(
    lambda x: [term.strip() for term in x.split(',')]
)
df_cleaned['Terms_List'] = df_cleaned['Terms_List'].apply(lambda x: list(set(x)))

mlb = MultiLabelBinarizer()
labels = mlb.fit_transform(df_cleaned['Terms_List'])
label_columns = mlb.classes_

labels_df = pd.DataFrame(labels, columns=label_columns)
existing_columns = [col for col in label_columns if col in df_cleaned.columns]
df_cleaned = df_cleaned.drop(columns=existing_columns, errors='ignore')
df_cleaned = pd.concat([df_cleaned, labels_df], axis=1)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cleaned['Terms_List'] = df_cleaned['Terms'].apply(
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cleaned['Terms_List'] = df_cleaned['Terms_List'].apply(lambda x: list(set(x)))


In [6]:
print(df_cleaned.shape)
df_cleaned.head()

(52800, 19)


Unnamed: 0,batch_number,Text_Cleaned,Terms,Terms_List,autoactivation,autocatalysis,autocatalytic,autofeedback,autoinducer,autoinduction,autoinhibition,autoinhibitory,autokinase,autolysis,autophosphorylation,autoregulation,autoregulatory,autoubiquitination,non-autoregulatory
0,1,autoregulation may control expression yeast py...,autoregulation,[autoregulation],0,0,0,0,0,0,0,0,0,0,0,1,0,0,0
1,1,complete genome sequence comparative genomics ...,non-autoregulatory,[non-autoregulatory],0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
2,1,genome sequence desulfitobacterium hafniense d...,non-autoregulatory,[non-autoregulatory],0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
3,1,overlapping expression cytosolic glutamine syn...,non-autoregulatory,[non-autoregulatory],0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
4,1,overexpression suppression oryza sativa nblrrl...,autoactivation,[autoactivation],1,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [7]:
# check label distribution
test_df = df_cleaned[df_cleaned['batch_number'] == 1]
numeric_columns = test_df.select_dtypes(include=['int64', 'float64']).columns
label_counts = test_df[numeric_columns].sum(axis=0)
label_columns = [col for col in df_cleaned.columns[3:] if col != "Terms_List"]
print(label_counts.sort_values(ascending=False))

batch_number           5280
non-autoregulatory     3520
autophosphorylation     838
autocatalytic           176
autoregulation          154
autoubiquitination      145
autoinhibition          133
autoregulatory           81
autoinducer              73
autolysis                70
autoinhibitory           60
autoactivation           22
autocatalysis            15
autofeedback             13
autoinduction            11
autokinase                8
dtype: int64


**1. Define Functions for the Model**

In [8]:
# Device Configuration
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Device being used: {device}")

# Load tokenizer and base model
tokenizer = AutoTokenizer.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')
bert_model = AutoModel.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')

Device being used: mps


In [9]:
# data splitting
def split_single_batch_data(batch_number, test_size=0.2, random_state=42):
    """
    Split data from a single batch into train and test sets using stratified sampling.
    Prints Train/Test sizes and Non-auto/Auto counts.
    """
    batch_df = df_cleaned[df_cleaned['batch_number'] == batch_number].copy()
    X = batch_df['Text_Cleaned']
    y = labels_df.loc[batch_df.index].values  # Ensure indexing alignment
    
    # calculate label distribution
    non_auto_count = len(batch_df[batch_df['Terms'] == 'non-autoregulatory'])
    auto_count = len(batch_df[batch_df['Terms'] != 'non-autoregulatory'])
    
    # split data
    msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
    
    for train_idx, test_idx in msss.split(X, y):
        X_train = X.iloc[train_idx]
        X_test = X.iloc[test_idx]
        y_train = y[train_idx]
        y_test = y[test_idx]
    
    print(f"Batch {batch_number} | Train: {len(X_train)}, Test: {len(X_test)} | Non-auto: {non_auto_count}, Auto: {auto_count}")
    
    return X_train, X_test, y_train, y_test

In [10]:
# calculate class weights
def get_data_and_weights(batch_number):
    """
    Get data and calculate class weights for a specific batch.
    """
    X_train, X_test, y_train, y_test = split_single_batch_data(batch_number)
    
    # Calculate class weights
    pos_weights = []
    for i in range(y_train.shape[1]):
        neg_count = len(y_train) - np.sum(y_train[:, i])
        pos_count = np.sum(y_train[:, i])
        pos_weights.append(neg_count / pos_count if pos_count > 0 else 1.0)
    
    pos_weights = torch.FloatTensor(pos_weights).to(device)
    
    return X_train, X_test, y_train, y_test, pos_weights

In [11]:
# create dataset class
class PubMedDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts.iloc[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.FloatTensor(label)
        }

In [12]:
# create data loader
def create_dataset_and_loader(X_train, y_train, X_test, y_test, batch_size):
    train_dataset = PubMedDataset(X_train, y_train, tokenizer)
    test_dataset = PubMedDataset(X_test, y_test, tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, test_loader

In [13]:
# create model
class PubMedBERTClassifier(nn.Module):
    def __init__(self, n_classes, dropout=0.1):
        super(PubMedBERTClassifier, self).__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

In [14]:
# training function
def train_epoch(model, data_loader, optimizer, criterion):
    model.train()
    total_loss = 0
    
    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(data_loader)

In [15]:
# set thresholds for each label
def set_thresholds(pos_weights):
    """
    Set thresholds based on normalized pos_weights, scaled to [0.2, 0.8].
    """
    thresholds = []

    if len(pos_weights) != len(label_columns):
        raise ValueError(f"Length mismatch: pos_weights ({len(pos_weights)}) vs label_columns ({len(label_columns)})")

    # Calculate min and max weights for normalization
    min_weight = pos_weights.min().item()
    max_weight = pos_weights.max().item()
    weight_range = max_weight - min_weight

    # Avoid division by zero
    if weight_range == 0:
        weight_range = 1

    # Calculate thresholds
    for weight in pos_weights:
        # Normalize to [0, 1]
        normalized_weight = (weight.item() - min_weight) / weight_range

        # Map to [0.2, 0.8]
        threshold = 0.8 - (0.6 * normalized_weight)

        # Clamp to [0.2, 0.8]
        threshold = max(0.2, min(threshold, 0.8))

        thresholds.append(threshold)

    # Print thresholds with two decimal places
    formatted_thresholds = [f"{t:.2f}" for t in thresholds]
    print(f"\nDynamic Thresholds: {formatted_thresholds}")
    
    return thresholds


In [16]:
# Evaluation function
def evaluate(model, data_loader, criterion, thresholds):
    """
    Evaluate the model with focused metrics output.
    """
    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            probabilities = torch.sigmoid(outputs).cpu().numpy()
            
            # Apply thresholds per label
            predictions = np.array([
                (probabilities[:, i] >= thresholds[i]).astype(int) for i in range(len(thresholds))
            ]).T

            all_predictions.extend(predictions)
            all_labels.extend(labels.cpu().numpy())
    
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)

    # samples metrics
    samples_precision = precision_score(all_labels, all_predictions, average='samples', zero_division=0)
    samples_recall = recall_score(all_labels, all_predictions, average='samples', zero_division=0)
    samples_f1 = f1_score(all_labels, all_predictions, average='samples', zero_division=0)

    # f1 metrics
    micro_f1 = f1_score(all_labels, all_predictions, average='micro', zero_division=0)
    macro_f1 = f1_score(all_labels, all_predictions, average='macro', zero_division=0)
    weighted_f1 = f1_score(all_labels, all_predictions, average='weighted', zero_division=0)

    # average loss
    avg_loss = total_loss / len(data_loader)

    metrics = {
        'loss': avg_loss,
        'micro_f1': micro_f1,
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'samples_f1': samples_f1,
        'samples_precision': samples_precision,
        'samples_recall': samples_recall
    }
    
    # output results
    print(f"Loss: {avg_loss:.4f} | Micro F1: {micro_f1:.4f} | Macro F1: {macro_f1:.4f} | Weighted F1: {weighted_f1:.4f} | Samples F1: {samples_f1:.4f}")
    print(f"Samples Precision: {samples_precision:.4f} | Samples Recall: {samples_recall:.4f}")
    
    # return metrics for further analysis
    return metrics

**2. Train model**

In [17]:
def train_model(batch_number, n_epochs, learning_rate, batch_size):
    """
    Training loop for a single batch with dynamic threshold settings.
    """
    X_train, X_test, y_train, y_test, pos_weights = get_data_and_weights(batch_number)
    print(f"\nProcessing Batch {batch_number} ...")

    # Calculate dynamic thresholds
    thresholds = set_thresholds(pos_weights)
    
    train_loader, test_loader = create_dataset_and_loader(X_train, y_train, X_test, y_test, batch_size)

    model = PubMedBERTClassifier(n_classes=y_train.shape[1]).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    best_samples_f1 = 0.0

    for epoch in range(n_epochs):
        print(f"Epoch {epoch + 1}/{n_epochs} | Batch {batch_number}")

        train_epoch(model, train_loader, optimizer, criterion)
        metrics = evaluate(model, test_loader, criterion, thresholds)

        current_samples_f1 = metrics['samples_f1']

        # Save best model and thresholds
        if current_samples_f1 > best_samples_f1:
            best_samples_f1 = current_samples_f1
            model_path = f"../src/model/best_model_batch_{batch_number}.pt"
            torch.save(model.state_dict(), model_path)
            print(f"  New best model saved to {model_path}")

            # Save thresholds
            thresholds_path = f"../src/model/best_thresholds_batch_{batch_number}.json"
            with open(thresholds_path, "w") as f:
                json.dump(thresholds, f)
            print(f"  Thresholds saved to {thresholds_path}")

In [None]:
# run model for all batches
n_epochs = 7
learning_rate = 2e-5
batch_size = 16
batch_range = 5

for batch_num in range(1, batch_range):
    train_model(batch_number=batch_num, n_epochs=n_epochs, learning_rate=learning_rate, batch_size=batch_size)

Batch 1 | Train: 4223, Test: 1057 | Non-auto: 3520, Auto: 1760

Processing Batch 1 ...

Dynamic Thresholds: ['0.60', '0.50', '0.78', '0.44', '0.74', '0.40', '0.77', '0.73', '0.20', '0.74', '0.80', '0.77', '0.75', '0.77', '0.80']
Epoch 1/7 | Batch 1
Loss: 0.7703 | Micro F1: 0.1101 | Macro F1: 0.1762 | Weighted F1: 0.0782 | Samples F1: 0.0491
Samples Precision: 0.0374 | Samples Recall: 0.1268
  New best model saved to ../src/model/best_model_batch_1.pt
  Thresholds saved to ../src/model/best_thresholds_batch_1.json
Epoch 2/7 | Batch 1
Loss: 0.5501 | Micro F1: 0.5500 | Macro F1: 0.3549 | Weighted F1: 0.7401 | Samples F1: 0.6198
Samples Precision: 0.6008 | Samples Recall: 0.7129
  New best model saved to ../src/model/best_model_batch_1.pt
  Thresholds saved to ../src/model/best_thresholds_batch_1.json
Epoch 3/7 | Batch 1
Loss: 0.3836 | Micro F1: 0.5981 | Macro F1: 0.4176 | Weighted F1: 0.8026 | Samples F1: 0.7025
Samples Precision: 0.6776 | Samples Recall: 0.8136
  New best model saved to 

| Batch | Micro-F1 | Macro-F1 | Weighted-F1 | Samples-F1 | Sample-Precision | Sample-Recall |
|-------|----------|----------|-------------|------------|------------------|---------------|
|   1   |  0.8749  |  0.6532  |    0.9073   |   0.8897   |      0.8766      |     0.9216    |
|   2   |  0.9380  |  0.7721  |    0.9456   |   0.9455   |      0.9406      |     0.9570    |
|   3   |  0.9327  |  0.7529  |    0.9492   |   0.9409   |      0.9336      |     0.9578    |
|   4   |  0.9618  |  0.8444  |    0.9694   |   0.9665   |      0.9631      |     0.9751    |
|   5   |  0.9692  |  0.9292  |    0.9702   |   0.9699   |      0.9707      |     0.9700    |
|   6   |  0.9767  |  0.9363  |    0.9773   |   0.9790   |      0.9788      |     0.9810    |
|   7   |  0.9818  |  0.9781  |    0.9821   |   0.9820   |      0.9826      |     0.9826    |
|   8   |  0.9814  |  0.9073  |    0.9825   |   0.9837   |      0.9836      |     0.9845    |
|   9   |  0.9870  |  0.9695  |    0.9873   |   0.9883   |      0.9886      |     0.9897    |
|  10   |  0.9879  |  0.9558  |    0.9889   |   0.9886   |      0.9883      |     0.9897    |