# Optimized Skimlit Model

This notebook implements an optimized version of the Skimlit model for classifying segments of scientific paper abstracts. It uses BioBERT for biomedical text processing, efficient `tf.data` pipelines, mixed precision training, a learning rate schedule with warmup, and comprehensive evaluation metrics and visualizations.

# Check GPU availability

In [None]:
!nvidia-smi

Thu May 15 06:55:54 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   40C    P8             11W /   72W |       0MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

# Check available RAM

In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print(f'Available RAM: {ram_gb:.1f} GB')
print('High-RAM runtime!' if ram_gb >= 20 else 'Not using a high-RAM runtime.')

Available RAM: 56.9 GB
High-RAM runtime!


# Install required packages

In [None]:
!pip install tensorflow
!pip install transformers
!pip install scikit-learn
!pip install matplotlib
!pip install seaborn



# Import libraries

In [None]:
import tensorflow as tf
from transformers import TFBertForSequenceClassification, BertConfig, AutoTokenizer, AdamWeightDecay
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from sklearn.utils import resample
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os

# Clone the PubMed RCT repository

In [None]:
if not os.path.exists('pubmed-rct'):
    !git clone https://github.com/Franck-Dernoncourt/pubmed-rct.git

data_dir = 'pubmed-rct/PubMed_20k_RCT_numbers_replaced_with_at_sign/'

