# Module 3: Clinical NLP Model Training

This notebook fine-tunes a BERT model for ICD-10 code prediction from clinical notes.

## Objectives:
- Fine-tune ClinicalBERT for ICD-10 classification
- Train multi-label classifier for diagnosis codes
- Evaluate model with precision, recall, F1-score
- Deploy model for automated coding

---

## 1. Setup and Imports

In [None]:
import sys
import os

# Add project root to path
project_root = os.path.abspath('..')
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    sys.path.insert(0, os.path.join(project_root, 'src'))

print(f"Project root: {project_root}")

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModel,
    AdamW,
    get_linear_schedule_with_warmup
)
import numpy as np
import pandas as pd
from sklearn.metrics import (
    classification_report,
    multilabel_confusion_matrix,
    accuracy_score,
    f1_score,
    precision_score,
    recall_score
)
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.notebook import tqdm
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"Transformers version: {__import__('transformers').__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Generate Synthetic Clinical Data

Create synthetic clinical notes with associated ICD-10 codes for training.

In [None]:
# Define common ICD-10 codes and their descriptions
ICD10_CODES = {
    'I10': 'Essential (primary) hypertension',
    'E11.9': 'Type 2 diabetes mellitus without complications',
    'J44.9': 'Chronic obstructive pulmonary disease, unspecified',
    'I25.10': 'Atherosclerotic heart disease',
    'M79.3': 'Chronic pain',
    'F41.9': 'Anxiety disorder, unspecified',
    'E78.5': 'Hyperlipidemia, unspecified',
    'N18.3': 'Chronic kidney disease, stage 3',
    'J45.909': 'Unspecified asthma',
    'K21.9': 'Gastro-esophageal reflux disease'
}

# Sample clinical note templates
CLINICAL_TEMPLATES = [
    "Patient presents with elevated blood pressure readings consistently above 140/90. Diagnosed with hypertension. Started on antihypertensive medication.",
    "65-year-old male with history of type 2 diabetes. HbA1c at 7.8%. Patient counseled on diet and exercise. Metformin dose adjusted.",
    "Patient complains of chronic shortness of breath and persistent cough. Spirometry shows reduced FEV1. COPD exacerbation treated with bronchodilators.",
    "Chest pain on exertion. EKG shows ischemic changes. Cardiac catheterization reveals coronary artery disease. Scheduled for angioplasty.",
    "Chronic lower back pain for 6 months. Patient reports pain intensity 7/10. Physical therapy recommended. Prescribed NSAIDs.",
    "Patient experiencing persistent worry, restlessness, and difficulty sleeping. Diagnosed with generalized anxiety disorder. Started on SSRI.",
    "Lipid panel shows elevated LDL cholesterol at 180 mg/dL. Patient advised on lifestyle modifications. Statin therapy initiated.",
    "Lab results show creatinine 1.8 mg/dL, eGFR 45. Diagnosed with stage 3 chronic kidney disease. Nephrology referral made.",
    "Recurrent wheezing and shortness of breath triggered by exercise. Peak flow measurements reduced. Asthma diagnosis confirmed.",
    "Patient reports frequent heartburn and acid reflux, especially after meals. Upper GI endoscopy shows esophagitis. GERD diagnosed."
]

CODE_MAPPING = {
    0: 'I10',
    1: 'E11.9',
    2: 'J44.9',
    3: 'I25.10',
    4: 'M79.3',
    5: 'F41.9',
    6: 'E78.5',
    7: 'N18.3',
    8: 'J45.909',
    9: 'K21.9'
}

def generate_synthetic_dataset(num_samples=1000):
    """Generate synthetic clinical notes with ICD-10 codes"""
    data = []
    
    for _ in range(num_samples):
        # Randomly select 1-3 conditions
        num_conditions = np.random.randint(1, 4)
        selected_indices = np.random.choice(len(CLINICAL_TEMPLATES), num_conditions, replace=False)
        
        # Combine notes
        note = " ".join([CLINICAL_TEMPLATES[i] for i in selected_indices])
        
        # Create multi-label targets
        labels = np.zeros(len(ICD10_CODES), dtype=int)
        labels[selected_indices] = 1
        
        data.append({
            'clinical_note': note,
            'labels': labels.tolist()
        })
    
    return pd.DataFrame(data)

# Generate dataset
df = generate_synthetic_dataset(num_samples=1000)
print(f"Generated {len(df)} clinical notes")
print(f"\nSample note:\n{df.iloc[0]['clinical_note']}")
print(f"\nLabels: {df.iloc[0]['labels']}")

# Show label distribution
label_counts = np.sum(np.array(df['labels'].tolist()), axis=0)
print(f"\nLabel distribution:")
for idx, (code, desc) in enumerate(ICD10_CODES.items()):
    print(f"  {code} ({desc[:30]}...): {label_counts[idx]} samples")

