In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# MedAlign Health: Aligning a Clinical AI Assistant Using RLHF

## Implementation Notebook

This notebook implements the RLHF pipeline for aligning a clinical AI assistant to produce safe, empathetic, and medically appropriate patient-facing responses.

**Industry:** Digital Health / Telehealth
**Company:** MedAlign Health (fictional)
**Core Problem:** Align a 7B parameter clinical assistant using physician preferences

---

## 1. Setup and Data Loading

In [None]:
!pip install torch transformers datasets matplotlib seaborn pandas -q

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from collections import Counter
import random

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

In [None]:
# Synthetic clinical preference dataset
# In production, this would be real de-identified clinical conversations
# annotated by board-certified physicians

CLINICAL_PROMPTS = [
    "I have been having persistent headaches for the past week.",
    "My child has a fever of 101F. Should I be worried?",
    "I am feeling very anxious about my upcoming surgery.",
    "I noticed a rash on my arm that appeared yesterday.",
    "I have been feeling unusually tired for the past month.",
    "My back pain has not improved after two weeks.",
    "I am having difficulty sleeping and feel stressed.",
    "I forgot to take my medication yesterday. What should I do?",
    "I read online that my symptoms could mean cancer.",
    "Can you tell me if I should take ibuprofen or acetaminophen?",
]

# Physician-preferred responses (empathetic, appropriate, safe)
PREFERRED_RESPONSES = [
    "I understand that persistent headaches can be concerning. This could be related to tension, dehydration, or other common factors. I recommend keeping a headache diary noting when they occur and discussing the pattern with your physician at your next visit.",
    "I understand your concern. A fever of 101F in children is quite common and often the body's natural response to fighting infection. Keep your child hydrated and comfortable. If the fever persists beyond 48 hours or exceeds 103F, please contact your pediatrician.",
    "It is completely normal to feel anxious before surgery. Many patients experience this. I recommend discussing your specific concerns with your surgical team — they can walk you through what to expect and address any questions you have.",
    "New rashes can have many causes including contact irritation, allergies, or minor infections. If the rash is not spreading, painful, or accompanied by fever, you can monitor it for a few days. If it changes or you develop other symptoms, consult your dermatologist.",
    "Persistent fatigue lasting a month warrants attention. Common causes include sleep quality, stress, nutritional factors, or other treatable conditions. I recommend scheduling an appointment with your physician to discuss this and potentially run some routine blood work.",
    "I am sorry to hear your back pain has continued. Two weeks of persistent pain is a good reason to follow up with your physician. In the meantime, gentle stretching and avoiding prolonged sitting may help. Your physician can determine if imaging or physical therapy would be beneficial.",
    "Sleep difficulties and stress often go hand in hand, and both are very common. Some helpful strategies include maintaining a consistent sleep schedule, limiting screen time before bed, and gentle relaxation techniques. If these issues persist, your physician can discuss additional options.",
    "Missing a single dose happens to everyone. For most medications, you can take it as soon as you remember, unless it is close to your next scheduled dose. However, since medication instructions vary, I recommend checking with your pharmacist or physician for specific guidance on your medication.",
    "I understand how frightening it can be to read about possible diagnoses online. It is important to know that many symptoms have multiple possible causes, and the most common explanations are usually benign. The best next step is to discuss your specific symptoms with your physician who can provide an accurate assessment.",
    "Both ibuprofen and acetaminophen are effective for pain relief, but they work differently. The best choice depends on your specific situation, other medications you take, and your medical history. I recommend discussing this with your pharmacist or physician who can give you personalized guidance.",
]

