# Bio Bert Model


In [1]:
# Step 1: Imports & Seed Setup
import os, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import precision_score, recall_score, f1_score
from transformers import BertTokenizer, BertForSequenceClassification
from tqdm import tqdm

In [2]:

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)


In [3]:
# 2) Device setup
device = (
    torch.device("mps")
    if torch.backends.mps.is_available()
    else torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("cpu")
)
print(f"Using device: {device}")


Using device: mps


## 1. Read in Data

In [4]:
df = pd.read_csv('../data/preprocessed/shuffled_data.csv')
print(df.shape)
df.head()

(5343, 9)


Unnamed: 0,AC,PMID,Title,Abstract,Terms,Title_clean,Abstract_clean,Text_combined,Term_list
0,Q9FN94,19124768,Tyrosine phosphorylation of the BRI1 receptor ...,Brassinosteroids (BRs) are essential growth-pr...,autophosphorylation,Tyrosine phosphorylation of the BRI1 receptor ...,Brassinosteroids (BRs) are essential growth-pr...,Tyrosine phosphorylation of the BRI1 receptor ...,['autophosphorylation']
1,Q06219,20159955,The mammalian clock component PERIOD2 coordina...,Mammalian circadian clocks provide a temporal ...,,The mammalian clock component PERIOD2 coordina...,Mammalian circadian clocks provide a temporal ...,The mammalian clock component PERIOD2 coordina...,['']
2,Q14129,15461802,A genome annotation-driven approach to cloning...,We have developed a systematic approach to gen...,,A genome annotation-driven approach to cloning...,We have developed a systematic approach to gen...,A genome annotation-driven approach to cloning...,['']
3,Q59WV0,15123810,The diploid genome sequence of Candida albicans.,We present the diploid genome sequence of the ...,,The diploid genome sequence of Candida albicans.,We present the diploid genome sequence of the ...,The diploid genome sequence of Candida albican...,['']
4,O75534,16356927,The autoregulatory translational control eleme...,Repression of poly(A)-binding protein (PABP) m...,autoregulatory,The autoregulatory translational control eleme...,Repression of poly(A)-binding protein (PABP) m...,The autoregulatory translational control eleme...,['autoregulatory']


In [19]:
df2 = pd.read_csv('../data/processed/shuffled_data.csv')
df2

Unnamed: 0,AC,PMID,Title,Abstract,Terms,Title_clean,Abstract_clean,Text_combined,Term_list,autoactivation,...,autoinduction,autoinhibition,autoinhibitory,autokinase,autolysis,autophosphatase,autophosphorylation,autoregulation,autoregulatory,autoubiquitination
0,,,,,,,,,,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,,,,,,,,,,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,,,,,,,,,,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,,,,,,,,,,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,Q9LXN3,22797656.0,Identification of an Arabidopsis fatty alcohol...,"While suberin is an insoluble heteropolymer, a...",,Identification of an Arabidopsis fatty alcohol...,"While suberin is an insoluble heteropolymer, a...",Identification of an Arabidopsis fatty alcohol...,[],0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
534268,,,,,,,,,,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
534269,Q8BYC4,18347022.0,Characterization of an orphan G protein-couple...,GPR20 was isolated as an orphan G protein-coup...,,Characterization of an orphan G protein-couple...,GPR20 was isolated as an orphan G protein-coup...,Characterization of an orphan G protein-couple...,[],0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
534270,O33407,20060837.0,Crystal structure of a full-length autotranspo...,The autotransporter (AT) secretion mechanism i...,,Crystal structure of a full-length autotranspo...,The autotransporter (AT) secretion mechanism i...,Crystal structure of a full-length autotranspo...,[],0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
534271,,,,,,,,,,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


## 2. Preprocess Data

In [5]:
# Step 4: Binarize Multi‑Labels
df['Term_list'] = df['Terms'].apply(
    lambda x: [t.strip() for t in str(x).split(',') if t.strip()]
)

mlb = MultiLabelBinarizer()
Y = mlb.fit_transform(df['Term_list'])

label_classes = mlb.classes_

print("Label classes:", label_classes)
print("Shape of label matrix:", Y.shape)  


labels_df = pd.DataFrame(Y, columns=label_classes)
df = pd.concat([df, labels_df], axis=1)


