# Bio Bert Model


In [None]:
import nltk
from nltk.corpus import stopwords
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset, DatasetDict
import os
import pandas as pd
import numpy as np
import re
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import torch.nn as nn
from sklearn.preprocessing import MultiLabelBinarizer
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score

## 1. Read in Data

In [2]:
df = pd.read_csv('../data/preprocessed/shuffled_data.csv')
print(df.shape)
df.head()

(5343, 9)


Unnamed: 0,AC,PMID,Title,Abstract,Terms,Title_clean,Abstract_clean,Text_combined,Term_list
0,Q9FN94,19124768,Tyrosine phosphorylation of the BRI1 receptor ...,Brassinosteroids (BRs) are essential growth-pr...,autophosphorylation,Tyrosine phosphorylation of the BRI1 receptor ...,Brassinosteroids (BRs) are essential growth-pr...,Tyrosine phosphorylation of the BRI1 receptor ...,['autophosphorylation']
1,Q06219,20159955,The mammalian clock component PERIOD2 coordina...,Mammalian circadian clocks provide a temporal ...,,The mammalian clock component PERIOD2 coordina...,Mammalian circadian clocks provide a temporal ...,The mammalian clock component PERIOD2 coordina...,['']
2,Q14129,15461802,A genome annotation-driven approach to cloning...,We have developed a systematic approach to gen...,,A genome annotation-driven approach to cloning...,We have developed a systematic approach to gen...,A genome annotation-driven approach to cloning...,['']
3,Q59WV0,15123810,The diploid genome sequence of Candida albicans.,We present the diploid genome sequence of the ...,,The diploid genome sequence of Candida albicans.,We present the diploid genome sequence of the ...,The diploid genome sequence of Candida albican...,['']
4,O75534,16356927,The autoregulatory translational control eleme...,Repression of poly(A)-binding protein (PABP) m...,autoregulatory,The autoregulatory translational control eleme...,Repression of poly(A)-binding protein (PABP) m...,The autoregulatory translational control eleme...,['autoregulatory']


## 2. Preprocess Data

In [None]:


df['Term_list'] = df['Terms'].apply(lambda x: [term.strip() for term in str(x).split(',')] if isinstance(x, str) else [])

df['Term_list'] = df['Term_list'].apply(lambda x: [term for term in x if term != ''])

mlb = MultiLabelBinarizer()
Y = mlb.fit_transform(df['Term_list'])

label_classes = mlb.classes_

print("Label classes:", label_classes)
print("Shape of label matrix:", Y.shape)  


labels_df = pd.DataFrame(Y, columns=label_classes)
df = pd.concat([df, labels_df], axis=1)


Label classes: ['autoactivation' 'autocatalysis' 'autocatalytic' 'autofeedback'
 'autoinducer' 'autoinduction' 'autoinhibition' 'autoinhibitory'
 'autokinase' 'autolysis' 'autophosphatase' 'autophosphorylation'
 'autoregulation' 'autoregulatory' 'autoubiquitination']
Shape of label matrix: (5343, 15)


In [5]:
label_counts = labels_df.sum(axis=0)
print(label_counts.sort_values(ascending=False))

autophosphorylation    849
autocatalytic          177
autoregulation         154
autoubiquitination     146
autoinhibition         137
autoregulatory          85
autoinducer             73
autolysis               70
autoinhibitory          60
autoactivation          22
autocatalysis           15
autofeedback            13
autoinduction           11
autokinase               8
autophosphatase          1
dtype: int64


In [6]:
df.head()