# Rejected responses (too clinical, alarming, or out of scope)
REJECTED_RESPONSES = [
    "Persistent headaches can be indicative of various conditions including tension-type cephalalgia, migraine with or without aura, cluster headaches, or in rare cases intracranial pathology. You should get a CT scan.",
    "Febrile episodes in pediatric patients can indicate bacterial or viral infections. Monitor for febrile seizures which occur in 2-5% of children. Administer antipyretics per weight-based dosing at 10-15mg/kg.",
    "Pre-operative anxiety is a documented phenomenon. Consider asking your anesthesiologist about benzodiazepine premedication. The mortality rate for most surgeries is very low.",
    "Dermatological presentations vary widely. This could be contact dermatitis, urticaria, psoriasis, eczema, or in some cases early presentation of autoimmune conditions. Apply hydrocortisone cream twice daily.",
    "Chronic fatigue can be a symptom of numerous conditions including hypothyroidism, anemia, diabetes mellitus, chronic fatigue syndrome, depression, or malignancy. Get a comprehensive metabolic panel and CBC.",
    "Persistent lower back pain may indicate lumbar disc herniation, spinal stenosis, spondylolisthesis, or other structural abnormalities. You should get an MRI and consider a referral to a spine specialist.",
    "Insomnia and stress are comorbid conditions often treated pharmacologically. Consider melatonin supplementation or discuss SSRIs with your prescriber. Cognitive behavioral therapy for insomnia has evidence-based efficacy.",
    "Missed doses can alter drug pharmacokinetics and potentially reduce therapeutic efficacy. Double the next dose to compensate for the missed one. Monitor for any breakthrough symptoms.",
    "Your symptoms could indeed be consistent with oncological processes. Statistical likelihood depends on your demographic factors. I recommend urgent referral to oncology for comprehensive workup.",
    "Take ibuprofen 400mg every 6 hours. It is an NSAID with anti-inflammatory properties superior to acetaminophen for musculoskeletal pain. Avoid if you have renal impairment or GI ulcer history.",
]

print(f"Created {len(CLINICAL_PROMPTS)} clinical preference pairs")
print(f"Preferred response avg length: {np.mean([len(r.split()) for r in PREFERRED_RESPONSES]):.0f} words")
print(f"Rejected response avg length: {np.mean([len(r.split()) for r in REJECTED_RESPONSES]):.0f} words")

## 2. Exploratory Data Analysis

In [None]:
# TODO: Analyze patterns in physician preferences
def analyze_preference_patterns():
    """
    Analyze the preference dataset to understand what physicians value.

    TODO:
    1. Compare word count distributions between preferred and rejected
    2. Identify empathy markers (words like 'understand', 'normal', 'concern')
    3. Identify alarm markers (words like 'cancer', 'mortality', 'urgent')
    4. Create a bar chart of empathy vs alarm word frequency
    """
    # Word counts
    pref_lengths = [len(r.split()) for r in PREFERRED_RESPONSES]
    rej_lengths = [len(r.split()) for r in REJECTED_RESPONSES]

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    axes[0].bar(['Preferred', 'Rejected'],
                [np.mean(pref_lengths), np.mean(rej_lengths)],
                color=['#66b3ff', '#ff9999'], edgecolor='black')
    axes[0].set_ylabel('Average Word Count')
    axes[0].set_title('Response Length: Preferred vs Rejected')
    axes[0].grid(True, alpha=0.3, axis='y')

    # Empathy vs alarm words
    empathy_words = {'understand', 'normal', 'common', 'concern', 'recommend',
                     'helpful', 'sorry', 'completely', 'natural'}
    alarm_words = {'cancer', 'mortality', 'urgent', 'immediately', 'seizure',
                   'malignancy', 'pathology', 'double', 'urgent'}

    pref_empathy = sum(1 for r in PREFERRED_RESPONSES
                       for w in r.lower().split() if w.strip('.,') in empathy_words)
    pref_alarm = sum(1 for r in PREFERRED_RESPONSES
                     for w in r.lower().split() if w.strip('.,') in alarm_words)
    rej_empathy = sum(1 for r in REJECTED_RESPONSES
                      for w in r.lower().split() if w.strip('.,') in empathy_words)
    rej_alarm = sum(1 for r in REJECTED_RESPONSES
                    for w in r.lower().split() if w.strip('.,') in alarm_words)

    x = np.arange(2)
    width = 0.35
    axes[1].bar(x - width/2, [pref_empathy, pref_alarm], width,
                label='Preferred', color='#66b3ff')
    axes[1].bar(x + width/2, [rej_empathy, rej_alarm], width,
                label='Rejected', color='#ff9999')
    axes[1].set_xticks(x)
    axes[1].set_xticklabels(['Empathy Words', 'Alarm Words'])
    axes[1].set_ylabel('Count')
    axes[1].set_title('Language Patterns in Preferences')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.show()

    print("\nKey finding: Preferred responses use more empathy language")
    print("and fewer alarming medical terms.")

analyze_preference_patterns()

## 3. Baseline Model (SFT)

