<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/exotic_ai_trainer_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
# exotic_ai_trainer.py

import os
import sys
import json
import time
import argparse
from dataclasses import dataclass, asdict
from typing import Tuple, Optional

import numpy as np
import pandas as pd
import tensorflow as tf


# -----------------------------
# Reproducibility
# -----------------------------
def set_seeds(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)


# -----------------------------
# Data loaders
# -----------------------------
def load_mnist(kind: str = "mnist") -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int]:
    if kind == "mnist":
        (x_train, y_train), (x_val, y_val) = tf.keras.datasets.mnist.load_data()
    elif kind == "fashion":
        (x_train, y_train), (x_val, y_val) = tf.keras.datasets.fashion_mnist.load_data()
    else:
        raise ValueError("Unsupported dataset. Use 'mnist' or 'fashion'.")
    x_train = x_train.reshape((-1, 784)).astype(np.float32) / 255.0
    x_val = x_val.reshape((-1, 784)).astype(np.float32) / 255.0
    return x_train, y_train, x_val, y_val, 784, 10


def load_synthetic(
    n_train=5000,
    n_val=1000,
    input_dim=30,
    num_classes=10,
    clusters=20,
    seed=42
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int]:
    rng = np.random.default_rng(seed)
    centers = rng.normal(0.0, 3.0, size=(clusters, input_dim))
    # Train
    Xt, yt = [], []
    for k in range(clusters):
        n_k = n_train // clusters + (1 if k < n_train % clusters else 0)
        cov = rng.uniform(0.3, 1.5)
        Xk = centers[k] + rng.normal(0, cov, size=(n_k, input_dim))
        yk = np.full(n_k, k % num_classes, dtype=np.int32)
        Xt.append(Xk); yt.append(yk)
    x_train = np.vstack(Xt).astype(np.float32)
    y_train = np.concatenate(yt).astype(np.int32)
    # Val
    Xv, yv = [], []
    for k in range(clusters):
        n_k = n_val // clusters + (1 if k < n_val % clusters else 0)
        cov = rng.uniform(0.3, 1.5)
        Xk = centers[k] + rng.normal(0, cov, size=(n_k, input_dim))
        yk = np.full(n_k, k % num_classes, dtype=np.int32)
        Xv.append(Xk); yv.append(yk)
    x_val = np.vstack(Xv).astype(np.float32)
    y_val = np.concatenate(yv).astype(np.int32)
    # Shuffle
    p = rng.permutation(len(x_train)); x_train, y_train = x_train[p], y_train[p]
    p = rng.permutation(len(x_val));   x_val, y_val   = x_val[p], y_val[p]
    return x_train, y_train, x_val, y_val, input_dim, num_classes


def make_tf_dataset(x, y, batch_size=128, shuffle=True, seed=42):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if shuffle:
        ds = ds.shuffle(min(len(x), 10000), seed=seed, reshuffle_each_iteration=True)
    return ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)


# -----------------------------
# Model builder
# -----------------------------
def build_exotic_model(
    input_dim: int,
    num_classes: int = 10,
    neg_units: int = 256,
    hyper_units: int = 512,
    negative_activation: str = "tanh",
    hyper_activation: str = "relu",
    dropout: float = 0.3,
    batchnorm: bool = True,
    l2_reg: float = 1e-4,
) -> tf.keras.Model:
    reg = tf.keras.regularizers.l2(l2_reg) if l2_reg and l2_reg > 0 else None

    inputs = tf.keras.Input(shape=(input_dim,), name="inputs")

    # "Negative-energy" layer
    x = tf.keras.layers.Dense(neg_units, use_bias=not batchnorm, kernel_regularizer=reg, name="neg_dense")(inputs)
    if batchnorm:
        x = tf.keras.layers.BatchNormalization(name="neg_bn")(x)
    x = tf.keras.layers.Activation(negative_activation, name="neg_act")(x)
    if dropout and dropout > 0:
        x = tf.keras.layers.Dropout(dropout, name="neg_dropout")(x)

    # "Hyperdimensional" layer
    x = tf.keras.layers.Dense(hyper_units, use_bias=not batchnorm, kernel_regularizer=reg, name="hyper_dense")(x)
    if batchnorm:
        x = tf.keras.layers.BatchNormalization(name="hyper_bn")(x)
    x = tf.keras.layers.Activation(hyper_activation, name="hyper_act")(x)
    if dropout and dropout > 0:
        x = tf.keras.layers.Dropout(dropout, name="hyper_dropout")(x)

    outputs = tf.keras.layers.Dense(num_classes, activation="softmax", kernel_regularizer=reg, name="logits")(x)
    return tf.keras.Model(inputs=inputs, outputs=outputs, name="ExoticAI")


# -----------------------------
# Config
# -----------------------------
@dataclass
class Config:
    dataset: str = "synthetic"        # synthetic | mnist | fashion
    input_dim: int = 30               # used only for synthetic
    num_classes: int = 10
    epochs: int = 15
    batch_size: int = 128
    lr: float = 1e-3
    dropout: float = 0.3
    batchnorm: bool = True
    l2_reg: float = 1e-4
    neg_units: int = 256
    hyper_units: int = 512
    negative_activation: str = "tanh"
    hyper_activation: str = "relu"
    patience: int = 5
    reduce_lr_patience: int = 3
    seed: int = 42
    results_dir: str = "results"
    tag: Optional[str] = None
    save_tflite: bool = False
    save_onnx: bool = False


def get_run_dir(cfg: Config) -> str:
    ts = time.strftime("%Y%m%d-%H%M%S")
    tag = cfg.tag or cfg.dataset
    run_dir = os.path.join(cfg.results_dir, f"{ts}_{tag}")
    os.makedirs(run_dir, exist_ok=True)
    return run_dir