## 3. Train-Test Split

In [None]:
# Split data
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")
print(f"Test samples: {len(test_df)}")

## 4. Create Custom Dataset

In [None]:
class ClinicalNotesDataset(Dataset):
    """Dataset for clinical notes and ICD-10 codes"""
    
    def __init__(self, dataframe, tokenizer, max_length=512):
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        note = self.data.iloc[idx]['clinical_note']
        labels = torch.tensor(self.data.iloc[idx]['labels'], dtype=torch.float)
        
        # Tokenize
        encoding = self.tokenizer(
            note,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': labels
        }

# Initialize tokenizer (using BioBERT or ClinicalBERT)
MODEL_NAME = 'dmis-lab/biobert-base-cased-v1.1'  # BioBERT for medical text
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Create datasets
train_dataset = ClinicalNotesDataset(train_df, tokenizer)
val_dataset = ClinicalNotesDataset(val_df, tokenizer)
test_dataset = ClinicalNotesDataset(test_df, tokenizer)

# Create dataloaders
BATCH_SIZE = 16
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Tokenizer loaded: {MODEL_NAME}")
print(f"Vocabulary size: {len(tokenizer)}")

## 5. Define Model Architecture

In [None]:
class ICD10Classifier(nn.Module):
    """BERT-based multi-label classifier for ICD-10 codes"""
    
    def __init__(self, model_name, num_labels, dropout=0.3):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
    
    def forward(self, input_ids, attention_mask):
        # Get BERT outputs
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # Use [CLS] token representation
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        
        # Classification
        logits = self.classifier(pooled_output)
        
        return logits

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ICD10Classifier(
    model_name=MODEL_NAME,
    num_labels=len(ICD10_CODES),
    dropout=0.3
).to(device)

print(f"Model initialized on {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## 6. Training Configuration

In [None]:
# Training hyperparameters
EPOCHS = 10
LEARNING_RATE = 2e-5
THRESHOLD = 0.5  # For binary classification

# Loss function (Binary Cross-Entropy for multi-label)
criterion = nn.BCEWithLogitsLoss()

# Optimizer
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

# Learning rate scheduler
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=total_steps // 10,
    num_training_steps=total_steps
)

# Create model directory
model_dir = Path('../models/icd10_coding')
model_dir.mkdir(parents=True, exist_ok=True)

print(f"Training configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Total Steps: {total_steps}")
print(f"  Warmup Steps: {total_steps // 10}")
print(f"  Model save path: {model_dir}")

## 7. Training Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, scheduler, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    progress_bar = tqdm(dataloader, desc='Training')
    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        # Track metrics
        total_loss += loss.item()
        preds = torch.sigmoid(logits) > THRESHOLD
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / len(dataloader)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)
    
    return avg_loss, accuracy, f1, precision, recall


def validate_epoch(model, dataloader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc='Validation')
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # Forward pass
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            
            # Track metrics
            total_loss += loss.item()
            preds = torch.sigmoid(logits) > THRESHOLD
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / len(dataloader)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)
    
    return avg_loss, accuracy, f1, precision, recall

## 8. Train the Model

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'train_f1': [],
    'train_precision': [],
    'train_recall': [],
    'val_loss': [],
    'val_acc': [],
    'val_f1': [],
    'val_precision': [],
    'val_recall': []
}

best_val_f1 = 0
best_model_path = model_dir / 'best_icd10_model.pth'

print("\n" + "="*70)
print("STARTING TRAINING")
print("="*70 + "\n")

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    print("-" * 70)
    
    # Train
    train_loss, train_acc, train_f1, train_prec, train_rec = train_epoch(
        model, train_loader, criterion, optimizer, scheduler, device
    )
    
    # Validate
    val_loss, val_acc, val_f1, val_prec, val_rec = validate_epoch(
        model, val_loader, criterion, device
    )
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['train_f1'].append(train_f1)
    history['train_precision'].append(train_prec)
    history['train_recall'].append(train_rec)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_f1'].append(val_f1)
    history['val_precision'].append(val_prec)
    history['val_recall'].append(val_rec)
    
    # Print summary
    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"  Train - Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f} | P: {train_prec:.4f} | R: {train_rec:.4f}")
    print(f"  Val   - Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f} | P: {val_prec:.4f} | R: {val_rec:.4f}")
    
    # Save best model
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_f1': val_f1,
            'val_acc': val_acc,
            'val_precision': val_prec,
            'val_recall': val_rec
        }, best_model_path)
        print(f"  ✓ Best model saved! (Val F1: {val_f1:.4f})")

print("\n" + "="*70)
print("TRAINING COMPLETE!")
print("="*70)

## 9. Visualize Training Progress

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Training Progress - ICD-10 Classification', fontsize=16, fontweight='bold')