In [None]:
# For this case study, we use a simple model to demonstrate the pipeline
# In production, this would be a 7B parameter model like Llama-2-7B

class SimpleLanguageModel(nn.Module):
    """Simplified language model for demonstrating RLHF mechanics."""

    def __init__(self, vocab_size=5000, embed_dim=128, hidden_dim=256, n_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(embed_dim, nhead=4, dim_feedforward=hidden_dim,
                                       batch_first=True)
            for _ in range(n_layers)
        ])
        self.output_head = nn.Linear(embed_dim, vocab_size)
        self.embed_dim = embed_dim

    def forward(self, input_ids, return_hidden=False):
        x = self.embedding(input_ids)
        for layer in self.layers:
            x = layer(x)
        if return_hidden:
            return x
        logits = self.output_head(x)
        return logits

    def get_hidden(self, input_ids):
        return self.forward(input_ids, return_hidden=True)

model = SimpleLanguageModel().to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Reward Model

In [None]:
class ClinicalRewardModel(nn.Module):
    """
    Reward model with dual heads:
    - Scalar reward head (trained on physician preferences)
    - Safety classifier head (trained on safety annotations)
    """

    def __init__(self, backbone, hidden_dim=128):
        super().__init__()
        self.backbone = backbone
        self.reward_head = nn.Linear(hidden_dim, 1)
        self.safety_head = nn.Linear(hidden_dim, 2)  # safe / needs escalation

    def forward(self, input_ids):
        hidden = self.backbone.get_hidden(input_ids)
        pooled = hidden.mean(dim=1)  # Average pooling
        reward = self.reward_head(pooled).squeeze(-1)
        safety_logits = self.safety_head(pooled)
        return reward, safety_logits

# TODO: Implement the training loop for the reward model
def train_reward_model(reward_model, train_data, num_epochs=10, lr=1e-3):
    """
    Train the reward model on physician preference pairs.

    Args:
        reward_model: ClinicalRewardModel instance
        train_data: list of (preferred_ids, rejected_ids, safety_label) tuples
        num_epochs: number of training epochs
        lr: learning rate

    TODO:
    1. For each pair, compute rewards for both preferred and rejected
    2. Compute Bradley-Terry loss: -log(sigma(r_pref - r_rej))
    3. Compute safety classification loss (cross-entropy)
    4. Combine: total = bt_loss + 2.0 * safety_loss
    5. Track and return training metrics
    """
    optimizer = torch.optim.Adam(reward_model.parameters(), lr=lr)
    losses = []

    for epoch in range(num_epochs):
        epoch_loss = 0
        for pref_ids, rej_ids, safety_label in train_data:
            r_pref, safety_pref = reward_model(pref_ids.unsqueeze(0).to(device))
            r_rej, safety_rej = reward_model(rej_ids.unsqueeze(0).to(device))

            bt_loss = -F.logsigmoid(r_pref - r_rej).mean()
            safety_target = torch.tensor([safety_label], device=device)
            safety_loss = F.cross_entropy(safety_pref, safety_target)

            total_loss = bt_loss + 2.0 * safety_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            epoch_loss += total_loss.item()

        avg_loss = epoch_loss / len(train_data)
        losses.append(avg_loss)
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1}/{num_epochs} — Loss: {avg_loss:.4f}")

    return losses

# Create synthetic training data (tokenized as random IDs for demo)
train_pairs = []
for i in range(len(CLINICAL_PROMPTS)):
    pref_ids = torch.randint(0, 5000, (50,))
    rej_ids = torch.randint(0, 5000, (50,))
    safety_label = 0  # Most responses are safe
    train_pairs.append((pref_ids, rej_ids, safety_label))

# Add some safety-critical examples
for _ in range(3):
    pref_ids = torch.randint(0, 5000, (50,))
    rej_ids = torch.randint(0, 5000, (50,))
    train_pairs.append((pref_ids, rej_ids, 1))  # Needs escalation

backbone = SimpleLanguageModel().to(device)
reward_model = ClinicalRewardModel(backbone).to(device)
losses = train_reward_model(reward_model, train_pairs)

plt.figure(figsize=(8, 4))
plt.plot(losses, 'b-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Reward Model Training Loss')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 5. PPO Training with KL Penalty

