# WaveletDiff TPU Training (Keras 3 / JAX)

This notebook implements the high-performance training pipeline for WaveletDiff using Keras 3 with JAX backend, specifically optimized for Cloud TPUs (v5e/v6e).

### Features:
- **JAX Backend**: Uses XLA compilation for maximum throughput.
- **tf.data Pipeline**: Asynchronous prefetching via `tf.data` (avoids Python threading bottlenecks).
- **Fused Steps**: The entire training step (sampling + diff + loss) is valid XLA graph.


In [None]:
# @title Cell 1: Environment Setup & Cloning
import os
import sys

# 1. Set Backend to JAX (Must be done before importing keras)
os.environ["KERAS_BACKEND"] = "jax"

# 2. Clone Repository
REPO_DIR = "/content/waveletDiff_synth_data"
if not os.path.exists(REPO_DIR):
    !git clone https://github.com/MilesHoffman/waveletDiff_synth_data.git {REPO_DIR}
else:
    print("Repo already exists. Pulling latest...")
    !cd {REPO_DIR} && git pull

# 3. Dependencies
!pip install keras --upgrade  # Ensure Keras 3
!pip install pywavelets

# 4. Path Setup
if REPO_DIR not in sys.path:
    sys.path.append(f"{REPO_DIR}/src")
    print(f"Added {REPO_DIR}/src to path")

import keras
import jax

print(f"Keras version: {keras.__version__}")
print(f"Backend: {keras.config.backend()}")

# --- Hardware Optimization ---
try:
    devices = jax.devices()
    device_type = devices[0].platform.upper()
    print(f"Hardware Detected: {device_type} (Count: {len(devices)})")
    
    if device_type == 'TPU':
        keras.mixed_precision.set_global_policy("mixed_bfloat16")
        print("✅ Optimization: Precision set to 'mixed_bfloat16' (TPU Native)")
    elif device_type == 'GPU':
        keras.mixed_precision.set_global_policy("mixed_float16")
        print("✅ Optimization: Precision set to 'mixed_float16'")
except Exception as e:
    print(f"Hardware detection failed: {e}")

In [None]:
# @title Cell 2: Configuration

# Dataset
DATA_PATH = f"{REPO_DIR}/src/copied_waveletDiff/data/stocks/stock_data.csv"
SEQ_LEN = 24
BATCH_SIZE = 4096 # Large batch for TPU efficiency

# Model
EMBED_DIM = 256
NUM_HEADS = 8
NUM_LAYERS = 8
DROPOUT = 0.1
LEARNING_RATE = 2e-4

In [None]:
# @title Cell 3: Load Data & Initialize Model
from w_keras import data as kdata
from w_keras import transformer as ktrans
import keras

# 1. Load Data (tf.data Pipeline)
# Note: The first time this runs, it will process the CSV into wavelets.
ds, info = kdata.load_dataset(DATA_PATH, BATCH_SIZE, SEQ_LEN)

print(f"Wavelet Info: {info['level_dims']} coeffs per level")

# 2. Initialize Model
model_config = {
    'embed_dim': EMBED_DIM,
    'num_heads': NUM_HEADS,
    'num_layers': NUM_LAYERS,
    'dropout': DROPOUT,
    'prediction_target': 'noise'
}

model = ktrans.WaveletDiffusionTransformer(info, model_config)

# 3. Compile (XLA Just-In-Time)
optimizer = keras.optimizers.AdamW(learning_rate=LEARNING_RATE, weight_decay=1e-5)
model.compile(optimizer=optimizer, jit_compile=True)

# Dummy build to print summary
# Keras models are built lazily, so we pass one batch to shape it
for batch in ds.take(1):
    model.predict(batch[0], verbose=0)
model.summary()

In [None]:
# @title Cell 4: Train
history = model.fit(
    ds,
    epochs=50,
    steps_per_epoch=200 # Adjust based on dataset size
)