# -----------------------------
# Training
# -----------------------------
def train(cfg: Config):
    set_seeds(cfg.seed)

    # Data
    if cfg.dataset in ("mnist", "fashion"):
        x_train, y_train, x_val, y_val, input_dim, num_classes = load_mnist(cfg.dataset)
    elif cfg.dataset == "synthetic":
        x_train, y_train, x_val, y_val, input_dim, num_classes = load_synthetic(
            n_train=5000, n_val=1000, input_dim=cfg.input_dim, num_classes=cfg.num_classes, seed=cfg.seed
        )
    else:
        raise ValueError("dataset must be one of: synthetic | mnist | fashion")

    cfg.input_dim = input_dim
    cfg.num_classes = num_classes

    ds_train = make_tf_dataset(x_train, y_train, batch_size=cfg.batch_size, shuffle=True, seed=cfg.seed)
    ds_val   = make_tf_dataset(x_val, y_val, batch_size=cfg.batch_size, shuffle=False)

    # Model
    model = build_exotic_model(
        input_dim=input_dim,
        num_classes=num_classes,
        neg_units=cfg.neg_units,
        hyper_units=cfg.hyper_units,
        negative_activation=cfg.negative_activation,
        hyper_activation=cfg.hyper_activation,
        dropout=cfg.dropout,
        batchnorm=cfg.batchnorm,
        l2_reg=cfg.l2_reg,
    )

    optimizer = tf.keras.optimizers.Adam(learning_rate=cfg.lr)
    model.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])

    # Run directory and callbacks
    run_dir = get_run_dir(cfg)
    ckpt_path = os.path.join(run_dir, "best_model.keras")  # Keras 3 native format
    callbacks = [
        tf.keras.callbacks.EarlyStopping(patience=cfg.patience, restore_best_weights=True, monitor="val_loss"),
        tf.keras.callbacks.ReduceLROnPlateau(patience=cfg.reduce_lr_patience, factor=0.5, min_lr=1e-6),
        tf.keras.callbacks.ModelCheckpoint(ckpt_path, monitor="val_accuracy", save_best_only=True),
        tf.keras.callbacks.CSVLogger(os.path.join(run_dir, "training_log.csv")),
    ]

    # Train
    history = model.fit(ds_train, validation_data=ds_val, epochs=cfg.epochs, callbacks=callbacks, verbose=1)

    # Save artifacts
    pd.DataFrame(history.history).to_csv(os.path.join(run_dir, "history.csv"), index=False)
    with open(os.path.join(run_dir, "config.json"), "w") as f:
        json.dump(asdict(cfg), f, indent=2)

    # Save final model (native Keras format)
    final_model_path = os.path.join(run_dir, "final_model.keras")
    model.save(final_model_path)

    # Export SavedModel for serving/TFLite
    savedmodel_dir = os.path.join(run_dir, "savedmodel")
    model.export(savedmodel_dir)

    # Optional: TFLite
    if cfg.save_tflite:
        try:
            converter = tf.lite.TFLiteConverter.from_saved_model(savedmodel_dir)
            tflite_model = converter.convert()
            with open(os.path.join(run_dir, "model.tflite"), "wb") as f:
                f.write(tflite_model)
        except Exception as e:
            print(f"[WARN] TFLite export failed: {e}")

    # Optional: ONNX (requires tf2onnx)
    if cfg.save_onnx:
        try:
            import tf2onnx
            onnx_path = os.path.join(run_dir, "model.onnx")
            spec = (tf.TensorSpec((None, input_dim), tf.float32, name="inputs"),)
            model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, output_path=onnx_path)
        except Exception as e:
            print(f"[WARN] ONNX export failed (install tf2onnx?): {e}")

    # Final evaluate
    val_metrics = model.evaluate(ds_val, verbose=0)
    metrics = dict(zip(model.metrics_names, val_metrics))
    with open(os.path.join(run_dir, "val_metrics.json"), "w") as f:
        json.dump(metrics, f, indent=2)

    print(f"Run directory: {run_dir}")
    print("Validation metrics:", metrics)
    return model, history, run_dir


# -----------------------------
# CLI (notebook-safe)
# -----------------------------
def build_arg_parser():
    p = argparse.ArgumentParser(description="Train ExoticAI with regularization and robust exports.")
    p.add_argument("--dataset", choices=["synthetic", "mnist", "fashion"], default="synthetic")
    p.add_argument("--input_dim", type=int, default=30)
    p.add_argument("--num_classes", type=int, default=10)
    p.add_argument("--epochs", type=int, default=15)
    p.add_argument("--batch_size", type=int, default=128)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--dropout", type=float, default=0.3)
    p.add_argument("--batchnorm", action="store_true")
    p.add_argument("--l2_reg", type=float, default=1e-4)
    p.add_argument("--neg_units", type=int, default=256)
    p.add_argument("--hyper_units", type=int, default=512)
    p.add_argument("--negative_activation", type=str, default="tanh")
    p.add_argument("--hyper_activation", type=str, default="relu")
    p.add_argument("--patience", type=int, default=5)
    p.add_argument("--reduce_lr_patience", type=int, default=3)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--results_dir", type=str, default="results")
    p.add_argument("--tag", type=str, default=None)
    p.add_argument("--save_tflite", action="store_true")
    p.add_argument("--save_onnx", action="store_true")
    return p


def main():
    parser = build_arg_parser()
    # Notebook-safe: ignore stray Jupyter args like "-f kernel.json"
    args, _unknown = parser.parse_known_args()
    cfg = Config(**vars(args))
    train(cfg)


if __name__ == "__main__":
    main()