In [None]:
# @title Global Configuration
# Hyperparameters
BATCH_SIZE = 128
LEARNING_RATE = 2e-4
EPOCHS = 1000
NUM_STEPS = 1000  # Diffusion timesteps
MODEL_DIM = 256
NUM_LAYERS = 2

# Data Configuration
DATASET_NAME = "stocks"  # stocks, fmri, ett, eeg
SEQ_LEN = 128
WAVELET_TYPE = "db4"
NUM_WAVELET_LEVELS = 3

# Path Configuration
REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data"
PROJECT_ROOT = "/content/waveletDiff_synth_data"
DATA_DIR = "/content/data"

In [None]:
# @title Imports & Environment Setup
import os
import sys

# 1. Setup Keras/JAX Backend
os.environ["KERAS_BACKEND"] = "jax"

# 2. Repository Management
if not os.path.exists(PROJECT_ROOT):
    !git clone {REPO_URL}

# Ensure project root and src are in path
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

# 3. Dependency Installation
!pip install -q keras-core flax optax pywt numpy lightning torch

import jax
import jax.numpy as jnp
import keras_core as keras
import numpy as np

print(f"Using Backend: {jax.lib.xla_bridge.get_backend().platform}")
print(f"Devices: {jax.devices()}")

In [None]:
# @title Production Data Orchestration
from src.data.module import WaveletTimeSeriesDataModule
from src.tpu_keras.data_bridge import JAXDataBridge

# Initialize the original DataModule
config = {
    'dataset': {'name': DATASET_NAME, 'seq_len': SEQ_LEN},
    'training': {'batch_size': BATCH_SIZE},
    'data': {'data_dir': DATA_DIR, 'normalize_data': True},
    'wavelet': {'type': WAVELET_TYPE, 'levels': NUM_WAVELET_LEVELS}
}

# Note: This might require actual data files in /content/data
try:
    dm = WaveletTimeSeriesDataModule(config=config)
    bridge = JAXDataBridge(dm)
    dataloader = bridge.get_iterator()
    LEVEL_DIMS = bridge.get_level_dims()
    print("Production DataModule and JAX Bridge initialized successfully.")
except Exception as e:
    print(f"Warning: Could not load real data: {e}")
    print("Falling back to synthetic data structure for initialization.")
    LEVEL_DIMS = [16, 16, 32, 64] #db4, level 3, seq 128
    def synthetic_gen():
        while True:
            yield [np.random.randn(BATCH_SIZE, d, 1).astype('float32') for d in LEVEL_DIMS]
    dataloader = synthetic_gen()

In [None]:
# @title WaveletDiff TPU Backend Optimization
from src.tpu_keras.models.transformer import WaveletDiffusionTransformer
from src.tpu_keras.models.diffusion import DiffusionScheduler
from src.tpu_keras.models.losses import WaveletLoss
from src.tpu_keras.trainer import TPUTrainer

# 1. Model Assembly (with TPU-optimized architecture)
model = WaveletDiffusionTransformer(
    input_dim=1,
    model_dim=MODEL_DIM,
    num_levels=NUM_WAVELET_LEVELS,
    num_layers_per_level=NUM_LAYERS
)

# 2. Noise Scheduling (Cosine for better gradients)
scheduler = DiffusionScheduler(num_steps=NUM_STEPS, schedule_type='cosine')

# 3. Wavelet-Aware Loss Function
loss_fn = WaveletLoss(level_dims=LEVEL_DIMS, strategy="coefficient_weighted")

# 4. High-Throughput TPU Trainer
trainer = TPUTrainer(
    model=model,
    scheduler=scheduler,
    loss_fn=loss_fn,
    learning_rate=LEARNING_RATE,
    steps_per_epoch=100,
    log_interval_percent=1  # Eliminates Host-TPU bottleneck
)

print("Backend modules fully integrated and ready for TPU training.")

In [None]:
# @title Training Loop Execution
for epoch in range(1, EPOCHS + 1):
    # Each epoch executes multi-step logic on device without host interruption
    trainer.train_epoch(dataloader, epoch)
    
    # Periodic Sampling for Quality Monitoring
    if epoch % 100 == 0:
        print(f" Landmark reached at Epoch {epoch} - Reviewing sample distribution...")