Unnamed: 0,AC,PMID,Title,Abstract,Terms,Title_clean,Abstract_clean,Text_combined,Term_list,autoactivation,...,autoinduction,autoinhibition,autoinhibitory,autokinase,autolysis,autophosphatase,autophosphorylation,autoregulation,autoregulatory,autoubiquitination
0,Q9FN94,19124768,Tyrosine phosphorylation of the BRI1 receptor ...,Brassinosteroids (BRs) are essential growth-pr...,autophosphorylation,Tyrosine phosphorylation of the BRI1 receptor ...,Brassinosteroids (BRs) are essential growth-pr...,Tyrosine phosphorylation of the BRI1 receptor ...,[autophosphorylation],0,...,0,0,0,0,0,0,1,0,0,0
1,Q06219,20159955,The mammalian clock component PERIOD2 coordina...,Mammalian circadian clocks provide a temporal ...,,The mammalian clock component PERIOD2 coordina...,Mammalian circadian clocks provide a temporal ...,The mammalian clock component PERIOD2 coordina...,[],0,...,0,0,0,0,0,0,0,0,0,0
2,Q14129,15461802,A genome annotation-driven approach to cloning...,We have developed a systematic approach to gen...,,A genome annotation-driven approach to cloning...,We have developed a systematic approach to gen...,A genome annotation-driven approach to cloning...,[],0,...,0,0,0,0,0,0,0,0,0,0
3,Q59WV0,15123810,The diploid genome sequence of Candida albicans.,We present the diploid genome sequence of the ...,,The diploid genome sequence of Candida albicans.,We present the diploid genome sequence of the ...,The diploid genome sequence of Candida albican...,[],0,...,0,0,0,0,0,0,0,0,0,0
4,O75534,16356927,The autoregulatory translational control eleme...,Repression of poly(A)-binding protein (PABP) m...,autoregulatory,The autoregulatory translational control eleme...,Repression of poly(A)-binding protein (PABP) m...,The autoregulatory translational control eleme...,[autoregulatory],0,...,0,0,0,0,0,0,0,0,1,0


## 3. Split the data into train, validation, and test sets

In [None]:


# Define the features (X) and labels (Y)

X = df[['Text_combined']]  
Y = df[label_classes]  

print(X.shape)
print(Y.shape)


(5343, 1)
(5343, 15)


In [8]:

# Split the data into training and testing sets (80% train, 20% test)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

# Check the shapes of the split data
print(f"Training data shape: {X_train.shape}, {Y_train.shape}")
print(f"Testing data shape: {X_test.shape}, {Y_test.shape}")


Training data shape: (4274, 1), (4274, 15)
Testing data shape: (1069, 1), (1069, 15)


## 4. Modeling

In [None]:


# Set up device (MPS, CUDA, or CPU)
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Device being used: {device}")


Device being used: mps


In [10]:
class PubMedDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        """
        Args:
            texts (pandas.Series): Input texts (cleaned)
            labels (pandas.DataFrame): Corresponding labels
            tokenizer (transformers.BertTokenizer): BioBERT tokenizer
            max_length (int): Maximum sequence length
        """
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts.iloc[idx])  # Get the text at idx
        label = self.labels.iloc[idx].values.astype(float)  # Get the multi-label (binary vector)
        
        # Tokenize the text
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.FloatTensor(label)  # Convert labels to float tensor for multi-label
        }


In [11]:
# Load BioBERT tokenizer (or you can use another pre-trained BERT-based model)
tokenizer = BertTokenizer.from_pretrained("dmis-lab/biobert-v1.1")


# Create Dataset instances
train_dataset = PubMedDataset(X_train['Text_combined'], Y_train, tokenizer)
test_dataset = PubMedDataset(X_test['Text_combined'], Y_test, tokenizer)

# Create DataLoader for batching
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16)


In [12]:
# Load BioBERT pre-trained model for multi-label classification
model = BertForSequenceClassification.from_pretrained("dmis-lab/biobert-v1.1", num_labels=len(label_classes))
model.to(device)  # Move the model to the appropriate device

# Define the optimizer
optimizer = AdamW(model.parameters(), lr=1e-5)

# Calculate class weights based on the frequency of terms in the training set
pos_weights = []

for i in range(Y_train.shape[1]):
    neg_count = len(Y_train) - np.sum(Y_train.iloc[:, i])  # Count of negative examples
    pos_count = np.sum(Y_train.iloc[:, i])  # Count of positive examples
    pos_weights.append(neg_count / pos_count if pos_count > 0 else 1.0)  # Weight calculation

