# Training

BioBERT is trained on PubMed abstracts and full-text articles, which are rich sources of biomedical information. This pre-training process allows BioBERT to learn contextual representations of words and sentences in the biomedical domain, enabling it to better understand and process biomedical text.

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

# Load dataset
df = pd.read_csv("/kaggle/input/comp-5/comp_1.csv")

# Concatenate title and abstract text
texts = df['Title'] + " " + df['abstractText']
texts = texts.astype(str)
labels = df[['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'L', 'M', 'N', 'Z']]

# Split dataset into train, validation, and test sets
train_texts, test_texts, train_labels, test_labels = train_test_split(texts, labels, test_size=0.2)
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=0.1)

# Load BioBERT tokenizer dmis-lab/biobert-v1.1
tokenizer = BertTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')

# Tokenize the text
# Tokenize the text with max_length parameter
train_encodings = tokenizer(train_texts.tolist(), truncation=True, padding=True, max_length=512)
val_encodings = tokenizer(val_texts.tolist(), truncation=True, padding=True, max_length=512)
test_encodings = tokenizer(test_texts.tolist(), truncation=True, padding=True, max_length=512)

# Convert labels to tensors
train_labels_tensor = torch.tensor(np.array(train_labels), dtype=torch.float32)
val_labels_tensor = torch.tensor(np.array(val_labels), dtype=torch.float32)
test_labels_tensor = torch.tensor(np.array(test_labels), dtype=torch.float32)

# Create PyTorch DataLoader
train_dataset = TensorDataset(torch.tensor(train_encodings['input_ids']), 
                              torch.tensor(train_encodings['attention_mask']), 
                              train_labels_tensor)
val_dataset = TensorDataset(torch.tensor(val_encodings['input_ids']), 
                            torch.tensor(val_encodings['attention_mask']), 
                            val_labels_tensor)
test_dataset = TensorDataset(torch.tensor(test_encodings['input_ids']), 
                             torch.tensor(test_encodings['attention_mask']), 
                             test_labels_tensor)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define BioBERT model for classification
class BioBERTClassifier(nn.Module):
    def __init__(self, num_labels):
        super(BioBERTClassifier, self).__init__()
        self.bert = BertForSequenceClassification.from_pretrained('dmis-lab/biobert-base-cased-v1.2', num_labels=num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        return logits

# Instantiate model, optimizer, and loss function
model = BioBERTClassifier(num_labels=14)
model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=5e-5)
criterion = nn.BCEWithLogitsLoss()  # Binary cross-entropy loss for multilabel classification

def train_epoch(model, train_loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for batch in train_loader:
        input_ids, attention_mask, labels = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

# Validation loop
def validate(model, val_loader, criterion):
    model.eval()
    val_loss = 0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for batch in val_loader:
            input_ids, attention_mask, labels = batch
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            val_loss += loss.item()
            all_preds.extend(torch.sigmoid(logits).cpu().numpy())
            all_targets.extend(labels.cpu().numpy())
    val_loss /= len(val_loader)
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    return val_loss, all_preds, all_targets

# Training loop with early stopping
best_val_loss = float('inf')
patience = 3  # Number of epochs to wait for improvement
no_improvement = 0

for epoch in range(15):  # Adjust number of epochs as needed
    train_loss = train_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_preds, val_targets = validate(model, val_loader, criterion)
    val_f1_micro = f1_score(val_targets, np.round(val_preds), average='micro')
    print(f"Epoch {epoch + 1}: Train Loss {train_loss:.4f}, Val Loss {val_loss:.4f}, Val F1 Micro {val_f1_micro:.4f}")
    
    # Check for early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement = 0
        torch.save(model.state_dict(), '/kaggle/working/best_bio_model.pth')  # Save the best model weights
    else:
        no_improvement += 1
        if no_improvement >= patience:
            print("Early stopping triggered. No improvement in validation loss.")
            break



# Inference

In [None]:
# Test loop
model.load_state_dict(torch.load('/kaggle/working/best_bio_model.pth'))  # Load the best model weights
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
    for batch in test_loader:
        input_ids, attention_mask, labels = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)
        logits = model(input_ids, attention_mask)
        all_preds.extend(torch.sigmoid(logits).cpu().numpy())
        all_targets.extend(labels.cpu().numpy())
all_preds = np.array(all_preds)
all_targets = np.array(all_targets)

# Compute evaluation metrics
test_precision = []
test_recall = []
test_f1 = []
for i in range(14):
    class_preds = np.round(all_preds[:, i])
    class_targets = all_targets[:, i]
    precision = precision_score(class_targets, class_preds)
    recall = recall_score(class_targets, class_preds)
    f1 = f1_score(class_targets, class_preds)
    test_precision.append(precision)
    test_recall.append(recall)
    test_f1.append(f1)

print("Test Metrics:")
label_names = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'L', 'M', 'N', 'Z']
for i, label in enumerate(label_names):
    print(f"{label}: Precision: {test_precision[i]:.4f}, Recall: {test_recall[i]:.4f}, F1 Score: {test_f1[i]:.4f}")

# Calculate evaluation metrics
test_f1_micro = f1_score(all_targets, np.round(all_preds), average='micro')
test_precision_micro = precision_score(all_targets, np.round(all_preds), average='micro')
test_recall_micro = recall_score(all_targets, np.round(all_preds), average='micro')
test_accuracy = accuracy_score(all_targets, np.round(all_preds))
print(f"Test F1 Micro: {test_f1_micro:.4f}, Test Precision Micro: {test_precision_micro:.4f}, "
      f"Test Recall Micro: {test_recall_micro:.4f}, Test Accuracy: {test_accuracy:.4f}")

test_f1_macro = f1_score(all_targets, np.round(all_preds), average='macro')
test_precision_macro = precision_score(all_targets, np.round(all_preds), average='macro')
test_recall_macro = recall_score(all_targets, np.round(all_preds), average='macro')
test_accuracy = accuracy_score(all_targets, np.round(all_preds))
print(f"Test F1 Macro: {test_f1_macro:.4f}, Test Precision Macro: {test_precision_macro:.4f}, "
      f"Test Recall Macro: {test_recall_macro:.4f}, Test Accuracy: {test_accuracy:.4f}")

![image.png](attachment:6d95a46d-82a3-47ce-9737-c742587b3d6b.png)

# Since class label are imbalanced micro should be used as parameter for accuracy of the model.

https://www.educative.io/answers/what-is-the-difference-between-micro-and-macro-averaging