# NEXUS: Bilirubin Regression from MedSigLIP Embeddings

**Novel Task**: Predict continuous bilirubin levels (mg/dL) from neonatal skin images.

## Why This Is Novel
MedSigLIP was trained for medical image-text similarity. Using its frozen embeddings
for **continuous bilirubin regression** is a genuinely novel application that goes
beyond its original zero-shot classification design.

## Architecture
```
Neonatal skin image
       |
  [Frozen MedSigLIP encoder]  (google/medsiglip-448)
       |
  1152-dim embedding
       |
  [Linear(1152, 256) -> ReLU -> Dropout(0.3) -> Linear(256, 1)]
       |
  Predicted bilirubin (mg/dL)
```

## Dataset
- **NeoJaundice**: 2,235 neonatal images with ground truth serum bilirubin (mg/dL)
- Split: 70% train / 15% val / 15% test

## HAI-DEF Model
- **MedSigLIP** (`google/medsiglip-448`) — frozen vision encoder

In [None]:
# Setup & Imports
import sys
import os
import json
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from scipy import stats
from sklearn.model_selection import train_test_split

# Add project root to path
sys.path.insert(0, '../src')
sys.path.insert(0, '../scripts/training')

print(f"PyTorch: {torch.__version__}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

## 1. Load & Explore the NeoJaundice Dataset

In [None]:
DATA_DIR = Path("../data/raw/neojaundice")
CSV_PATH = DATA_DIR / "chd_jaundice_published_2.csv"
IMAGES_DIR = DATA_DIR / "images"

df = pd.read_csv(CSV_PATH)
print(f"Total records: {len(df)}")
print(f"\nColumns: {list(df.columns)}")
print(f"\nBilirubin statistics:")
print(df["blood(mg/dL)"].describe())
df.head()

In [None]:
# Distribution of bilirubin values
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].hist(df["blood(mg/dL)"], bins=40, edgecolor='black', alpha=0.7, color='#2196F3')
axes[0].set_xlabel("Bilirubin (mg/dL)")
axes[0].set_ylabel("Count")
axes[0].set_title("Distribution of Serum Bilirubin")
axes[0].axvline(df["blood(mg/dL)"].mean(), color='red', linestyle='--', label=f'Mean: {df["blood(mg/dL)"].mean():.1f}')
axes[0].legend()

axes[1].boxplot(df["blood(mg/dL)"], vert=True)
axes[1].set_ylabel("Bilirubin (mg/dL)")
axes[1].set_title("Bilirubin Box Plot")

plt.tight_layout()
plt.show()

## 2. Extract MedSigLIP Embeddings

We use the frozen MedSigLIP vision encoder to extract 1152-dimensional embeddings
for each neonatal image. These embeddings are cached for reuse.

In [None]:
# Filter to images that exist on disk
image_paths = []
bilirubin_values = []

for _, row in df.iterrows():
    img_path = IMAGES_DIR / row["image_idx"]
    if img_path.exists():
        image_paths.append(img_path)
        bilirubin_values.append(float(row["blood(mg/dL)"]))

bilirubin_values = np.array(bilirubin_values, dtype=np.float32)
print(f"Valid image-label pairs: {len(image_paths)}")
print(f"Bilirubin range: {bilirubin_values.min():.1f} - {bilirubin_values.max():.1f} mg/dL")

In [None]:
CACHE_DIR = Path("../models/linear_probes")
EMB_CACHE = CACHE_DIR / "jaundice_regression_embeddings.npy"

if EMB_CACHE.exists():
    print("Loading cached embeddings...")
    embeddings = np.load(EMB_CACHE)
    print(f"Embeddings shape: {embeddings.shape}")
else:
    print("Extracting MedSigLIP embeddings (requires HF_TOKEN)...")
    from train_linear_probes import EmbeddingExtractor
    extractor = EmbeddingExtractor()
    embeddings = extractor.extract_batch_embeddings(image_paths, batch_size=8)
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    np.save(EMB_CACHE, embeddings)
    print(f"Saved embeddings: {embeddings.shape}")