pos_weights = torch.FloatTensor(pos_weights).to(device)  # Move weights to the same device as model


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at dmis-lab/biobert-v1.1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
# Define Binary Cross-Entropy with logits loss function and apply class weights
loss_fn = nn.BCEWithLogitsLoss(weight=pos_weights)


In [None]:


def train_model(
    model, 
    dataloader, 
    optimizer, 
    loss_fn, 
    device, 
    epochs: int = 3, 
    log_every: int = 10
):
    """
    Train `model` on data from `dataloader`, using `optimizer` and `loss_fn`.
    Prints out progress every `log_every` batches and per-epoch summaries.
    
    Args:
        model: your torch.nn.Module
        dataloader: DataLoader for training data
        optimizer: the optimizer (e.g. AdamW)
        loss_fn: loss function (e.g. BCEWithLogitsLoss)
        device: torch.device
        epochs: number of epochs to train
        log_every: how many batches between printouts (set to 0 to disable batch logs)
    
    Returns:
        history: dict with keys 'train_loss' (list of epoch losses)
    """
    history = {'train_loss': []}
    
    model.to(device)
    for epoch in range(1, epochs + 1):
        model.train()
        epoch_loss = 0.0
        
        loop = enumerate(dataloader, 1)
        if log_every:
            loop = tqdm(loop, total=len(dataloader), desc=f"Epoch {epoch}/{epochs}")
        
        for batch_idx, batch in loop:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
            if log_every and (batch_idx % log_every == 0):
                loop.set_postfix(loss=epoch_loss / batch_idx)
        
        avg_loss = epoch_loss / len(dataloader)
        history['train_loss'].append(avg_loss)
        print(f"→ Epoch {epoch}/{epochs} finished — avg loss: {avg_loss:.4f}")
    
    return history


In [None]:
# assume model, train_dataloader, optimizer, loss_fn, device already defined
history = train_model(
    model=model,
    dataloader=train_dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    device=device,
    epochs=5,
    log_every=20
)


Epoch 1/5:  40%|███▉      | 106/268 [16:52<24:19,  9.01s/it, loss=nan]

## 5. Evaluate the model

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score

def evaluate_model(model, dataloader, loss_fn, device):
    """
    Evaluate a multi‐label classification model.
    
    Args:
        model: A fine‐tuned BertForSequenceClassification
        dataloader: DataLoader yielding dicts with keys 
                    'input_ids', 'attention_mask', 'labels'
        loss_fn: Loss function (e.g. nn.BCEWithLogitsLoss with class weights)
        device: torch.device
    
    Returns:
        A dict with keys: 'loss', 'precision', 'recall', 'f1'
    """
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            # Unpack the batch dict
            input_ids     = batch['input_ids'].to(device)
            attention_mask= batch['attention_mask'].to(device)
            labels        = batch['labels'].to(device)
            
            # Forward pass (no labels arg so we can apply custom loss)
            outputs = model(input_ids, attention_mask=attention_mask)
            logits  = outputs.logits
            
            # Compute loss
            loss = loss_fn(logits, labels)
            total_loss += loss.item()
            
            # Predictions: sigmoid→probabilities→threshold→0/1
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()
            
            # Collect for metrics
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    # Aggregate
    avg_loss    = total_loss / len(dataloader)
    all_preds   = np.vstack(all_preds)
    all_labels  = np.vstack(all_labels)
    
    precision = precision_score(all_labels, all_preds, average='micro', zero_division=0)
    recall    = recall_score(all_labels, all_preds, average='micro', zero_division=0)
    f1        = f1_score(all_labels, all_preds, average='micro', zero_division=0)
    
    return {
        'loss':      avg_loss,
        'precision': precision,
        'recall':    recall,
        'f1':        f1
    }


In [None]:
metrics = evaluate_model(
    model, 
    test_dataloader, 
    loss_fn,    # your nn.BCEWithLogitsLoss(weight=pos_weights)
    device
)

print(f"Test   Loss: {metrics['loss']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall   : {metrics['recall']:.4f}")
print(f"Micro F1 : {metrics['f1']:.4f}")
