<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/gravitational_wave_nn_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install numpy matplotlib tensorflow

In [None]:
#!/usr/bin/env python3
# filename: gw_regression_keras.py

import os
import argparse
import json
import random
import numpy as np

# Quieter logs + more deterministic behavior (set before TF import)
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
os.environ.setdefault("TF_DETERMINISTIC_OPS", "1")

import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint


def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)


def generate_gravitational_wave_data(samples: int = 1000, random_state: int = 42, noise: float = 0.0):
    """
    Simulated inputs: (Mass, Distance, Energy) in [0, 1].
    Target: Y = sin(Mass) + cos(Distance) * exp(-Energy) + noise
    """
    rng = np.random.default_rng(random_state)
    X = rng.random((samples, 3), dtype=np.float32)
    Y = np.sin(X[:, 0]) + np.cos(X[:, 1]) * np.exp(-X[:, 2])
    if noise > 0.0:
        Y = Y + rng.normal(0.0, noise, size=samples)
    Y = Y.astype(np.float32)
    return X, Y.reshape(-1, 1)  # (N, 3), (N, 1)


def make_datasets(X, Y, val_ratio: float, batch_size: int, seed: int = 42):
    n = X.shape[0]
    idx = np.arange(n)
    rng = np.random.default_rng(seed)
    rng.shuffle(idx)

    n_val = max(1, int(val_ratio * n))
    val_idx = idx[:n_val]
    train_idx = idx[n_val:]

    X_train, Y_train = X[train_idx], Y[train_idx]
    X_val, Y_val = X[val_idx], Y[val_idx]

    train_ds = (
        tf.data.Dataset.from_tensor_slices((X_train, Y_train))
        .shuffle(buffer_size=X_train.shape[0], seed=seed, reshuffle_each_iteration=True)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )
    val_ds = (
        tf.data.Dataset.from_tensor_slices((X_val, Y_val))
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )
    return train_ds, val_ds, (X_train, Y_train, X_val, Y_val)


def build_model(hidden1: int = 64, hidden2: int = 32, activation: str = "relu", learning_rate: float = 1e-3):
    model = Sequential(
        [
            Dense(hidden1, activation=activation, input_shape=(3,)),
            Dense(hidden2, activation=activation),
            Dense(1, activation="linear"),  # regression
        ]
    )
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), loss="mse", metrics=["mae"])
    return model


def ensure_outdir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def save_artifacts(outdir: str, model: tf.keras.Model, config: dict, history: dict, metrics: dict, save_model: bool):
    ensure_outdir(outdir)
    if save_model:
        # Save the final model (Keras v3 native format)
        final_path = os.path.join(outdir, "model.keras")
        model.save(final_path)
    with open(os.path.join(outdir, "run_summary.json"), "w") as f:
        json.dump({"config": config, "history": history, "metrics": metrics}, f, indent=2)


def str2bool(v: str) -> bool:
    if isinstance(v, bool):
        return v
    v = v.strip().lower()
    if v in ("yes", "true", "t", "1", "y"):
        return True
    if v in ("no", "false", "f", "0", "n"):
        return False
    raise argparse.ArgumentTypeError(f"Boolean value expected, got: {v}")


def parse_args():
    parser = argparse.ArgumentParser(
        description="Keras regression on synthetic gravitational wave toy data (Jupyter/Colab-safe).",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("--samples", type=int, default=1000, help="Number of synthetic samples.")
    parser.add_argument("--epochs", type=int, default=100, help="Max training epochs.")
    parser.add_argument("--batch_size", type=int, default=32, help="Mini-batch size.")
    parser.add_argument("--val_ratio", type=float, default=0.2, help="Validation split ratio.")
    parser.add_argument("--noise", type=float, default=0.0, help="Stddev of Gaussian noise on target.")
    parser.add_argument("--learning_rate", type=float, default=1e-3, help="Adam learning rate.")
    parser.add_argument("--hidden1", type=int, default=64, help="Units in first hidden layer.")
    parser.add_argument("--hidden2", type=int, default=32, help="Units in second hidden layer.")
    parser.add_argument("--activation", type=str, default="relu", help="Hidden activation.")
    parser.add_argument("--patience", type=int, default=20, help="Early stopping patience (epochs).")
    parser.add_argument("--random_state", type=int, default=42, help="Random seed.")
    parser.add_argument("--outdir", type=str, default="outputs_tf", help="Directory to save artifacts.")
    parser.add_argument("--save_model", type=str, default="true", help="Whether to save model artifacts (true/false).")
    parser.add_argument(
        "--test_input",
        type=float,
        nargs=3,
        metavar=("MASS", "DIST", "ENERGY"),
        help="Three feature values to run a single prediction after training.",
    )
    # Important: ignore unknown args (e.g., -f <kernel.json> from notebooks)
    args, _ = parser.parse_known_args()
    args.save_model = str2bool(args.save_model)
    return args


def main():
    args = parse_args()
    set_seed(args.random_state)

    # Data
    X, Y = generate_gravitational_wave_data(
        samples=args.samples, random_state=args.random_state, noise=args.noise
    )
    train_ds, val_ds, (X_train, Y_train, X_val, Y_val) = make_datasets(
        X, Y, val_ratio=args.val_ratio, batch_size=args.batch_size, seed=args.random_state
    )

    # Model
    model = build_model(
        hidden1=args.hidden1,
        hidden2=args.hidden2,
        activation=args.activation,
        learning_rate=args.learning_rate,
    )
    model.summary()

    # Callbacks
    cbs = [
        EarlyStopping(monitor="val_loss", patience=args.patience, restore_best_weights=True, verbose=1),
        ReduceLROnPlateau(
            monitor="val_loss",
            factor=0.5,
            patience=max(5, args.patience // 2),
            min_lr=1e-6,
            verbose=1,
        ),
    ]
    if args.save_model:
        ensure_outdir(args.outdir)
        cbs.append(
            ModelCheckpoint(
                filepath=os.path.join(args.outdir, "best.keras"),
                monitor="val_loss",
                save_best_only=True,
                save_weights_only=False,
                verbose=1,
            )
        )

    # Train
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=args.epochs,
        verbose=1,
        callbacks=cbs,
    )

    # Evaluate
    val_loss, val_mae = model.evaluate(val_ds, verbose=0)
    print(f"Validation — MSE: {val_loss:.6f}, MAE: {val_mae:.6f}")

    # Optional single prediction
    test_pred = None
    if args.test_input is not None:
        test_vec = np.array(args.test_input, dtype=np.float32).reshape(1, 3)
        pred = model.predict(test_vec, verbose=0).item()
        test_pred = {"input": list(map(float, args.test_input)), "prediction": float(pred)}
        print(f"Test input: {test_pred['input']} -> predicted distortion: {test_pred['prediction']:.6f}")

    # Save artifacts
    config = {
        "samples": args.samples,
        "epochs": args.epochs,
        "batch_size": args.batch_size,
        "val_ratio": args.val_ratio,
        "noise": args.noise,
        "learning_rate": args.learning_rate,
        "hidden1": args.hidden1,
        "hidden2": args.hidden2,
        "activation": args.activation,
        "patience": args.patience,
        "random_state": args.random_state,
        "outdir": args.outdir,
    }
    metrics = {"val_mse": float(val_loss), "val_mae": float(val_mae)}
    if test_pred is not None:
        metrics["test_prediction"] = test_pred

    save_artifacts(args.outdir, model, config=config, history=history.history, metrics=metrics, save_model=args.save_model)
    print(f"Artifacts saved to: {args.outdir}")
    print("Done.")


if __name__ == "__main__":
    main()