# PINN-PBM Colab Training Notebook
Run the cells in order to train the breakage PINN on Google Colab with GPU acceleration.

## Quick Start
1. **Runtime ▸ Change runtime type ▸ GPU** (required).
2. Run the GPU check cell.
3. Run the setup cell to clone the repo (or pull updates).
4. Install dependencies.
5. Execute the training cell (takes ~6 minutes on T4).
6. Run the evaluation cell to visualize predictions vs. analytical solution.

In [None]:
# Check that Colab sees a GPU
!nvidia-smi

In [None]:
# Clone (or update) the PINN-PBM repository
from pathlib import Path

repo_url = "https://github.com/Glitched404/PINN-PBM.git"
repo_name = "PINN-PBM"
repo_path = Path(repo_name)

if not repo_path.exists():
    !git clone $repo_url
else:
    print("Repository already exists. Pulling latest changes...")
    %cd $repo_name
    !git pull
    %cd ..

%cd $repo_name

In [None]:
# Install project dependencies (Python 3.11 compatible stack)
%pip install -q \
    tensorflow==2.17.0 \
    tensorflow-probability==0.25.0 \
    numpy==1.26.4 \
    scipy==1.12.0 \
    matplotlib==3.8.2 \
    typing-extensions>=4.7.0,<5.0.0 \
    tqdm>=4.66.0,<5.0.0 \
    PyYAML>=6.0.0,<7.0.0 \
    pytest>=7,<9 \
    pytest-cov>=4,<5

# Verify key packages after installation
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

print("numpy:", np.__version__)
print("tensorflow:", tf.__version__)
print("tensorflow_probability:", tfp.__version__)

# Ensure no dependency conflicts remain
!pip check

In [None]:
# Train BreakagePINN on Case 1 for a short run
import numpy as np
import tensorflow as tf
from tqdm.notebook import trange

from pinn_pbm.breakage.models import BreakagePINN
from pinn_pbm.breakage.solutions import get_analytical_solution
from pinn_pbm.core.utils import set_random_seed, configure_gpu_memory_growth

set_random_seed(42)
configure_gpu_memory_growth()

# Domain configuration
v_min, v_max = 1e-3, 10.0
t_min, t_max = 0.0, 10.0
case_type = "case1"

# Generate supervised data from the analytical solution
v_points = np.logspace(np.log10(v_min), np.log10(v_max), 41, dtype=np.float32)
t_slices = np.array([0.0, 2.0, 5.0, 10.0], dtype=np.float32)
V_grid, T_grid = np.meshgrid(v_points, t_slices)
F_grid = get_analytical_solution(V_grid, T_grid, case_type)

v_train = V_grid.flatten().astype(np.float32)
t_train = T_grid.flatten().astype(np.float32)
f_train = F_grid.flatten().astype(np.float32)
num_data = v_train.shape[0]

# Collocation candidates for physics loss
rng = np.random.default_rng(42)
n_colloc_candidates = 20000
v_colloc_candidates = np.exp(
    rng.uniform(np.log(v_min), np.log(v_max), n_colloc_candidates)
).astype(np.float32)
t_colloc_candidates = rng.uniform(t_min + 1e-3, t_max, n_colloc_candidates).astype(np.float32)

# Instantiate the PINN
pinn = BreakagePINN(
    v_min=v_min,
    v_max=v_max,
    t_min=t_min,
    t_max=t_max,
    case_type=case_type,
    n_hidden_layers=8,
    n_neurons=128,
)

optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)

num_epochs = 500
batch_size_data = 512
batch_size_phys = 1024
target_w_phys = 0.01
ramp_epochs = 200

history = []

for epoch in trange(num_epochs, desc="Training"):
    data_idx = rng.integers(0, num_data, size=batch_size_data)
    v_data = tf.constant(v_train[data_idx])
    t_data = tf.constant(t_train[data_idx])
    f_data = tf.constant(f_train[data_idx][:, None])

    colloc_idx = rng.integers(0, n_colloc_candidates, size=batch_size_phys)
    v_phys = tf.constant(v_colloc_candidates[colloc_idx])
    t_phys = tf.constant(t_colloc_candidates[colloc_idx])

    w_phys_value = target_w_phys * min((epoch + 1) / ramp_epochs, 1.0)

    total_loss, data_loss, phys_loss = pinn.train_step(
        v_data=v_data,
        t_data=t_data,
        f_data=f_data,
        v_physics=v_phys,
        t_physics=t_phys,
        w_data=tf.constant(1.0, dtype=tf.float32),
        w_physics=tf.constant(w_phys_value, dtype=tf.float32),
        optimizer=optimizer,
    )

    history.append((
        float(total_loss.numpy()),
        float(data_loss.numpy()),
        float(phys_loss.numpy()),
        w_phys_value,
    ))

    if (epoch + 1) % 50 == 0:
        total, data, phys, w_val = history[-1]
        print(f"Epoch {epoch + 1:4d}: total={total:.3e}, data={data:.3e}, phys={phys:.3e}, w_phys={w_val:.4f}")

print("Training complete!")

In [None]:
# Evaluate and visualize predictions vs analytical solution
import matplotlib.pyplot as plt

v_plot = np.logspace(np.log10(v_min), np.log10(v_max), 200, dtype=np.float32)
t_eval = np.array([0.0, 2.0, 5.0, 10.0], dtype=np.float32)

f_pred = pinn.predict(v_plot, t_eval)
f_exact = np.array([get_analytical_solution(v_plot, float(t), case_type) for t in t_eval])

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()
colors = plt.cm.viridis(np.linspace(0, 1, len(t_eval)))

for i, (ax, t_val, color) in enumerate(zip(axes, t_eval, colors)):
    ax.semilogx(v_plot, f_exact[i], color=color, lw=2.5, label="Analytical")
    ax.semilogx(v_plot, f_pred[i], 'r--', lw=2.0, label="PINN")
    ax.set_title(f"t = {t_val:.1f}")
    ax.set_xlabel('Volume v')
    ax.set_ylabel('f(v, t)')
    ax.grid(True, which='both', ls='--', alpha=0.3)
    ax.legend()

plt.suptitle('Case 1: PINN vs Analytical Solution')
plt.tight_layout()
plt.show()

# Report mean relative error for each time slice
for i, t_val in enumerate(t_eval):
    rel_err = np.abs((f_pred[i] - f_exact[i]) / (f_exact[i] + 1e-30))
    print(f"t={t_val:.1f}: mean relative error = {rel_err.mean():.3e}")

## Next Steps
- Adjust `num_epochs`, learning rates, and batch sizes for higher accuracy.
- Swap `case_type` to `'case2'`, `'case3'`, or `'case4'` and update training data accordingly.
- Save results to Google Drive or download with `files.download(...)` if needed.