Label classes: ['autoactivation' 'autocatalysis' 'autocatalytic' 'autofeedback'
 'autoinducer' 'autoinduction' 'autoinhibition' 'autoinhibitory'
 'autokinase' 'autolysis' 'autophosphatase' 'autophosphorylation'
 'autoregulation' 'autoregulatory' 'autoubiquitination' 'nan']
Shape of label matrix: (5343, 16)


In [6]:
label_counts = labels_df.sum(axis=0)
print(label_counts.sort_values(ascending=False))

nan                    3562
autophosphorylation     849
autocatalytic           177
autoregulation          154
autoubiquitination      146
autoinhibition          137
autoregulatory           85
autoinducer              73
autolysis                70
autoinhibitory           60
autoactivation           22
autocatalysis            15
autofeedback             13
autoinduction            11
autokinase                8
autophosphatase           1
dtype: int64


In [7]:
df.head()

Unnamed: 0,AC,PMID,Title,Abstract,Terms,Title_clean,Abstract_clean,Text_combined,Term_list,autoactivation,...,autoinhibition,autoinhibitory,autokinase,autolysis,autophosphatase,autophosphorylation,autoregulation,autoregulatory,autoubiquitination,nan
0,Q9FN94,19124768,Tyrosine phosphorylation of the BRI1 receptor ...,Brassinosteroids (BRs) are essential growth-pr...,autophosphorylation,Tyrosine phosphorylation of the BRI1 receptor ...,Brassinosteroids (BRs) are essential growth-pr...,Tyrosine phosphorylation of the BRI1 receptor ...,[autophosphorylation],0,...,0,0,0,0,0,1,0,0,0,0
1,Q06219,20159955,The mammalian clock component PERIOD2 coordina...,Mammalian circadian clocks provide a temporal ...,,The mammalian clock component PERIOD2 coordina...,Mammalian circadian clocks provide a temporal ...,The mammalian clock component PERIOD2 coordina...,[nan],0,...,0,0,0,0,0,0,0,0,0,1
2,Q14129,15461802,A genome annotation-driven approach to cloning...,We have developed a systematic approach to gen...,,A genome annotation-driven approach to cloning...,We have developed a systematic approach to gen...,A genome annotation-driven approach to cloning...,[nan],0,...,0,0,0,0,0,0,0,0,0,1
3,Q59WV0,15123810,The diploid genome sequence of Candida albicans.,We present the diploid genome sequence of the ...,,The diploid genome sequence of Candida albicans.,We present the diploid genome sequence of the ...,The diploid genome sequence of Candida albican...,[nan],0,...,0,0,0,0,0,0,0,0,0,1
4,O75534,16356927,The autoregulatory translational control eleme...,Repression of poly(A)-binding protein (PABP) m...,autoregulatory,The autoregulatory translational control eleme...,Repression of poly(A)-binding protein (PABP) m...,The autoregulatory translational control eleme...,[autoregulatory],0,...,0,0,0,0,0,0,0,1,0,0


## 3. Split the data into train, validation, and test sets

In [8]:
# Step 5: Train/Validation Split 

X = df['Text_combined']
Y = df[label_classes]
X_train, X_val, Y_train, Y_val = train_test_split(
    X, Y, test_size=0.2, random_state=seed
)
print("Train:", X_train.shape, Y_train.shape)
print("Val:  ", X_val.shape,   Y_val.shape)


Train: (4274,) (4274, 16)
Val:   (1069,) (1069, 16)


## 4. Modeling

In [9]:
# Step 6: Dataset Definition

MAX_LEN = 512

class PubMedDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=MAX_LEN):
        self.texts = texts.reset_index(drop=True)
        self.labels = labels.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts.iloc[idx])
        label = torch.FloatTensor(self.labels.iloc[idx].values.astype(float))
        enc = self.tokenizer(
            text,
            add_special_tokens=True,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'labels': label
        }
    


In [11]:
# ↓ Step 7: Instantiate the tokenizer (must come before Step 7!) ↓
tokenizer = BertTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

In [12]:
# Step 8: DataLoader Setup 
train_ds = PubMedDataset(X_train, Y_train, tokenizer)
val_ds   = PubMedDataset(X_val,   Y_val,   tokenizer)
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True)
val_dl   = DataLoader(val_ds,   batch_size=16)

