# PINN-PBM Colab Training Notebook
Run the cells in order to train the breakage PINN on Google Colab with GPU acceleration.

## Quick Start
1. **Runtime ▸ Change runtime type ▸ GPU** (required).
2. Run the GPU check cell.
3. Run the setup cell to clone the repo (or pull updates).
4. Install dependencies.
5. Execute the training cell (takes ~6 minutes on T4).
6. Run the evaluation cell to visualize predictions vs. analytical solution.

In [None]:
# Check that Colab sees a GPU
!nvidia-smi

In [None]:
# Clone (or update) the PINN-PBM repository
from pathlib import Path

repo_url = "https://github.com/Glitched404/PINN-PBM.git"
repo_name = "PINN-PBM"
repo_path = Path(repo_name)

if not repo_path.exists():
    !git clone $repo_url
else:
    print("Repository already exists. Pulling latest changes...")
    %cd $repo_name
    !git pull
    %cd ..

%cd $repo_name

In [None]:
# Install dependencies compatible with Colab's Python 3.12 runtime
%pip install -q -r requirements-colab.txt

# Install the project package without pulling additional dependencies
%pip install -e . --no-deps

# Verify key packages after installation
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

print("numpy:", np.__version__)
print("tensorflow:", tf.__version__)
print("tensorflow_probability:", tfp.__version__)

# Ensure no dependency conflicts remain
!pip check

In [None]:
# Ensure project source is on sys.path and run full breakage experiment
from pathlib import Path
import sys

repo_root = Path.cwd()
source_dir = repo_root / "src"
if source_dir.exists() and str(source_dir) not in sys.path:
    sys.path.insert(0, str(source_dir))

from pinn_pbm.breakage.experiments import run_case

# Configure experiment options
CASE_TYPE = "case1"      # "case1", "case2", "case3", "case4"
ADAM_EPOCHS = None        # Set to an int to override default progressive schedule
L_BFGS_BACKEND = "tfp"    # "tfp", "scipy", or "none"
SEED = 42

result = run_case(
    case_type=CASE_TYPE,
    adam_epochs=ADAM_EPOCHS,
    lbfgs=L_BFGS_BACKEND,
    seed=SEED,
    make_plots=True,
    verbose=True,
)

config = result["config"]
pinn = result["pinn"]
losses = result.get("losses")

loss_fig = result["figures"].get("loss")
pred_fig = result["figures"].get("prediction")
if loss_fig is not None:
    display(loss_fig)
if pred_fig is not None:
    display(pred_fig)

print("Adam duration (s):", result["adam_duration_sec"])
print("L-BFGS backend:", result["lbfgs_backend"], result["lbfgs"]) 

if losses:
    print("\nFinal losses:")
    print(f"  total: {losses['total'][-1]:.3e}")
    print(f"  data:  {losses['data'][-1]:.3e}")
    print(f"  phys:  {losses['physics'][-1]:.3e}")

relative_errors = result["relative_errors"]
for t_slice, rel_err in zip(config.t_slices, relative_errors):
    print(f"t={t_slice:.1f} → mean relative error {rel_err:.3e}")

In [None]:
# Evaluate and visualize predictions vs analytical solution
import matplotlib.pyplot as plt

v_plot = np.logspace(np.log10(v_min), np.log10(v_max), 200, dtype=np.float32)
t_eval = np.array([0.0, 2.0, 5.0, 10.0], dtype=np.float32)

f_pred = pinn.predict(v_plot, t_eval)
f_exact = np.array([get_analytical_solution(v_plot, float(t), case_type) for t in t_eval])

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()
colors = plt.cm.viridis(np.linspace(0, 1, len(t_eval)))

for i, (ax, t_val, color) in enumerate(zip(axes, t_eval, colors)):
    ax.semilogx(v_plot, f_exact[i], color=color, lw=2.5, label="Analytical")
    ax.semilogx(v_plot, f_pred[i], 'r--', lw=2.0, label="PINN")
    ax.set_title(f"t = {t_val:.1f}")
    ax.set_xlabel('Volume v')
    ax.set_ylabel('f(v, t)')
    ax.grid(True, which='both', ls='--', alpha=0.3)
    ax.legend()

plt.suptitle('Case 1: PINN vs Analytical Solution')
plt.tight_layout()
plt.show()

# Report mean relative error for each time slice
for i, t_val in enumerate(t_eval):
    rel_err = np.abs((f_pred[i] - f_exact[i]) / (f_exact[i] + 1e-30))
    print(f"t={t_val:.1f}: mean relative error = {rel_err.mean():.3e}")

## Next Steps
- Adjust `num_epochs`, learning rates, and batch sizes for higher accuracy.
- Swap `case_type` to `'case2'`, `'case3'`, or `'case4'` and update training data accordingly.
- Save results to Google Drive or download with `files.download(...)` if needed.