# BioBERT/ClinicalBERT Transfer Learning for Clinical Notes

This notebook implements transfer learning using BioBERT and ClinicalBERT for clinical text analysis. We'll fine-tune these pretrained models on clinical notes for tasks like diagnosis prediction, named entity recognition, and medical text classification.

## Objectives
1. Load and preprocess clinical notes data
2. Implement BioBERT/ClinicalBERT models
3. Fine-tune models on clinical tasks
4. Evaluate performance with F1 and NER metrics
5. Compare with baseline models
6. Save fine-tuned models for deployment


In [7]:
# Import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, AutoModel, AutoConfig,
    TrainingArguments, Trainer, EarlyStoppingCallback
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.preprocessing import LabelEncoder
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set plotting style
plt.style.use('default')
sns.set_palette('viridis')

print("Libraries imported successfully!")


  from .autonotebook import tqdm as notebook_tqdm



Using device: cpu
Libraries imported successfully!


## 1. Load and Prepare Clinical Notes Data


In [8]:
# Load clinical notes data (using MIMIC-3 NOTEEVENTS as example)
try:
    # Try to load preprocessed clinical notes
    clinical_notes = pd.read_csv('../src/data/processed/clinical_notes.csv')
    print(f"Loaded preprocessed clinical notes: {clinical_notes.shape}")
except FileNotFoundError:
    # Load raw MIMIC-3 data if preprocessed not available
    print("Loading raw MIMIC-3 data...")
    diagnoses = pd.read_csv('../MIMIC-3/DIAGNOSES_ICD.csv')
    diagnoses_icd = pd.read_csv('../MIMIC-3/D_ICD_DIAGNOSES.csv')
    
    # Merge with ICD descriptions
    diagnoses = pd.merge(diagnoses, diagnoses_icd, on='icd9_code', how='left')
    print(f"Loaded raw diagnoses data: {diagnoses.shape}")
    
    # Check data quality
    print("Data quality check:")
    print(f"Missing values: {diagnoses.isnull().sum().sum()}")
    print(f"Unique patients: {diagnoses['subject_id'].nunique()}")
    print(f"Unique diagnoses: {diagnoses['icd9_code'].nunique()}")
    
    # Check most common diagnoses
    print("\nTop 10 most common diagnoses:")
    diagnosis_counts = diagnoses['icd9_code'].value_counts().head(10)
    print(diagnosis_counts)
    
    # Check if we have diagnosis descriptions
    if 'short_title' in diagnoses.columns:
        print("\nSample diagnosis descriptions:")
        sample_diagnoses = diagnoses[['icd9_code', 'short_title']].drop_duplicates().head(10)
        print(sample_diagnoses)
    else:
        print("No diagnosis descriptions available")
    
    # Create synthetic clinical notes for demonstration
    print("\nCreating synthetic clinical notes for demonstration...")
    
    np.random.seed(42)
    n_samples = 1000
    
    clinical_notes = pd.DataFrame({
        'note_text': [
            "Patient presents with chest pain and shortness of breath. Vital signs stable. ECG shows ST elevation. Troponin levels elevated. Diagnosis: Acute myocardial infarction.",
            "Patient has fever, cough, and fatigue. Chest X-ray shows pneumonia. White blood cell count elevated. Treatment with antibiotics started.",
            "Patient complains of headache and dizziness. Blood pressure elevated. CT scan normal. Diagnosis: Hypertension with headache.",
            "Patient has abdominal pain and nausea. Physical exam reveals tenderness. Blood work shows elevated liver enzymes. Diagnosis: Hepatitis.",
            "Patient presents with joint pain and swelling. Rheumatoid factor positive. X-rays show joint damage. Diagnosis: Rheumatoid arthritis."
        ] * 200,
        'diagnosis': [
            "Myocardial Infarction", "Pneumonia", "Hypertension", "Hepatitis", "Rheumatoid Arthritis"
        ] * 200,
        'category': [
            "Cardiovascular", "Respiratory", "Cardiovascular", "Gastrointestinal", "Musculoskeletal"
        ] * 200
    })
    
    print(f"Created synthetic clinical notes: {clinical_notes.shape}")


Loading raw MIMIC-3 data...
Loaded raw diagnoses data: (1761, 8)
Data quality check:
Missing values: 135
Unique patients: 100
Unique diagnoses: 581

Top 10 most common diagnoses:
icd9_code
4019     53
42731    48
5849     45
4280     39
51881    31
25000    31
2724     29
5990     27
486      26
2859     25
Name: count, dtype: int64

Sample diagnosis descriptions:
  icd9_code               short_title
