# WaveletDiff TPU Training (Keras 3 Backend)

Frontend notebook for high-performance TPU training. Logic delegated to `src.w_keras`.


In [None]:
# @title Cell 1: Global Configuration
import os
import sys

# --- Config ---
REPO_DIR = "/content/waveletDiff_synth_data"
DATA_PATH = f"{REPO_DIR}/src/copied_waveletDiff/data/stocks/stock_data.csv"
CHECKPOINT_DIR = "/content/checkpoints"

# Params
SEQ_LEN = 24
BATCH_SIZE = 256
EPOCHS = 50
STEPS_PER_EPOCH = None # Leave None for full epoch from finite dataset

# Model Config dictionary passed to backend
CONFIG = {
    'EMBED_DIM': 256,
    'NUM_HEADS': 8,
    'NUM_LAYERS': 8,
    'DROPOUT': 0.1,
    'TIME_EMBED_DIM': 128,
    'PREDICTION_TARGET': 'noise',
    'USE_CROSS_LEVEL_ATTENTION': True,
    'LEARNING_RATE': 2e-4,
    'WEIGHT_DECAY': 1e-5,
    'EPOCHS': EPOCHS,
    'STEPS_PER_EPOCH': STEPS_PER_EPOCH,
    'CHECKPOINT_DIR': CHECKPOINT_DIR
}

In [None]:
# @title Cell 2: Setup (Clone & Install)
# 1. Clone
if not os.path.exists(REPO_DIR):
    !git clone https://github.com/MilesHoffman/waveletDiff_synth_data.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull

# 2. Dependencies
!pip install keras --upgrade
!pip install pywavelets
!pip install tensorflow

# 3. Import Path
if REPO_DIR not in sys.path:
    sys.path.append(f"{REPO_DIR}/src")


In [None]:
# @title Cell 3: Initialize Environment
from w_keras import trainer_interface as trainer

trainer.setup_environment()


In [None]:
# @title Cell 4: Load Data
ds, info = trainer.get_dataloader(DATA_PATH, BATCH_SIZE, SEQ_LEN)

print(f"Wavelet Info: {info['level_dims']}")


In [None]:
# @title Cell 5: Initialize Model
model = trainer.init_model(info, CONFIG)

# Dummy prediction to force build and summary
for batch in ds.take(1):
    try:
        model.predict(batch[0], verbose=0)
    except:
        pass
model.summary()


In [None]:
# @title Cell 6: Run Training
history = trainer.train_loop(model, ds, CONFIG)
