# Bio Bert Model


In [1]:
# Step 1: Imports & Seed Setup
import os, random, json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import precision_score, recall_score, f1_score
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit  # Added import for this function
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import nltk
from nltk.corpus import stopwords
import re

In [2]:
# Set random seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

In [3]:
# Device setup
device = (
    torch.device("mps")
    if torch.backends.mps.is_available()
    else torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("cpu")
)
print(f"Using device: {device}")

Using device: mps


## 1. Read in Data 

In [5]:
# Load data
df = pd.read_csv('../data/processed/shuffled_10_data.csv')
print(df.shape)
df.head()

(53130, 7)


Unnamed: 0,AC,PMID,Title,Abstract,Terms,Text_combined,batch_number
0,Q8NML3,17183211,"RamA, the transcriptional regulator of acetate...",The RamA protein represents a LuxR-type transc...,autoregulation,"RamA, the transcriptional regulator of acetate...",1
1,Q9SCZ4,17673660,The FERONIA receptor-like kinase mediates male...,"In flowering plants, signaling between the mal...",autophosphorylation,The FERONIA receptor-like kinase mediates male...,1
2,Q81WX1,12721629,The genome sequence of Bacillus anthracis Ames...,Bacillus anthracis is an endospore-forming bac...,,The genome sequence of Bacillus anthracis Ames...,1
3,P14410,8521865,Phosphorylation of the N-terminal intracellula...,This paper reports the phosphorylation of the ...,,Phosphorylation of the N-terminal intracellula...,1
4,P36898,14523231,Mutations in bone morphogenetic protein recept...,Brachydactyly (BD) type A2 is an autosomal dom...,autophosphorylation,Mutations in bone morphogenetic protein recept...,1


## 2. Clean Data

In [6]:
# clean text
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\d+', '', text)
    text = " ".join([word.strip() for word in text.split() if word not in stop_words])
    return text

df['Text_Cleaned'] = df['Text_combined'].apply(clean_text)

# Fill nan with 'non-autoregulatory'
df['Terms'] = df['Terms'].fillna('non-autoregulatory')

# Keep only selected columns
columns_to_keep = ['batch_number','Text_Cleaned','Terms']
df_cleaned = df[columns_to_keep]

[nltk_data] Downloading package stopwords to /Users/halao/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [7]:
print(df_cleaned.shape)
df_cleaned.head()

(53130, 3)


Unnamed: 0,batch_number,Text_Cleaned,Terms
0,1,rama transcriptional regulator acetate metabol...,autoregulation
1,1,feronia receptorlike kinase mediates malefemal...,autophosphorylation
2,1,genome sequence bacillus anthracis ames compar...,non-autoregulatory
3,1,phosphorylation nterminal intracellular tail s...,non-autoregulatory
4,1,mutations bone morphogenetic protein receptor ...,autophosphorylation


## 3. Preprocess Data