0     99591                    Sepsis
1     99662  React-oth vasc dev/graft
2      5672                       NaN
3     40391    Hyp kid NOS w cr kid V
4     42731       Atrial fibrillation
5      4280                   CHF NOS
6      4241     Aortic valve disorder
7      4240     Mitral valve disorder
8      2874                       NaN
9     03819  Staphylcocc septicem NEC

Creating synthetic clinical notes for demonstration...
Created synthetic clinical notes: (1000, 3)


In [9]:
# Clean and preprocess clinical notes
def preprocess_clinical_notes(text):
    """Preprocess clinical notes for BERT input"""
    if pd.isna(text):
        return ""
    
    # Remove extra whitespace
    text = ' '.join(text.split())
    
    # Remove special characters but keep medical abbreviations
    import re
    text = re.sub(r'[^a-zA-Z0-9\s.,;:()\-]', '', text)
    
    return text.strip()

# Apply preprocessing
clinical_notes['processed_text'] = clinical_notes['note_text'].apply(preprocess_clinical_notes)

# Remove empty texts
clinical_notes = clinical_notes[clinical_notes['processed_text'].str.len() > 10]

# Encode labels
label_encoder = LabelEncoder()
clinical_notes['diagnosis_encoded'] = label_encoder.fit_transform(clinical_notes['diagnosis'])

print(f"Processed dataset shape: {clinical_notes.shape}")
print(f"Unique diagnoses: {clinical_notes['diagnosis'].nunique()}")
print(f"\nSample processed text:")
print(clinical_notes['processed_text'].iloc[0][:200] + "...")


Processed dataset shape: (1000, 5)
Unique diagnoses: 5

Sample processed text:
Patient presents with chest pain and shortness of breath. Vital signs stable. ECG shows ST elevation. Troponin levels elevated. Diagnosis: Acute myocardial infarction....


# BioBERT/ClinicalBERT Transfer Learning for Clinical Notes

This notebook implements transfer learning using BioBERT and ClinicalBERT for clinical text analysis. We'll fine-tune these pretrained models on clinical notes for tasks like diagnosis prediction, named entity recognition, and medical text classification.

## Objectives
1. Load and preprocess clinical notes data
2. Implement BioBERT/ClinicalBERT models
3. Fine-tune models on clinical tasks
4. Evaluate performance with F1 and NER metrics
5. Compare with baseline models
6. Save fine-tuned models for deployment


In [4]:

import pandas as pd
import numpy as np

In [5]:
# Load clinical notes data (using MIMIC-3 NOTEEVENTS as example)
try:
    # Try to load preprocessed clinical notes
    clinical_notes = pd.read_csv('../src/data/processed/clinical_notes.csv')
    print(f"Loaded preprocessed clinical notes: {clinical_notes.shape}")
except FileNotFoundError:
    # Load raw MIMIC-3 data if preprocessed not available
    print("Loading raw MIMIC-3 data...")
    diagnoses = pd.read_csv('../MIMIC-3/DIAGNOSES_ICD.csv')
    diagnoses_icd = pd.read_csv('../MIMIC-3/D_ICD_DIAGNOSES.csv')
    
    # Merge with ICD descriptions
    diagnoses = pd.merge(diagnoses, diagnoses_icd, on='icd9_code', how='left')
    print(f"Loaded raw diagnoses data: {diagnoses.shape}")
    
    # Check data quality
    print("Data quality check:")
    print(f"Missing values: {diagnoses.isnull().sum().sum()}")
    print(f"Unique patients: {diagnoses['subject_id'].nunique()}")
    print(f"Unique diagnoses: {diagnoses['icd9_code'].nunique()}")
    
    # Check most common diagnoses
    print("\nTop 10 most common diagnoses:")
    diagnosis_counts = diagnoses['icd9_code'].value_counts().head(10)
    print(diagnosis_counts)
    
    # Check if we have diagnosis descriptions
    if 'short_title' in diagnoses.columns:
        print("\nSample diagnosis descriptions:")
        sample_diagnoses = diagnoses[['icd9_code', 'short_title']].drop_duplicates().head(10)
        print(sample_diagnoses)
    else:
        print("No diagnosis descriptions available")
    
    # Create synthetic clinical notes for demonstration
    print("\nCreating synthetic clinical notes for demonstration...")
    
    np.random.seed(42)
    n_samples = 1000
    
    clinical_notes = pd.DataFrame({
        'note_text': [
            "Patient presents with chest pain and shortness of breath. Vital signs stable. ECG shows ST elevation. Troponin levels elevated. Diagnosis: Acute myocardial infarction.",
            "Patient has fever, cough, and fatigue. Chest X-ray shows pneumonia. White blood cell count elevated. Treatment with antibiotics started.",
            "Patient complains of headache and dizziness. Blood pressure elevated. CT scan normal. Diagnosis: Hypertension with headache.",
            "Patient has abdominal pain and nausea. Physical exam reveals tenderness. Blood work shows elevated liver enzymes. Diagnosis: Hepatitis.",
            "Patient presents with joint pain and swelling. Rheumatoid factor positive. X-rays show joint damage. Diagnosis: Rheumatoid arthritis."
        ] * 200,
        'diagnosis': [
            "Myocardial Infarction", "Pneumonia", "Hypertension", "Hepatitis", "Rheumatoid Arthritis"
        ] * 200,
        'category': [
            "Cardiovascular", "Respiratory", "Cardiovascular", "Gastrointestinal", "Musculoskeletal"
        ] * 200
    })
    
    print(f"Created synthetic clinical notes: {clinical_notes.shape}")