Cloning into 'pubmed-rct'...
remote: Enumerating objects: 39, done.[K
remote: Counting objects: 100% (14/14), done.[K
remote: Compressing objects: 100% (9/9), done.[K
remote: Total 39 (delta 8), reused 5 (delta 5), pack-reused 25 (from 1)[K
Receiving objects: 100% (39/39), 177.08 MiB | 16.20 MiB/s, done.
Resolving deltas: 100% (15/15), done.


# Function to parse dataset files with line numbers

In [None]:
def parse_file(filepath):
    texts, labels, line_numbers = [], [], []
    with open(filepath, 'r') as f:
        line_num = 0
        for line in f:
            if line.startswith('###'):
                line_num = 0
                continue
            if line.strip() == '':
                continue
            label, text = line.split('\t', 1)
            labels.append(label)
            texts.append(text.strip())
            line_numbers.append(line_num)
            line_num += 1
    return texts, labels, line_numbers

# Load and augment data

In [None]:
train_texts, train_labels, train_line_numbers = parse_file(data_dir + 'train.txt')
val_texts, val_labels, val_line_numbers = parse_file(data_dir + 'dev.txt')
test_texts, test_labels, test_line_numbers = parse_file(data_dir + 'test.txt')

# Add positional information to texts

In [None]:
train_texts_with_pos = [f"[LINE {ln}] {text}" for ln, text in zip(train_line_numbers, train_texts)]
val_texts_with_pos = [f"[LINE {ln}] {text}" for ln, text in zip(val_line_numbers, val_texts)]
test_texts_with_pos = [f"[LINE {ln}] {text}" for ln, text in zip(test_line_numbers, test_texts)]

# Encode labels

In [None]:
label_encoder = LabelEncoder()
train_labels_encoded = label_encoder.fit_transform(train_labels)
val_labels_encoded = label_encoder.transform(val_labels)
test_labels_encoded = label_encoder.transform(test_labels)

num_classes = len(label_encoder.classes_)

# Compute class weights

In [None]:
class_weights = compute_class_weight('balanced', classes=np.unique(train_labels_encoded), y=train_labels_encoded)
class_weights = dict(enumerate(class_weights))

# Create sample weights for training

In [None]:
train_sample_weight = np.array([class_weights[label] for label in train_labels_encoded])

# Balancing Data

In [None]:
# Balancing Data
from sklearn.utils import resample
import numpy as np
from collections import Counter

class_counts = Counter(train_labels_encoded)
min_class_count = min(class_counts.values())
print(f"Minimum class count: {min_class_count}")

resampled_texts = []
resampled_labels = []
resampled_line_numbers = []
resampled_sample_weights = []

for class_idx in range(num_classes):
    class_indices = np.where(train_labels_encoded == class_idx)[0]

    sampled_indices = resample(class_indices,
                              replace=False,
                              n_samples=min_class_count,
                              random_state=42)

    resampled_texts.extend([train_texts_with_pos[i] for i in sampled_indices])
    resampled_labels.extend([train_labels_encoded[i] for i in sampled_indices])
    resampled_line_numbers.extend([train_line_numbers[i] for i in sampled_indices])
    resampled_sample_weights.extend([train_sample_weight[i] for i in sampled_indices])

# Convert to numpy arrays for shuffling
resampled_texts = np.array(resampled_texts)
resampled_labels = np.array(resampled_labels)
resampled_line_numbers = np.array(resampled_line_numbers)
resampled_sample_weights = np.array(resampled_sample_weights)

# Shuffle the resampled data
shuffle_indices = np.random.permutation(len(resampled_texts))
resampled_texts = resampled_texts[shuffle_indices]
resampled_labels = resampled_labels[shuffle_indices]
resampled_line_numbers = resampled_line_numbers[shuffle_indices]
resampled_sample_weights = resampled_sample_weights[shuffle_indices]

# Update training data
train_texts_with_pos = resampled_texts.tolist()  # Convert back to list for tokenization
train_labels_encoded = resampled_labels
train_line_numbers = resampled_line_numbers
train_sample_weight = resampled_sample_weights

print("New class distribution:", Counter(train_labels_encoded))

Minimum class count: 13839
New class distribution: Counter({np.int64(4): 13839, np.int64(3): 13839, np.int64(0): 13839, np.int64(2): 13839, np.int64(1): 13839})


# Load BioBERT tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.1')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

# Compute token lengths to determine optimal max_length

In [None]:
token_lengths = []
for i in range(0, len(train_texts_with_pos), 1000):
    batch_texts = train_texts_with_pos[i:i+1000]
    encodings = tokenizer(batch_texts, truncation=False, padding=False)
    lengths = [len(enc) for enc in encodings['input_ids']]
    token_lengths.extend(lengths)
max_length = int(np.percentile(token_lengths, 95))
print(f'Setting max_length to {max_length} based on 95th percentile of token lengths')

Setting max_length to 77 based on 95th percentile of token lengths


# Tokenize the texts

In [None]:
train_encodings = tokenizer(train_texts_with_pos, truncation=True, padding=True, max_length=max_length, return_tensors='tf')
val_encodings = tokenizer(val_texts_with_pos, truncation=True, padding=True, max_length=max_length, return_tensors='tf')
test_encodings = tokenizer(test_texts_with_pos, truncation=True, padding=True, max_length=max_length, return_tensors='tf')

# Create tf.data datasets

In [None]:
batch_size = 32
train_dataset = tf.data.Dataset.from_tensor_slices((
    dict(train_encodings),
    train_labels_encoded,
    train_sample_weight
)).shuffle(buffer_size=10000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
val_dataset = tf.data.Dataset.from_tensor_slices((
    dict(val_encodings),
    val_labels_encoded
)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
test_dataset = tf.data.Dataset.from_tensor_slices((
    dict(test_encodings),
    test_labels_encoded
)).batch(batch_size).prefetch(tf.data.AUTOTUNE)

# Enable mixed precision

In [None]:
tf.keras.mixed_precision.set_global_policy('mixed_float16')

# Load BioBERT model


In [None]:
config = BertConfig.from_pretrained('dmis-lab/biobert-base-cased-v1.1', num_labels=num_classes, hidden_dropout_prob=0.4, attention_probs_dropout_prob=0.4)
model = TFBertForSequenceClassification.from_pretrained('dmis-lab/biobert-base-cased-v1.1', config=config, from_pt=True)

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

All PyTorch model weights were used when initializing TFBertForSequenceClassification.

Some weights or buffers of the TF 2.0 model TFBertForSequenceClassification were not initialized from the PyTorch model and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Define learning rate schedule with warmup

In [None]:
initial_lr = 1e-5
epochs = 32
num_train_steps = (len(train_texts) // batch_size) * epochs
warmup_steps = int(0.15 * num_train_steps)

class LinearScheduleWithWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_lr, warmup_steps, total_steps):
        super(LinearScheduleWithWarmup, self).__init__()
        self.initial_lr = initial_lr
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        return tf.cond(
            step < self.warmup_steps,
            lambda: self.initial_lr * (step / self.warmup_steps),
            lambda: self.initial_lr * (
                (self.total_steps - step) / (self.total_steps - self.warmup_steps)
            ),
        )

    def get_config(self): # Add get_config method
        return {
            "initial_lr": self.initial_lr,
            "warmup_steps": self.warmup_steps,
            "total_steps": self.total_steps,
        }

lr_schedule = LinearScheduleWithWarmup(initial_lr, warmup_steps, num_train_steps)

# Set up optimizer
optimizer = AdamWeightDecay(learning_rate=initial_lr,  # Use initial_lr instead of lr_schedule
                            weight_decay_rate=0.01,
                            epsilon=1e-6,
                            clipnorm=1.0)

# Create a callback to update the learning rate according to the schedule
class LRUpdateCallback(tf.keras.callbacks.Callback):
    def __init__(self, schedule):
        super(LRUpdateCallback, self).__init__()
        self.schedule = schedule

    def on_train_batch_begin(self, batch, logs=None):
        lr = self.schedule(self.model.optimizer.iterations)
        tf.keras.backend.set_value(self.model.optimizer.lr, lr)

lr_update_callback = LRUpdateCallback(lr_schedule)

# Set up optimizer

# Compile the model

In [None]:
model.compile(optimizer=optimizer,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# Set up callbacks

In [None]:
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=2, min_lr=1e-6)

# Train the model

In [None]:
history = model.fit(train_dataset,
                    validation_data=val_dataset,
                    epochs=epochs,
                    callbacks=[early_stopping, lr_reducer, lr_update_callback])

Epoch 1/32
Epoch 2/32
Epoch 3/32
Epoch 4/32
Epoch 5/32
Epoch 6/32
Epoch 7/32
Epoch 8/32
Epoch 9/32
Epoch 10/32
Epoch 11/32
Epoch 12/32
Epoch 13/32
Epoch 14/32
Epoch 15/32
Epoch 16/32
Epoch 17/32
Epoch 18/32
Epoch 19/32
Epoch 20/32


# Evaluate on test set

In [None]:
test_loss, test_accuracy = model.evaluate(test_dataset)
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')

Test Loss: 0.3060, Test Accuracy: 0.8968


# Get predictions

In [None]:
predictions = model.predict(test_dataset)
predicted_labels = np.argmax(predictions.logits, axis=1)



# Classification report

In [None]:
print('\nClassification Report:\n')
print(classification_report(test_labels_encoded, predicted_labels, target_names=label_encoder.classes_))


Classification Report:

              precision    recall  f1-score   support

  BACKGROUND       0.76      0.83      0.79      3621
 CONCLUSIONS       0.90      0.92      0.91      4571
     METHODS       0.93      0.96      0.94      9897
   OBJECTIVE       0.75      0.61      0.67      2333
     RESULTS       0.94      0.91      0.93      9713

    accuracy                           0.90     30135
   macro avg       0.86      0.85      0.85     30135
weighted avg       0.90      0.90      0.90     30135



# Confusion matrix

In [None]:
cm = confusion_matrix(test_labels_encoded, predicted_labels)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=label_encoder.classes_, yticklabels=label_encoder.classes_, cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.png')
plt.close()

# Plot training history

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.title('Accuracy Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.tight_layout()
plt.savefig('training_history.png')
plt.close()

# Save the model, tokenizer, and label encoder

In [None]:
model_dir = "skimlit_model"
os.makedirs(model_dir, exist_ok=True)

model.save_pretrained(model_dir)

tokenizer.save_pretrained(model_dir)

import joblib
joblib.dump(label_encoder, os.path.join(model_dir, "label_encoder.joblib"))

config.save_pretrained(model_dir)

In [None]:
import os
import joblib
import numpy as np
from transformers import TFBertForSequenceClassification, BertConfig, AutoTokenizer
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns

# Load the saved model, tokenizer, and label encoder
model_dir = "skimlit_model"
model = TFBertForSequenceClassification.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
label_encoder = joblib.load(os.path.join(model_dir, "label_encoder.joblib"))
config = BertConfig.from_pretrained(model_dir)


# Example usage (replace with your actual test data):
test_texts = ["This is a test sentence about a clinical trial.", "Another sentence related to medical research."]
test_texts_with_pos = [f"[LINE 0] {text}" for text in test_texts]

# Tokenize the new texts
test_encodings = tokenizer(test_texts_with_pos, truncation=True, padding=True, max_length=config.max_position_embeddings, return_tensors='tf')

# Create a tf.data dataset for prediction
test_dataset = tf.data.Dataset.from_tensor_slices(dict(test_encodings)).batch(1).prefetch(tf.data.AUTOTUNE)

# Make predictions
predictions = model.predict(test_dataset)

# Get predicted labels
predicted_labels = np.argmax(predictions.logits, axis=1)

# Decode predicted labels
predicted_labels_decoded = label_encoder.inverse_transform(predicted_labels)

# Print the decoded predictions
print("Predicted labels:")
for text, label in zip(test_texts, predicted_labels_decoded):
  print(f"Text: {text}, Predicted Label: {label}")

Some layers from the model checkpoint at skimlit_model were not used when initializing TFBertForSequenceClassification: ['dropout_37']
- This IS expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFBertForSequenceClassification were initialized from the model checkpoint at skimlit_model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForSequenceClassification for predictions without further training.


Predicted labels:
Text: This is a test sentence about a clinical trial., Predicted Label: OBJECTIVE
Text: Another sentence related to medical research., Predicted Label: METHODS