In [None]:
def rlhf_step(model, ref_model, reward_model, optimizer,
              input_ids, beta=0.1, epsilon=0.2):
    """
    One RLHF optimization step.

    TODO: Implement the full PPO step:
    1. Get model logits and ref_model logits
    2. Compute per-token KL divergence
    3. Get reward from reward_model
    4. Compute total reward = reward_RM - beta * KL
    5. Compute PPO clipped loss
    6. Update model
    """
    model.train()

    # Forward through current model
    logits = model(input_ids)
    log_probs = F.log_softmax(logits, dim=-1)

    # Forward through reference model (no grad)
    with torch.no_grad():
        ref_logits = ref_model(input_ids)
        ref_log_probs = F.log_softmax(ref_logits, dim=-1)

    # Per-token KL
    token_ids = input_ids[:, 1:]
    model_token_lp = log_probs[:, :-1].gather(2, token_ids.unsqueeze(-1)).squeeze(-1)
    ref_token_lp = ref_log_probs[:, :-1].gather(2, token_ids.unsqueeze(-1)).squeeze(-1)
    kl = (model_token_lp - ref_token_lp)

    # Reward
    with torch.no_grad():
        reward, safety = reward_model(input_ids)

    # Total reward = reward_RM - beta * total_KL
    total_kl = kl.abs().mean()
    effective_reward = reward - beta * total_kl

    # Simple policy gradient (REINFORCE with baseline)
    baseline = effective_reward.mean()
    advantage = effective_reward - baseline

    pg_loss = -(model_token_lp.mean(dim=1) * advantage.detach()).mean()

    optimizer.zero_grad()
    pg_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    return {
        'loss': pg_loss.item(),
        'reward': reward.mean().item(),
        'kl': total_kl.item(),
        'effective_reward': effective_reward.mean().item(),
    }

## 6. Full Training Loop

In [None]:
# Initialize models
active_model = SimpleLanguageModel().to(device)
ref_model = SimpleLanguageModel().to(device)
ref_model.load_state_dict(active_model.state_dict())
ref_model.eval()
for p in ref_model.parameters():
    p.requires_grad = False

reward_backbone = SimpleLanguageModel().to(device)
reward_model = ClinicalRewardModel(reward_backbone).to(device)
# Pre-train reward model
_ = train_reward_model(reward_model, train_pairs, num_epochs=20)

optimizer = torch.optim.Adam(active_model.parameters(), lr=1e-4)

# Training
metrics_history = {'rewards': [], 'kl': [], 'losses': [], 'effective_rewards': []}
num_steps = 100

print("Starting RLHF training...")
for step in range(num_steps):
    # Random input (in production: real patient queries)
    input_ids = torch.randint(0, 5000, (4, 50)).to(device)

    metrics = rlhf_step(active_model, ref_model, reward_model,
                        optimizer, input_ids, beta=0.1)

    for key in metrics_history:
        short_key = key.rstrip('s') if key != 'losses' else 'loss'
        if key == 'effective_rewards':
            metrics_history[key].append(metrics['effective_reward'])
        elif key == 'rewards':
            metrics_history[key].append(metrics['reward'])
        elif key == 'losses':
            metrics_history[key].append(metrics['loss'])
        elif key == 'kl':
            metrics_history[key].append(metrics['kl'])

    if (step + 1) % 25 == 0:
        print(f"Step {step+1}/{num_steps} — "
              f"Reward: {metrics['reward']:.3f}, "
              f"KL: {metrics['kl']:.4f}, "
              f"Eff. Reward: {metrics['effective_reward']:.3f}")

print("Training complete!")

## 7. Evaluation and Results

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

axes[0, 0].plot(metrics_history['rewards'], 'b-', linewidth=1.5)
axes[0, 0].set_title('Reward Model Score')
axes[0, 0].set_xlabel('Step')
axes[0, 0].set_ylabel('Reward')
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].plot(metrics_history['kl'], 'r-', linewidth=1.5)
axes[0, 1].set_title('KL Divergence from Reference')
axes[0, 1].set_xlabel('Step')
axes[0, 1].set_ylabel('KL')
axes[0, 1].grid(True, alpha=0.3)

axes[1, 0].plot(metrics_history['losses'], 'g-', linewidth=1.5)
axes[1, 0].set_title('Policy Gradient Loss')
axes[1, 0].set_xlabel('Step')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].grid(True, alpha=0.3)