epochs_range = range(1, len(history['train_loss']) + 1)

# Loss
axes[0, 0].plot(epochs_range, history['train_loss'], 'b-', label='Train', linewidth=2)
axes[0, 0].plot(epochs_range, history['val_loss'], 'r-', label='Val', linewidth=2)
axes[0, 0].set_title('Loss', fontsize=12)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy
axes[0, 1].plot(epochs_range, history['train_acc'], 'b-', label='Train', linewidth=2)
axes[0, 1].plot(epochs_range, history['val_acc'], 'r-', label='Val', linewidth=2)
axes[0, 1].set_title('Accuracy', fontsize=12)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# F1-Score
axes[1, 0].plot(epochs_range, history['train_f1'], 'b-', label='Train', linewidth=2)
axes[1, 0].plot(epochs_range, history['val_f1'], 'r-', label='Val', linewidth=2)
axes[1, 0].set_title('F1-Score', fontsize=12)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('F1-Score')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Precision & Recall
axes[1, 1].plot(epochs_range, history['train_precision'], 'b--', label='Train Precision', linewidth=2)
axes[1, 1].plot(epochs_range, history['val_precision'], 'r--', label='Val Precision', linewidth=2)
axes[1, 1].plot(epochs_range, history['train_recall'], 'b:', label='Train Recall', linewidth=2)
axes[1, 1].plot(epochs_range, history['val_recall'], 'r:', label='Val Recall', linewidth=2)
axes[1, 1].set_title('Precision & Recall', fontsize=12)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Score')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(model_dir / 'training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

## 10. Test Set Evaluation

In [None]:
# Load best model
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"\nLoaded best model from epoch {checkpoint['epoch'] + 1}")
print(f"  Val F1: {checkpoint['val_f1']:.4f}")
print(f"  Val Accuracy: {checkpoint['val_acc']:.4f}")

# Evaluate on test set
test_loss, test_acc, test_f1, test_prec, test_rec = validate_epoch(
    model, test_loader, criterion, device
)

print(f"\n" + "="*70)
print("TEST SET RESULTS")
print("="*70)
print(f"  Loss: {test_loss:.4f}")
print(f"  Accuracy: {test_acc:.4f}")
print(f"  F1-Score: {test_f1:.4f}")
print(f"  Precision: {test_prec:.4f}")
print(f"  Recall: {test_rec:.4f}")
print("="*70)

## 11. Per-Class Performance

In [None]:
# Get predictions for test set
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc='Testing'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        logits = model(input_ids, attention_mask)
        preds = torch.sigmoid(logits) > THRESHOLD
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

# Per-class metrics
print("\nPer-Class Performance:")
print("-" * 90)
print(f"{'ICD-10 Code':<12} {'Description':<40} {'Precision':<12} {'Recall':<12} {'F1-Score':<12}")
print("-" * 90)

for idx, (code, desc) in enumerate(ICD10_CODES.items()):
    y_true = all_labels[:, idx]
    y_pred = all_preds[:, idx]
    
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    
    print(f"{code:<12} {desc[:38]:<40} {prec:<12.4f} {rec:<12.4f} {f1:<12.4f}")

print("-" * 90)

## 12. Save Model and Metadata

In [None]:
# Save training history
with open(model_dir / 'training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

# Save model metadata
metadata = {
    'model_architecture': 'BioBERT + Linear Classifier',
    'base_model': MODEL_NAME,
    'num_labels': len(ICD10_CODES),
    'icd10_codes': ICD10_CODES,
    'threshold': THRESHOLD,
    'training': {
        'epochs': EPOCHS,
        'learning_rate': LEARNING_RATE,
        'batch_size': BATCH_SIZE,
        'train_samples': len(train_df),
        'val_samples': len(val_df),
        'test_samples': len(test_df)
    },
    'test_results': {
        'accuracy': float(test_acc),
        'f1_score': float(test_f1),
        'precision': float(test_prec),
        'recall': float(test_rec)
    },
    'timestamp': datetime.now().isoformat()
}

with open(model_dir / 'model_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

# Save tokenizer
tokenizer.save_pretrained(model_dir / 'tokenizer')

print("\n" + "="*70)
print("MODEL ARTIFACTS SAVED")
print("="*70)
print(f"\nModel directory: {model_dir}")
print(f"\nFiles saved:")
print(f"  ✓ best_icd10_model.pth - Trained model")
print(f"  ✓ training_history.json - Training metrics")
print(f"  ✓ model_metadata.json - Model configuration")
print(f"  ✓ training_curves.png - Training visualization")
print(f"  ✓ tokenizer/ - BioBERT tokenizer")
print("\n" + "="*70)
print("✓ Training notebook completed successfully!")
print("="*70)