In [13]:
# Step 9: Model, Optimizer & Scheduler

model = BertForSequenceClassification.from_pretrained(
    "dmis-lab/biobert-v1.1",
    num_labels=len(label_classes)
).to(device)

optimizer = AdamW(model.parameters(), lr=1e-5)
scheduler = ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=1, verbose=True
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at dmis-lab/biobert-v1.1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# Step 10: Class Weights & Loss

pos_weights = []
for i in range(Y_train.shape[1]):
    neg = len(Y_train) - Y_train.iloc[:, i].sum()
    pos = Y_train.iloc[:, i].sum()
    pos_weights.append(neg / pos if pos > 0 else 1.0)
pos_weights = torch.FloatTensor(pos_weights).to(device)

loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weights)

In [15]:
# Step 11: Evaluation Function
def evaluate_model(model, dataloader, loss_fn, device):
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating"):
            input_ids   = batch['input_ids'].to(device)
            attn_mask   = batch['attention_mask'].to(device)
            labels      = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attn_mask)
            logits  = outputs.logits
            loss    = loss_fn(logits, labels)
            total_loss += loss.item()

            preds = (torch.sigmoid(logits) > 0.5).float()
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    # MPS/CUDA cache cleanup
    if device.type == 'mps':
        torch.mps.empty_cache()
    elif device.type == 'cuda':
        torch.cuda.empty_cache()

    all_preds  = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)

    precision = precision_score(all_labels, all_preds, average='micro', zero_division=0)
    recall    = recall_score(all_labels, all_preds, average='micro', zero_division=0)
    f1        = f1_score(all_labels, all_preds, average='micro', zero_division=0)

    return {
        'loss':      total_loss / len(dataloader),
        'precision': precision,
        'recall':    recall,
        'f1':        f1
    }

In [16]:
# Step 12: Training Loop w/ Validation, Scheduler & Early Stopping
def train_model(
    model, train_dl, val_dl, optimizer, loss_fn, device,
    epochs=10, log_every=50, scheduler=None, early_stop_patience=3
):
    best_val_loss = float('inf')
    no_improve    = 0
    history = {'train_loss': [], 'val_loss': [], 'val_f1': []}

    for epoch in range(1, epochs+1):
        model.train()
        epoch_loss = 0.0
        loop = tqdm(enumerate(train_dl, 1), total=len(train_dl),
                    desc=f"Epoch {epoch}/{epochs}")

        for i, batch in loop:
            input_ids = batch['input_ids'].to(device)
            attn_mask = batch['attention_mask'].to(device)
            labels    = batch['labels'].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attn_mask)
            logits  = outputs.logits
            loss    = loss_fn(logits, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            if i % log_every == 0:
                loop.set_postfix(train_loss=epoch_loss / i)

        avg_train = epoch_loss / len(train_dl)
        history['train_loss'].append(avg_train)
        print(f"→ Epoch {epoch} train loss: {avg_train:.4f}")

        # Validate
        val_metrics = evaluate_model(model, val_dl, loss_fn, device)
        val_loss = val_metrics['loss']
        history['val_loss'].append(val_loss)
        history['val_f1'].append(val_metrics['f1'])
        print(f"→ Val loss: {val_loss:.4f}, Val F1: {val_metrics['f1']:.4f}")

        # Scheduler step
        if scheduler:
            scheduler.step(val_loss)

        # Early stopping + checkpoint
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improve = 0
            torch.save(model.state_dict(), "best_model.pth")
        else:
            no_improve += 1
            if no_improve >= early_stop_patience:
                print(f"Early stopping at epoch {epoch}")
                break

    return history

In [17]:
# Step 12: Run Training
history = train_model(
    model, train_dl, val_dl, optimizer, loss_fn, device,
    epochs=10, log_every=50,
    scheduler=scheduler,
    early_stop_patience=3
)

Epoch 1/10: 100%|██████████| 268/268 [50:03<00:00, 11.21s/it, train_loss=nan]


→ Epoch 1 train loss: nan


Validating: 100%|██████████| 67/67 [01:15<00:00,  1.12s/it]


→ Val loss: nan, Val F1: 0.0000


Epoch 2/10:  10%|▉         | 26/268 [07:19<1:08:14, 16.92s/it]


KeyboardInterrupt: 