# PINN for Cancer Signaling: Full-Range Training Demo

This notebook demonstrates how to:
1. Train a PINN on the full time range [0,1,4,8,24,48]hrs
2. Visualize training fit across the full timeline
3. Use the trained model for downstream inference

**Drug Combination**: Vemurafenib (0.5) + Trametinib (0.3)


In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from train_pinn import train_pinn
from visualize_extrapolation import (
    plot_extrapolation_results, 
    plot_training_history, 
    generate_prediction_table,
    load_pinn_with_data
)
from data_utils import TRAINING_DATA_RAW, SPECIES_ORDER
import os

%matplotlib inline

## Step 1: Configure Training Parameters

The model will train on all available time points (0–48h).


In [None]:
config = {
    'train_until_hour': 48,
    'num_epochs': 10000,
    'learning_rate': 0.001,
    'lr_decay': 0.95,
    'batch_size': 6,  # 6 training time points
    'hidden_size': 100,
    'num_physics_points': 100,
    'weight_decay': 1e-5,
    'weights': {
        'data': 1.0,
        'physics': 0.5,
        'boundary': 0.3,
        'conservation': 0.2
    }
}

print("Training Configuration:")
print("  Training on: [0, 1, 4, 8, 24, 48] hours")
print("  Testing on: None (full-range training)")
print(f"  Epochs: {config['num_epochs']}")
print(f"  Physics weight: {config['weights']['physics']}")


## Step 2: Train the Model

This will take several minutes. The progress bar shows train loss (test loss is N/A for full-range training).


In [None]:
# Only run if you want to train from scratch
# Otherwise, skip to the next cell to load a pre-trained model

print("="*60)
print("TRAINING: Vem+Tram at [0,1,4,8,24,48]hrs")
print("TESTING: No held-out time points (full data training)")
print("="*60)

model, k_params, history, scalers, train_data, test_data = train_pinn(config)


## Step 3: Load Trained Model

Load the best checkpoint (either from training above or a pre-trained model).

