# WaveletDiff Training

This notebook calls `train.py` directly, using YAML configs by default with optional CLI overrides.

In [None]:
# @title Cell 1: Configuration

# --- Core ---
DATASET = "stocks" # @param {type:"string"}
EXPERIMENT_NAME = "default_experiment" # @param {type:"string"}

# --- Override YAML with custom params? ---
USE_CUSTOM_PARAMS = False # @param {type:"boolean"}

# --- Custom Overrides (only used if USE_CUSTOM_PARAMS is True) ---
EPOCHS = 5000 # @param {type:"integer"}
BATCH_SIZE = 512 # @param {type:"integer"}
SEQ_LEN = 24 # @param {type:"integer"}

In [None]:
# @title Cell 2: Setup
import os
import sys

REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
REPO_NAME = "waveletDiff_synth_data"
REPO_PATH = os.path.abspath(REPO_NAME)

if not os.path.exists(REPO_PATH):
    print(f"Cloning {REPO_URL}...")
    !git clone {REPO_URL} {REPO_NAME}
else:
    print("Repo already exists. Pulling latest...")
    !git -C {REPO_NAME} pull

!pip install -q pytorch-lightning pywavelets scipy pandas tqdm lightning
print(f"âœ… Repository ready at: {REPO_PATH}")

In [None]:
# @title Cell 3: Run Training
%cd {REPO_NAME}

if USE_CUSTOM_PARAMS:
    print(f"Running with custom params: epochs={EPOCHS}, batch_size={BATCH_SIZE}, seq_len={SEQ_LEN}")
    !python src/train.py --dataset {DATASET} --experiment_name {EXPERIMENT_NAME} --epochs {EPOCHS} --batch_size {BATCH_SIZE} --seq_len {SEQ_LEN}
else:
    print(f"Running with YAML defaults for dataset: {DATASET}")
    !python src/train.py --dataset {DATASET} --experiment_name {EXPERIMENT_NAME}