## 3. Baseline: Color-Based Bilirubin Estimation

Before training the ML model, we evaluate the existing color-based estimator
which uses a simple yellow-index formula.

In [None]:
from PIL import Image

def estimate_bilirubin_color(img_path):
    """Color-based bilirubin estimate (yellow-index formula)."""
    img = np.array(Image.open(img_path).convert("RGB")).astype(float)
    r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
    yellow_index = (r + g - b) / (r + g + b + 1e-6)
    return max(0, (np.mean(yellow_index) - 0.2) * 50)

# Evaluate color-based method on a sample
n_sample = min(200, len(image_paths))
indices = np.random.RandomState(42).choice(len(image_paths), n_sample, replace=False)

color_preds = []
color_actuals = []
for idx in indices:
    pred = estimate_bilirubin_color(image_paths[idx])
    color_preds.append(pred)
    color_actuals.append(bilirubin_values[idx])

color_preds = np.array(color_preds)
color_actuals = np.array(color_actuals)

color_mae = np.abs(color_preds - color_actuals).mean()
color_rmse = np.sqrt(np.mean((color_preds - color_actuals)**2))
color_r, _ = stats.pearsonr(color_preds, color_actuals)

print("=== Color-Based Baseline ===")
print(f"MAE:       {color_mae:.3f} mg/dL")
print(f"RMSE:      {color_rmse:.3f} mg/dL")
print(f"Pearson r: {color_r:.4f}")

## 4. Train Bilirubin Regression Head

Architecture: `Linear(D, 256) -> ReLU -> Dropout(0.3) -> Linear(256, 1)`

- **Loss**: Huber loss (robust to outliers)
- **Optimizer**: Adam with weight decay
- **Scheduler**: ReduceLROnPlateau
- **Early stopping**: patience=15

In [None]:
class BilirubinRegressorHead(nn.Module):
    """2-layer MLP regression head."""
    def __init__(self, input_dim=1152, hidden_dim=256, dropout=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )
    def forward(self, x):
        return self.net(x).squeeze(-1)

# Data split: 70/15/15
SEED = 42
X_trainval, X_test, y_trainval, y_test = train_test_split(
    embeddings, bilirubin_values, test_size=0.15, random_state=SEED)
X_train, X_val, y_train, y_val = train_test_split(
    X_trainval, y_trainval, test_size=0.15/0.85, random_state=SEED)

print(f"Train: {len(X_train)}, Val: {len(X_val)}, Test: {len(X_test)}")
print(f"Embedding dim: {embeddings.shape[1]}")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
input_dim = embeddings.shape[1]

# Convert to tensors
X_train_t = torch.tensor(X_train, dtype=torch.float32).to(device)
y_train_t = torch.tensor(y_train, dtype=torch.float32).to(device)
X_val_t = torch.tensor(X_val, dtype=torch.float32).to(device)
y_val_t = torch.tensor(y_val, dtype=torch.float32).to(device)
X_test_t = torch.tensor(X_test, dtype=torch.float32).to(device)
y_test_t = torch.tensor(y_test, dtype=torch.float32).to(device)