In [None]:
if os.path.exists('pinn_model_best.pth'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model, scalers, train_data, test_data = load_pinn_with_data('pinn_model_best.pth', device)
    print("✓ Model loaded successfully!")
    print(f"  Device: {device}")
    print(f"  Training points: {train_data['t']}")
    if test_data is not None and len(test_data['t']) > 0:
        print(f"  Test points: {test_data['t']}")
    else:
        print("  Test points: None (full-range training)")
else:
    print("⚠ No trained model found. Run the training cell above first.")


## Step 4: Visualize Training Fit

This creates a comprehensive plot showing model predictions and training data across the full 0–48h range.


In [None]:
plot_extrapolation_results()
plt.show()

## Step 5: Training History Analysis

Track how loss evolved over epochs. Test loss will be NaN for full-range training.


In [None]:
plot_training_history()
plt.show()

# Show final losses
history_df = pd.read_csv('training_history.csv')
final_epoch = history_df.iloc[-1]
print("\nFinal Performance:")
print(f"  Train Loss: {final_epoch['l_data']:.6f}")
if pd.notna(final_epoch['l_test']):
    print(f"  Test Loss: {final_epoch['l_test']:.6f}")
    print(f"  Test/Train Ratio: {final_epoch['l_test']/final_epoch['l_data']:.2f}")
else:
    print("  Test Loss: N/A (full-range training)")


## Step 6: Detailed Predictions Table

Generate a table with predicted vs. true values for all time points and species.

In [None]:
predictions_df = generate_prediction_table()

# Show predictions at test time points only if they exist
test_predictions = predictions_df[predictions_df['Dataset'] == 'Test']
if not test_predictions.empty:
    print("\nTest Set Predictions (24 and 48 hours):")
    print(test_predictions[['Time (hrs)', 'Species', 'True Value', 'Predicted Value', 'Percent Error']].to_string(index=False))
else:
    print("\nNo held-out test points (full-range training).")


## Step 7: Species-Specific Analysis

Inspect per-species fit across the timeline (test points are optional).


In [None]:
# Compute R² for each species on test set if available
test_r2 = {}

if not test_predictions.empty:
    for species in SPECIES_ORDER:
        species_data = test_predictions[test_predictions['Species'] == species]
        y_true = species_data['True Value'].values
        y_pred = species_data['Predicted Value'].values
        
        ss_res = np.sum((y_true - y_pred)**2)
        ss_tot = np.sum((y_true - np.mean(y_true))**2)
        r2 = 1 - (ss_res / (ss_tot + 1e-8))
        test_r2[species] = r2

    # Sort by extrapolation quality
    r2_df = pd.DataFrame(list(test_r2.items()), columns=['Species', 'Test R²'])
    r2_df = r2_df.sort_values('Test R²', ascending=False)

    print("\nExtrapolation Quality by Species:")
    print(r2_df.to_string(index=False))

    # Bar plot
    plt.figure(figsize=(10, 6))
    colors = ['green' if r2 > 0.8 else 'orange' if r2 > 0.6 else 'red' for r2 in r2_df['Test R²']]
    plt.bar(r2_df['Species'], r2_df['Test R²'], color=colors)
    plt.xticks(rotation=45)
    plt.ylabel('Test R²')
    plt.title('Extrapolation Quality by Species')
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.show()
else:
    print("\nNo held-out test points (full-range training).")


## Step 8: Biological Interpretation

Check if the model learned expected dynamics.

In [None]:
# Load checkpoint to inspect learned rate constants
checkpoint = torch.load('pinn_model_best.pth', map_location='cpu')
k_params = checkpoint['k_params_state_dict']

print("Learned Kinetic Parameters:")
for name, value in k_params.items():
    print(f"  {name}: {value.item():.4f}")

# Expected dynamics from biology
print("\n" + "="*60)
print("Expected Biological Behavior (Vem+Tram):")
print("="*60)
print("✓ pMEK: Should be LOW (Trametinib inhibits MEK)")
print("✓ pERK: Should be LOW (blocked MEK → blocked ERK)")
print("✓ DUSP6: Should DECREASE over time (less ERK activation)")
print("✓ pAKT: Should INCREASE (compensatory PI3K activation)")
print("✓ HER3: May INCREASE (feedback receptor upregulation)")

# Check if predictions match
pMEK_48 = predictions_df[(predictions_df['Time (hrs)'] == 48) & (predictions_df['Species'] == 'pMEK')]['Predicted Value'].values[0]
pERK_48 = predictions_df[(predictions_df['Time (hrs)'] == 48) & (predictions_df['Species'] == 'pERK')]['Predicted Value'].values[0]
pAKT_48 = predictions_df[(predictions_df['Time (hrs)'] == 48) & (predictions_df['Species'] == 'pAKT')]['Predicted Value'].values[0]

print("\nModel Predictions at 48hrs:")
print(f"  pMEK: {pMEK_48:.3f} (expect LOW)")
print(f"  pERK: {pERK_48:.3f} (expect LOW)")
print(f"  pAKT: {pAKT_48:.3f} (expect ELEVATED)")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# 1. Load data from the model checkpoint
from visualize_extrapolation import load_pinn_with_data, SPECIES_ORDER, TRAINING_DATA_RAW

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, scalers, train_data, test_data = load_pinn_with_data('pinn_model_best.pth', device)
has_test_data = test_data is not None and len(test_data['t']) > 0

# 2. Generate smooth predictions across 0-48h
t_smooth = np.linspace(0, 48, 500)
drugs_raw = TRAINING_DATA_RAW['drugs']
y_smooth = model.predict(t_smooth, drugs_raw, scalers, device)

# 3. Plotting grid
fig, axes = plt.subplots(4, 3, figsize=(18, 16))
axes = axes.flatten()

for i, species in enumerate(SPECIES_ORDER):
    ax = axes[i]
    
    # Plot smooth PINN curve
    ax.plot(t_smooth, y_smooth[:, i], label='PINN Prediction', color='blue', linewidth=2)
    
    # Plot Training points
    ax.scatter(train_data['t'], train_data['y'][:, i],
               color='green', marker='o', s=80, label='Train Data (Input)', zorder=5)
    
    # Plot Test points if present
    if has_test_data:
        ax.scatter(test_data['t'], test_data['y'][:, i],
                   color='red', marker='s', s=80, label='Test Data (Hidden)', zorder=5)
        ax.axvline(x=8, color='black', linestyle='--', alpha=0.3)
    
    ax.set_title(f"Dynamic Fit: {species}", fontsize=14, fontweight='bold')
    ax.set_xlabel("Time (hours)")
    ax.set_ylabel("Concentration (a.u.)")
    ax.grid(alpha=0.2)
    if i == 0: ax.legend()

# Hide unused plot
axes[-1].axis('off')
plt.tight_layout()
plt.show()


## Summary

### What Did We Learn?

1. **Full-Range Training**: The model was trained on all available time points (0–48h).
2. **Physics Constraints**: ODE regularization helps the model learn biologically plausible dynamics.
3. **Training Fit Quality**: R² across training points indicates how well the model captures dynamics.

### Next Steps

- Adjust `config['weights']['physics']` to see how physics strength affects the fit
- Use the trained model for new drug combinations (e.g., PI3Ki + Vemurafenib)
- Run multiple training runs with different random seeds for uncertainty quantification