axes[1, 1].plot(metrics_history['effective_rewards'], 'm-', linewidth=1.5)
axes[1, 1].set_title('Effective Reward (RM - beta*KL)')
axes[1, 1].set_xlabel('Step')
axes[1, 1].set_ylabel('Effective Reward')
axes[1, 1].grid(True, alpha=0.3)

plt.suptitle('MedAlign RLHF Training Dashboard', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 8. Error Analysis

In [None]:
# TODO: Implement error analysis
def error_analysis():
    """
    Analyze failure modes of the aligned model.

    TODO:
    1. Score all test responses with the reward model
    2. Identify the bottom 20% (lowest-scoring)
    3. Categorize failures by type
    4. Create a confusion matrix for safety classification
    5. Plot the reward distribution highlighting the failure tail
    """
    # Simulate scores for analysis
    scores = np.random.normal(1.5, 0.8, 100)
    safety_preds = np.random.binomial(1, 0.05, 100)  # 5% flagged
    safety_true = np.random.binomial(1, 0.03, 100)   # 3% actually need escalation

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Reward distribution
    axes[0].hist(scores, bins=20, color='#66b3ff', edgecolor='black', alpha=0.7)
    threshold = np.percentile(scores, 20)
    axes[0].axvline(x=threshold, color='red', linestyle='--',
                    label=f'Bottom 20% threshold ({threshold:.2f})')
    axes[0].set_xlabel('Reward Score')
    axes[0].set_ylabel('Count')
    axes[0].set_title('Response Quality Distribution')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Safety confusion matrix
    from sklearn.metrics import confusion_matrix
    try:
        cm = confusion_matrix(safety_true, safety_preds)
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[1],
                    xticklabels=['Safe', 'Escalate'],
                    yticklabels=['Safe', 'Escalate'])
        axes[1].set_xlabel('Predicted')
        axes[1].set_ylabel('Actual')
        axes[1].set_title('Safety Classification Confusion Matrix')
    except ImportError:
        axes[1].text(0.5, 0.5, 'Install sklearn for confusion matrix',
                     ha='center', va='center', transform=axes[1].transAxes)

    plt.tight_layout()
    plt.show()

    print(f"\nSafety compliance rate: {(1 - safety_preds.mean()) * 100:.1f}%")
    print(f"False escalation rate: {(safety_preds.sum() - (safety_preds & safety_true).sum()) / max(safety_preds.sum(), 1) * 100:.1f}%")

error_analysis()

## 9. Production Deployment Considerations

In [None]:
# TODO: Implement deployment preparation
def prepare_deployment_config():
    """
    Generate deployment configuration for the aligned model.

    TODO:
    1. Define the API schema (input/output formats)
    2. Configure safety thresholds
    3. Set up A/B testing parameters
    4. Define monitoring alerts
    5. Create rollback criteria
    """
    config = {
        "model": {
            "name": "clinassist-rlhf-v2",
            "parameters": "7B",
            "quantization": "none",
            "max_tokens": 256,
            "temperature": 0.7,
        },
        "safety": {
            "escalation_threshold": 0.5,
            "max_response_length": 300,
            "blocked_patterns": ["take [0-9]+ mg", "stop taking", "diagnosis:"],
            "require_disclaimer": True,
        },
        "monitoring": {
            "reward_drift_threshold": 0.5,
            "kl_alert_threshold": 5.0,
            "safety_alert_rate": 0.02,
            "logging_sample_rate": 1.0,
        },
        "ab_testing": {
            "rlhf_traffic_fraction": 0.9,
            "sft_traffic_fraction": 0.1,
            "min_duration_weeks": 4,
            "significance_level": 0.01,
        },
        "rollback": {
            "safety_compliance_min": 0.99,
            "reward_drop_threshold": 1.0,
            "auto_rollback_enabled": True,
        },
    }

    print("Deployment Configuration:")
    print("=" * 50)
    for section, params in config.items():
        print(f"\n[{section}]")
        for key, value in params.items():
            print(f"  {key}: {value}")

    return config

config = prepare_deployment_config()

print("\n\nModel ready for staged deployment:")
print("  1. Canary (1% traffic) — 1 week")
print("  2. Shadow (10% traffic, log-only) — 1 week")
print("  3. A/B Test (90/10 split) — 4 weeks")
print("  4. Full production — continuous monitoring")