model = BilirubinRegressorHead(input_dim=input_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
criterion = nn.HuberLoss(delta=2.0)

print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Training loop
EPOCHS = 100
BATCH_SIZE = 64
PATIENCE = 15

best_val_loss = float("inf")
best_state = None
patience_ctr = 0
history = {"train_loss": [], "val_loss": [], "val_mae": []}

for epoch in range(EPOCHS):
    model.train()
    indices = torch.randperm(len(X_train_t))
    epoch_loss, n_batches = 0.0, 0
    for i in range(0, len(indices), BATCH_SIZE):
        batch_idx = indices[i:i+BATCH_SIZE]
        optimizer.zero_grad()
        pred = model(X_train_t[batch_idx])
        loss = criterion(pred, y_train_t[batch_idx])
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        n_batches += 1
    avg_train = epoch_loss / max(n_batches, 1)

    model.eval()
    with torch.no_grad():
        val_pred = model(X_val_t)
        val_loss = criterion(val_pred, y_val_t).item()
        val_mae = torch.abs(val_pred - y_val_t).mean().item()
    scheduler.step(val_loss)

    history["train_loss"].append(avg_train)
    history["val_loss"].append(val_loss)
    history["val_mae"].append(val_mae)

    if (epoch+1) % 10 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:3d} | Train: {avg_train:.4f} | Val: {val_loss:.4f} | Val MAE: {val_mae:.2f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_state = model.state_dict().copy()
        patience_ctr = 0
    else:
        patience_ctr += 1
        if patience_ctr >= PATIENCE:
            print(f"Early stopping at epoch {epoch+1}")
            break

if best_state:
    model.load_state_dict(best_state)
print(f"\nBest val loss: {best_val_loss:.4f}")

## 5. Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history["train_loss"], label="Train", alpha=0.8)
axes[0].plot(history["val_loss"], label="Validation", alpha=0.8)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Huber Loss")
axes[0].set_title("Training & Validation Loss")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history["val_mae"], color='green', alpha=0.8)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("MAE (mg/dL)")
axes[1].set_title("Validation MAE")
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Test Set Evaluation

In [None]:
model.eval()
with torch.no_grad():
    test_pred = model(X_test_t).cpu().numpy()
    test_actual = y_test_t.cpu().numpy()

ml_mae = np.abs(test_pred - test_actual).mean()
ml_rmse = np.sqrt(np.mean((test_pred - test_actual)**2))
ml_r, ml_p = stats.pearsonr(test_pred, test_actual)

print("=" * 50)
print("TEST SET RESULTS — MedSigLIP Regressor")
print("=" * 50)
print(f"MAE:       {ml_mae:.3f} mg/dL")
print(f"RMSE:      {ml_rmse:.3f} mg/dL")
print(f"Pearson r: {ml_r:.4f} (p={ml_p:.2e})")

## 7. Comparison Plots

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# 7a. Predicted vs Actual scatter
ax = axes[0]
ax.scatter(test_actual, test_pred, alpha=0.4, s=15, color='#2196F3')
lims = [min(test_actual.min(), test_pred.min()), max(test_actual.max(), test_pred.max())]
ax.plot(lims, lims, 'r--', linewidth=1.5, label='Perfect prediction')
ax.set_xlabel("Actual Bilirubin (mg/dL)")
ax.set_ylabel("Predicted Bilirubin (mg/dL)")
ax.set_title(f"Predicted vs Actual (r={ml_r:.3f})")
ax.legend()
ax.grid(True, alpha=0.3)

# 7b. Bland-Altman plot
ax = axes[1]
diff = test_pred - test_actual
mean_vals = (test_pred + test_actual) / 2
mean_diff = np.mean(diff)
std_diff = np.std(diff)
loa_upper = mean_diff + 1.96 * std_diff
loa_lower = mean_diff - 1.96 * std_diff

ax.scatter(mean_vals, diff, alpha=0.4, s=15, color='#4CAF50')
ax.axhline(mean_diff, color='red', linestyle='-', label=f'Mean: {mean_diff:.2f}')
ax.axhline(loa_upper, color='orange', linestyle='--', label=f'+1.96 SD: {loa_upper:.2f}')
ax.axhline(loa_lower, color='orange', linestyle='--', label=f'-1.96 SD: {loa_lower:.2f}')
ax.set_xlabel("Mean of Predicted & Actual (mg/dL)")
ax.set_ylabel("Difference (Predicted - Actual)")
ax.set_title("Bland-Altman Plot")
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# 7c. Error distribution
ax = axes[2]
errors = test_pred - test_actual
ax.hist(errors, bins=30, edgecolor='black', alpha=0.7, color='#FF9800')
ax.axvline(0, color='red', linestyle='--', linewidth=1.5)
ax.set_xlabel("Prediction Error (mg/dL)")
ax.set_ylabel("Count")
ax.set_title(f"Error Distribution (MAE={ml_mae:.2f})")
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8. Before vs After Comparison