In [8]:
# Convert terms to list
df_cleaned['Terms_List'] = df_cleaned['Terms'].apply(
    lambda x: [term.strip() for term in x.split(',')]
)
df_cleaned['Terms_List'] = df_cleaned['Terms_List'].apply(lambda x: list(set(x)))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cleaned['Terms_List'] = df_cleaned['Terms'].apply(
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cleaned['Terms_List'] = df_cleaned['Terms_List'].apply(lambda x: list(set(x)))


In [9]:
# Step 3: Binarize multi-labels
mlb = MultiLabelBinarizer()
labels = mlb.fit_transform(df_cleaned['Terms_List'])
label_columns = mlb.classes_

labels_df = pd.DataFrame(labels, columns=label_columns)
existing_columns = [col for col in label_columns if col in df_cleaned.columns]
df_cleaned = df_cleaned.drop(columns=existing_columns, errors='ignore')
df_cleaned = pd.concat([df_cleaned, labels_df], axis=1)

In [10]:
print(df_cleaned.shape)
df_cleaned.head()

(53130, 19)


Unnamed: 0,batch_number,Text_Cleaned,Terms,Terms_List,autoactivation,autocatalysis,autocatalytic,autofeedback,autoinducer,autoinduction,autoinhibition,autoinhibitory,autokinase,autolysis,autophosphorylation,autoregulation,autoregulatory,autoubiquitination,non-autoregulatory
0,1,rama transcriptional regulator acetate metabol...,autoregulation,[autoregulation],0,0,0,0,0,0,0,0,0,0,0,1,0,0,0
1,1,feronia receptorlike kinase mediates malefemal...,autophosphorylation,[autophosphorylation],0,0,0,0,0,0,0,0,0,0,1,0,0,0,0
2,1,genome sequence bacillus anthracis ames compar...,non-autoregulatory,[non-autoregulatory],0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
3,1,phosphorylation nterminal intracellular tail s...,non-autoregulatory,[non-autoregulatory],0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
4,1,mutations bone morphogenetic protein receptor ...,autophosphorylation,[autophosphorylation],0,0,0,0,0,0,0,0,0,0,1,0,0,0,0


## 4. Train Model

In [11]:
# Step 6: Instantiate the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
bert_model = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")


In [12]:
# Step 5: Train/Validation Split 

def split_single_batch_data(batch_number, test_size=0.2, random_state=42):
    """
    Split data from a single batch into train and test sets using stratified sampling.
    Prints Train/Test sizes and Non-auto/Auto counts.
    """
    batch_df = df_cleaned[df_cleaned['batch_number'] == batch_number].copy()
    X = batch_df['Text_Cleaned']
    y = labels_df.loc[batch_df.index].values  # Ensure indexing alignment
    
    # Calculate label distribution
    non_auto_count = len(batch_df[batch_df['Terms'] == 'non-autoregulatory'])
    auto_count = len(batch_df[batch_df['Terms'] != 'non-autoregulatory'])
    
    # Split data
    msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
    
    for train_idx, test_idx in msss.split(X, y):
        X_train = X.iloc[train_idx]
        X_test = X.iloc[test_idx]
        y_train = y[train_idx]
        y_test = y[test_idx]
    
    print(f"Batch {batch_number} | Train: {len(X_train)}, Test: {len(X_test)} | Non-auto: {non_auto_count}, Auto: {auto_count}")
    
    return X_train, X_test, y_train, y_test

## 3. Split the data into train, validation, and test sets

In [13]:
# Calculate class weights
def get_data_and_weights(batch_number):
    """
    Get data and calculate class weights for a specific batch.
    """
    X_train, X_test, y_train, y_test = split_single_batch_data(batch_number)
    
    # Calculate class weights
    pos_weights = []
    for i in range(y_train.shape[1]):
        neg_count = len(y_train) - np.sum(y_train[:, i])
        pos_count = np.sum(y_train[:, i])
        pos_weights.append(neg_count / pos_count if pos_count > 0 else 1.0)
    
    pos_weights = torch.FloatTensor(pos_weights).to(device)
    
    return X_train, X_test, y_train, y_test, pos_weights

In [14]:
# create dataset class
class BioBERTDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts.iloc[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.FloatTensor(label)
        }


In [15]:
# create data loader
def create_dataset_and_loader(X_train, y_train, X_test, y_test, batch_size):
    train_dataset = BioBERTDataset(X_train, y_train, tokenizer)
    test_dataset = BioBERTDataset(X_test, y_test, tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, test_loader

In [16]:
# Create model
class BioBERTClassifier(nn.Module):
    def __init__(self, n_classes, dropout=0.1):
        super(BioBERTClassifier, self).__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

In [17]:
# training function
def train_epoch(model, data_loader, optimizer, criterion):
    model.train()
    total_loss = 0
    
    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(data_loader)

In [18]:
# Set thresholds for each label
def set_thresholds(pos_weights):
    """
    Set thresholds based on normalized pos_weights, scaled to [0.2, 0.8].
    """
    thresholds = []

    if len(pos_weights) != len(label_columns):
        raise ValueError(f"Length mismatch: pos_weights ({len(pos_weights)}) vs label_columns ({len(label_columns)})")

    # Calculate min and max weights for normalization
    min_weight = pos_weights.min().item()
    max_weight = pos_weights.max().item()
    weight_range = max_weight - min_weight

    # Avoid division by zero
    if weight_range == 0:
        weight_range = 1

    # Calculate thresholds
    for weight in pos_weights:
        # Normalize to [0, 1]
        normalized_weight = (weight.item() - min_weight) / weight_range

        # Map to [0.2, 0.8]
        threshold = 0.8 - (0.6 * normalized_weight)

        # Clamp to [0.2, 0.8]
        threshold = max(0.2, min(threshold, 0.8))

        thresholds.append(threshold)

    # Print thresholds with two decimal places
    formatted_thresholds = [f"{t:.2f}" for t in thresholds]
    print(f"\nDynamic Thresholds: {formatted_thresholds}")
    
    return thresholds

In [20]:
# Evaluation function
def evaluate(model, data_loader, criterion, thresholds):
    """
    Evaluate the model with focused metrics output.
    """
    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            probabilities = torch.sigmoid(outputs).cpu().numpy()
            
            # Apply thresholds per label
            predictions = np.array([
                (probabilities[:, i] >= thresholds[i]).astype(int) for i in range(len(thresholds))
            ]).T

            all_predictions.extend(predictions)
            all_labels.extend(labels.cpu().numpy())
    
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)

    # Samples metrics
    samples_precision = precision_score(all_labels, all_predictions, average='samples', zero_division=0)
    samples_recall = recall_score(all_labels, all_predictions, average='samples', zero_division=0)
    samples_f1 = f1_score(all_labels, all_predictions, average='samples', zero_division=0)

    # F1 metrics
    micro_f1 = f1_score(all_labels, all_predictions, average='micro', zero_division=0)
    macro_f1 = f1_score(all_labels, all_predictions, average='macro', zero_division=0)
    weighted_f1 = f1_score(all_labels, all_predictions, average='weighted', zero_division=0)

    # Average loss
    avg_loss = total_loss / len(data_loader)

    metrics = {
        'loss': avg_loss,
        'micro_f1': micro_f1,
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'samples_f1': samples_f1,
        'samples_precision': samples_precision,
        'samples_recall': samples_recall
    }
    
    # Output results
    print(f"Loss: {avg_loss:.4f} | Micro F1: {micro_f1:.4f} | Macro F1: {macro_f1:.4f} | Weighted F1: {weighted_f1:.4f} | Samples F1: {samples_f1:.4f}")
    print(f"Samples Precision: {samples_precision:.4f} | Samples Recall: {samples_recall:.4f}")
    
    # Return metrics for further analysis
    return metrics

In [21]:
def train_model(batch_number, n_epochs, learning_rate, batch_size):
    """
    Training loop for a single batch with dynamic threshold settings.
    """
    X_train, X_test, y_train, y_test, pos_weights = get_data_and_weights(batch_number)
    print(f"\nProcessing Batch {batch_number} ...")

    # Calculate dynamic thresholds
    thresholds = set_thresholds(pos_weights)
    
    train_loader, test_loader = create_dataset_and_loader(X_train, y_train, X_test, y_test, batch_size)

    model = BioBERTClassifier(n_classes=y_train.shape[1]).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    best_samples_f1 = 0.0

    for epoch in range(n_epochs):
        print(f"Epoch {epoch + 1}/{n_epochs} | Batch {batch_number}")

        train_epoch(model, train_loader, optimizer, criterion)
        metrics = evaluate(model, test_loader, criterion, thresholds)

        current_samples_f1 = metrics['samples_f1']

        # Save best model and thresholds
        if current_samples_f1 > best_samples_f1:
            best_samples_f1 = current_samples_f1
            model_path = f"../src/model/best_model_batch_{batch_number}.pt"
            torch.save(model.state_dict(), model_path)
            print(f"  New best model saved to {model_path}")

            # Save thresholds
            thresholds_path = f"../src/model/best_thresholds_batch_{batch_number}.json"
            with open(thresholds_path, "w") as f:
                json.dump(thresholds, f)
            print(f"  Thresholds saved to {thresholds_path}")

In [None]:
# Run model for all batches
n_epochs = 7
learning_rate = 2e-5
batch_size = 16

for batch_num in range(1, 11):
    train_model(batch_number=batch_num, n_epochs=n_epochs, learning_rate=learning_rate, batch_size=batch_size)

Batch 1 | Train: 4248, Test: 1065 | Non-auto: 3542, Auto: 1771

Processing Batch 1 ...

Dynamic Thresholds: ['0.60', '0.50', '0.78', '0.44', '0.74', '0.40', '0.77', '0.73', '0.20', '0.74', '0.80', '0.77', '0.75', '0.77', '0.80']
Epoch 1/7 | Batch 1
