# WaveletDiff Training (Enhanced)

This notebook provides a robust environment for training the WaveletDiff model, with full control over repository branches, hyperparameters, compilation, and logging.

In [None]:
# @title Cell 1: Global Configuration

# --- Environment ---
BRANCH = "develop" # @param {type:"string"}

# --- Data Paths ---
DATASET = "stocks" # @param ["stocks", "ett", "fmri", "exchange_rate", "eeg"]
EXPERIMENT_NAME = "stocks_experiment" # @param {type:"string"}
DATA_DIR = "../data/stocks/stock_data.csv" # @param {type:"string"}

# --- Core Hyperparameters ---
EPOCHS = 5000 # @param {type:"integer"}
BATCH_SIZE = 512 # @param {type:"integer"}
SEQ_LEN = 24 # @param {type:"integer"}
LR = 0.0002 # @param {type:"number"}

# --- Compilation ---
COMPILE_ENABLED = False # @param {type:"boolean"}
COMPILE_MODE = "default" # @param ["default", "reduce-overhead", "max-autotune"]

# --- Logging & UI ---
LOG_EVERY_N_EPOCHS = 10 # @param {type:"integer"}
ENABLE_PROGRESS_BAR = True # @param {type:"boolean"}

# --- Model Architecture ---
EMBED_DIM = 256
NUM_HEADS = 8 
NUM_LAYERS = 8
TIME_EMBED_DIM = 128 
DROPOUT = 0.1
PREDICTION_TARGET = "noise"
USE_CROSS_LEVEL_ATTENTION = True

# --- Wavelet & Noise ---
WAVELET_TYPE = "db2"
WAVELET_LEVELS = "auto"
NOISE_SCHEDULE = "exponential"
ENERGY_WEIGHT = 0.0

# --- Optimizer ---
SCHEDULER_TYPE = "onecycle"
WARMUP_EPOCHS = 50 

In [None]:
# @title Cell 2: Setup & environment
import os
import sys
import shutil

REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
REPO_NAME = "waveletDiff_synth_data"
REPO_PATH = os.path.abspath(REPO_NAME)

if not os.path.exists(REPO_PATH):
    print(f"Cloning {REPO_URL} branch {BRANCH}...")
    !git clone -b {BRANCH} {REPO_URL} {REPO_NAME}
else:
    print(f"Updating {REPO_NAME} to branch {BRANCH}...")
    !git -C {REPO_NAME} fetch --all
    !git -C {REPO_NAME} checkout {BRANCH}
    !git -C {REPO_NAME} pull origin {BRANCH}

# Add repo to path for notebook usage
if REPO_PATH not in sys.path:
    sys.path.append(REPO_PATH)

!pip install -q pytorch-lightning pywavelets scipy pandas tqdm lightning
print(f"âœ… Repository ready at: {REPO_PATH}")

In [None]:
# @title Cell 3: Run WaveletDiff Training
%cd {REPO_NAME}

print(f"Running training for {DATASET}...")
!python src/train.py \
    --dataset {DATASET} \
    --experiment_name {EXPERIMENT_NAME} \
    --data_dir {DATA_DIR} \
    --epochs {EPOCHS} \
    --batch_size {BATCH_SIZE} \
    --seq_len {SEQ_LEN} \
    --lr {LR} \
    --compile_enabled {str(COMPILE_ENABLED).lower()} \
    --compile_mode {COMPILE_MODE} \
    --log_every_n_epochs {LOG_EVERY_N_EPOCHS} \
    --enable_progress_bar {str(ENABLE_PROGRESS_BAR).lower()} \
    --embed_dim {EMBED_DIM} \
    --num_heads {NUM_HEADS} \
    --num_layers {NUM_LAYERS} \
    --time_embed_dim {TIME_EMBED_DIM} \
    --dropout {DROPOUT} \
    --prediction_target {PREDICTION_TARGET} \
    --use_cross_level_attention {str(USE_CROSS_LEVEL_ATTENTION).lower()} \
    --wavelet_type {WAVELET_TYPE} \
    --wavelet_levels {WAVELET_LEVELS} \
    --noise_schedule {NOISE_SCHEDULE} \
    --energy_weight {ENERGY_WEIGHT} \
    --scheduler_type {SCHEDULER_TYPE} \
    --warmup_epochs {WARMUP_EPOCHS}