| Metric | Color-Based (Before) | MedSigLIP Regressor (After) |
|--------|---------------------|-----------------------------|
| MAE    | Higher              | Lower (trained on data)     |
| RMSE   | Higher              | Lower                       |
| Pearson r | Lower            | Higher (captures embeddings)|

The MedSigLIP-based regressor leverages deep learned features from a medical vision
model, capturing patterns invisible to simple color analysis.

In [None]:
print("\n" + "=" * 60)
print("BEFORE vs AFTER COMPARISON")
print("=" * 60)
print(f"{'Metric':<15} {'Color-Based':<20} {'MedSigLIP Regressor':<20}")
print("-" * 55)
print(f"{'MAE (mg/dL)':<15} {color_mae:<20.3f} {ml_mae:<20.3f}")
print(f"{'RMSE (mg/dL)':<15} {color_rmse:<20.3f} {ml_rmse:<20.3f}")
print(f"{'Pearson r':<15} {color_r:<20.4f} {ml_r:<20.4f}")
print("=" * 60)

if ml_mae < color_mae:
    improvement = (1 - ml_mae / color_mae) * 100
    print(f"\nMAE improvement: {improvement:.1f}% reduction with MedSigLIP regressor")

## 9. Save Trained Model

In [None]:
output_dir = Path("../models/linear_probes")
output_dir.mkdir(parents=True, exist_ok=True)

model_path = output_dir / "bilirubin_regressor.pt"
torch.save({
    "model_state_dict": model.state_dict(),
    "input_dim": input_dim,
    "hidden_dim": 256,
    "metrics": {
        "mae": round(float(ml_mae), 3),
        "rmse": round(float(ml_rmse), 3),
        "pearson_r": round(float(ml_r), 4),
    },
}, model_path)
print(f"Model saved: {model_path}")

# Save detailed results
results = {
    "method": "MedSigLIP Bilirubin Regressor",
    "hai_def_model": "google/medsiglip-448",
    "architecture": f"Linear({input_dim},256)->ReLU->Dropout(0.3)->Linear(256,1)",
    "loss": "HuberLoss(delta=2.0)",
    "test_mae": round(float(ml_mae), 3),
    "test_rmse": round(float(ml_rmse), 3),
    "test_pearson_r": round(float(ml_r), 4),
    "baseline_color_mae": round(float(color_mae), 3),
    "baseline_color_rmse": round(float(color_rmse), 3),
    "baseline_color_pearson_r": round(float(color_r), 4),
    "bland_altman": {
        "mean_diff": round(float(mean_diff), 3),
        "std_diff": round(float(std_diff), 3),
        "loa_upper": round(float(loa_upper), 3),
        "loa_lower": round(float(loa_lower), 3),
    },
    "epochs_trained": len(history["train_loss"]),
    "test_size": len(test_actual),
}

results_path = output_dir / "bilirubin_regression_results.json"
with open(results_path, "w") as f:
    json.dump(results, f, indent=2)
print(f"Results saved: {results_path}")

## Summary

This notebook demonstrates a **novel application** of the MedSigLIP HAI-DEF model:
continuous bilirubin regression from neonatal skin images.

**Key findings**:
- Frozen MedSigLIP embeddings capture medically relevant features for bilirubin estimation
- A lightweight 2-layer MLP trained on these embeddings outperforms simple color analysis
- The Bland-Altman analysis shows the prediction spread and systematic bias
- This approach enables CHWs to get quantitative bilirubin estimates from phone photos

**Clinical relevance**: Non-invasive bilirubin screening can reduce the need for heel-prick
blood draws in newborns and enable earlier detection of hyperbilirubinemia in
resource-limited settings.