# WaveletDiff Training (Enhanced Logging)

This notebook gives you full control over hyperparameters, data paths, and now logging frequency and progress tracking.

In [None]:
# @title Cell 1: Global Configuration

# --- Data Paths ---
DATASET = "stocks" # @param ["stocks", "ett", "fmri", "exchange_rate", "eeg"]
EXPERIMENT_NAME = "stocks_experiment" # @param {type:"string"}
DATA_DIR = "../data/stocks/stock_data.csv" # @param {type:"string"}

# --- Core Hyperparameters ---
EPOCHS = 5000 # @param {type:"integer"}
BATCH_SIZE = 512 # @param {type:"integer"}
SEQ_LEN = 24 # @param {type:"integer"}
LR = 0.0002 # @param {type:"number"}

# --- Logging & UI ---
LOG_EVERY_N_EPOCHS = 10 # @param {type:"integer"}
ENABLE_PROGRESS_BAR = True # @param {type:"boolean"}

# --- Model Architecture ---
EMBED_DIM = 256 # @param {type:"integer"}
NUM_HEADS = 8 # @param {type:"integer"}
NUM_LAYERS = 8 # @param {type:"integer"}
TIME_EMBED_DIM = 128 # @param {type:"integer"}
DROPOUT = 0.1 # @param {type:"number"}
PREDICTION_TARGET = "noise" # @param ["noise", "coefficient"]
USE_CROSS_LEVEL_ATTENTION = True # @param {type:"boolean"}

# --- Wavelet & Noise ---
WAVELET_TYPE = "db2" # @param {type:"string"}
WAVELET_LEVELS = "auto" # @param {type:"string"}
NOISE_SCHEDULE = "exponential" # @param ["exponential", "cosine", "linear"]
ENERGY_WEIGHT = 0.0 # @param {type:"number"}

# --- Optimizer ---
SCHEDULER_TYPE = "onecycle" # @param ["onecycle", "cosine_warmup", "plateau_warmup", "cosine", "plateau"]
WARMUP_EPOCHS = 50 # @param {type:"integer"}

In [None]:
# @title Cell 2: Setup & environment
import os
import sys
import shutil

REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
REPO_NAME = "waveletDiff_synth_data"
REPO_PATH = os.path.abspath(REPO_NAME)

if not os.path.exists(REPO_PATH):
    print(f"Cloning {REPO_URL}...")
    !git clone {REPO_URL} {REPO_NAME}
else:
    print("Repo already exists. Pulling latest...")
    !git -C {REPO_NAME} pull

# Add repo to path for notebook usage
if REPO_PATH not in sys.path:
    sys.path.append(REPO_PATH)

!pip install -q pytorch-lightning pywavelets scipy pandas tqdm lightning
print(f"âœ… Repository ready at: {REPO_PATH}")

In [None]:
# @title Cell 3: Run WaveletDiff Training
%cd {REPO_NAME}

print(f"Running training for {DATASET}...")
!python src/train.py \
    --dataset {DATASET} \
    --experiment_name {EXPERIMENT_NAME} \
    --data_dir {DATA_DIR} \
    --epochs {EPOCHS} \
    --batch_size {BATCH_SIZE} \
    --seq_len {SEQ_LEN} \
    --lr {LR} \
    --log_every_n_epochs {LOG_EVERY_N_EPOCHS} \
    --enable_progress_bar {str(ENABLE_PROGRESS_BAR).lower()} \
    --embed_dim {EMBED_DIM} \
    --num_heads {NUM_HEADS} \
    --num_layers {NUM_LAYERS} \
    --time_embed_dim {TIME_EMBED_DIM} \
    --dropout {DROPOUT} \
    --prediction_target {PREDICTION_TARGET} \
    --use_cross_level_attention {str(USE_CROSS_LEVEL_ATTENTION).lower()} \
    --wavelet_type {WAVELET_TYPE} \
    --wavelet_levels {WAVELET_LEVELS} \
    --noise_schedule {NOISE_SCHEDULE} \
    --energy_weight {ENERGY_WEIGHT} \
    --scheduler_type {SCHEDULER_TYPE} \
    --warmup_epochs {WARMUP_EPOCHS}