# Display basic information
print(f"\nDataset shape: {clinical_notes.shape}")
print(f"Columns: {clinical_notes.columns.tolist()}")
print(f"\nDiagnosis distribution:")
print(clinical_notes['diagnosis'].value_counts())
print(f"\nCategory distribution:")
print(clinical_notes['category'].value_counts())

clinical_notes.head()


Loading raw MIMIC-3 data...
Loaded raw diagnoses data: (1761, 8)
Data quality check:
Missing values: 135
Unique patients: 100
Unique diagnoses: 581

Top 10 most common diagnoses:
icd9_code
4019     53
42731    48
5849     45
4280     39
51881    31
25000    31
2724     29
5990     27
486      26
2859     25
Name: count, dtype: int64

Sample diagnosis descriptions:
  icd9_code               short_title
0     99591                    Sepsis
1     99662  React-oth vasc dev/graft
2      5672                       NaN
3     40391    Hyp kid NOS w cr kid V
4     42731       Atrial fibrillation
5      4280                   CHF NOS
6      4241     Aortic valve disorder
7      4240     Mitral valve disorder
8      2874                       NaN
9     03819  Staphylcocc septicem NEC

Creating synthetic clinical notes for demonstration...
Created synthetic clinical notes: (1000, 3)

Dataset shape: (1000, 3)
Columns: ['note_text', 'diagnosis', 'category']

Diagnosis distribution:
diagnosis
Myocar

Unnamed: 0,note_text,diagnosis,category
0,Patient presents with chest pain and shortness...,Myocardial Infarction,Cardiovascular
1,"Patient has fever, cough, and fatigue. Chest X...",Pneumonia,Respiratory
2,Patient complains of headache and dizziness. B...,Hypertension,Cardiovascular
3,Patient has abdominal pain and nausea. Physica...,Hepatitis,Gastrointestinal
4,Patient presents with joint pain and swelling....,Rheumatoid Arthritis,Musculoskeletal


## 2. Preprocess Clinical Notes


In [10]:
# Clean and preprocess clinical notes
def preprocess_clinical_notes(text):
    """Preprocess clinical notes for BERT input"""
    if pd.isna(text):
        return ""
    
    # Remove extra whitespace
    text = ' '.join(text.split())
    
    # Remove special characters but keep medical abbreviations
    import re
    text = re.sub(r'[^a-zA-Z0-9\s.,;:()\-]', '', text)
    
    return text.strip()

# Apply preprocessing
clinical_notes['processed_text'] = clinical_notes['note_text'].apply(preprocess_clinical_notes)

# Remove empty texts
clinical_notes = clinical_notes[clinical_notes['processed_text'].str.len() > 10]

# Encode labels
label_encoder = LabelEncoder()
clinical_notes['diagnosis_encoded'] = label_encoder.fit_transform(clinical_notes['diagnosis'])

print(f"Processed dataset shape: {clinical_notes.shape}")
print(f"Unique diagnoses: {clinical_notes['diagnosis'].nunique()}")
print(f"\nSample processed text:")
print(clinical_notes['processed_text'].iloc[0][:200] + "...")


Processed dataset shape: (1000, 5)
Unique diagnoses: 5

Sample processed text:
Patient presents with chest pain and shortness of breath. Vital signs stable. ECG shows ST elevation. Troponin levels elevated. Diagnosis: Acute myocardial infarction....


## 3. Create Clinical Notes Dataset Class


In [11]:
class ClinicalNotesDataset(Dataset):
    """Dataset class for clinical notes"""
    
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

print("Clinical notes dataset class created!")


Clinical notes dataset class created!
