# WaveletDiff Training (Enhanced)

This notebook provides a robust environment for training the WaveletDiff model, with full control over repository branches, hyperparameters, compilation, and logging.

In [None]:
# @title Cell 1: Global Configuration

# --- Environment ---
BRANCH = "develop" # @param {type:"string"}

# --- Data Paths ---
DATASET = "stocks" # @param ["stocks", "ett", "fmri", "exchange_rate", "eeg"]
EXPERIMENT_NAME = "stocks_experiment" # @param {type:"string"}
DATA_DIR = "../data/stocks/stock_data.csv" # @param {type:"string"}

# --- Core Hyperparameters ---
EPOCHS = 5000 # @param {type:"integer"}
BATCH_SIZE = 512 # @param {type:"integer"}
SEQ_LEN = 24 # @param {type:"integer"}
LR = 0.0002 # @param {type:"number"}

# --- Performance ---
PRECISION = "bf16-mixed" # @param ["32", "bf16-mixed", "16-mixed"]
MATMUL_PRECISION = "medium" # @param ["highest", "high", "medium"]

# --- Compilation ---
COMPILE_ENABLED = True # @param {type:"boolean"}
COMPILE_MODE = "reduce-overhead" # @param ["default", "reduce-overhead", "max-autotune"]
COMPILE_FULLGRAPH = False # @param {type:"boolean"}

# --- Logging & UI ---
LOG_EVERY_N_EPOCHS = 10 # @param {type:"integer"}
ENABLE_PROGRESS_BAR = True # @param {type:"boolean"}

# --- Profiling ---
PROFILE_ENABLED = False # @param {type:"boolean"}
PROFILE_WAIT_STEPS = 5 # @param {type:"integer"}
PROFILE_WARMUP_STEPS = 3 # @param {type:"integer"}
PROFILE_ACTIVE_STEPS = 5 # @param {type:"integer"}
PROFILE_WAIT_EPOCHS = 0 # @param {type:"integer"}

# --- Google Drive Persistence ---
DRIVE_MOUNT_PATH = "/content/drive" # @param {type:"string"}
DRIVE_BASE_PATH = "/content/drive/MyDrive/waveletDiff_experiments" # @param {type:"string"}
SAVE_TO_DRIVE = True # @param {type:"boolean"}
SAVE_WEIGHTS_ONLY = True # @param {type:"boolean"}

# --- Model Architecture ---
EMBED_DIM = 256
NUM_HEADS = 8 
NUM_LAYERS = 8
TIME_EMBED_DIM = 128 
DROPOUT = 0.1
PREDICTION_TARGET = "noise"
USE_CROSS_LEVEL_ATTENTION = True

# --- Wavelet & Noise ---
WAVELET_TYPE = "db2"
WAVELET_LEVELS = "auto"
NOISE_SCHEDULE = "exponential"
ENERGY_WEIGHT = 0.0

# --- Optimizer ---
SCHEDULER_TYPE = "onecycle"
WARMUP_EPOCHS = 50 

In [None]:
# @title Cell 2: Setup & environment
import os
import sys
import shutil

REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
REPO_NAME = "waveletDiff_synth_data"
REPO_PATH = os.path.abspath(REPO_NAME)

if not os.path.exists(REPO_PATH):
    print(f"Cloning {REPO_URL} branch {BRANCH}...")
    !git clone -b {BRANCH} {REPO_URL} {REPO_NAME}
else:
    print(f"Updating {REPO_NAME} to branch {BRANCH}...")
    !git -C {REPO_NAME} fetch --all
    !git -C {REPO_NAME} checkout {BRANCH}
    !git -C {REPO_NAME} pull origin {BRANCH}

# Add repo to path for notebook usage
if REPO_PATH not in sys.path:
    sys.path.append(REPO_PATH)

!pip install -q pytorch-lightning pywavelets scipy pandas tqdm lightning
print(f"✅ Repository ready at: {REPO_PATH}")

In [None]:
# @title Cell 3: Run WaveletDiff Training
%cd {REPO_NAME}

print(f"Running training for {DATASET}...")
!python src/train.py \
    --dataset {DATASET} \
    --experiment_name {EXPERIMENT_NAME} \
    --data_dir {DATA_DIR} \
    --epochs {EPOCHS} \
    --batch_size {BATCH_SIZE} \
    --seq_len {SEQ_LEN} \
    --lr {LR} \
    --precision {PRECISION} \
    --matmul_precision {MATMUL_PRECISION} \
    --compile_enabled {str(COMPILE_ENABLED).lower()} \
    --compile_mode {COMPILE_MODE} \
    --compile_fullgraph {str(COMPILE_FULLGRAPH).lower()} \
    --log_every_n_epochs {LOG_EVERY_N_EPOCHS} \
    --enable_progress_bar {str(ENABLE_PROGRESS_BAR).lower()} \
    --profile_enabled {str(PROFILE_ENABLED).lower()} \
    --profile_wait_steps {PROFILE_WAIT_STEPS} \
    --profile_warmup_steps {PROFILE_WARMUP_STEPS} \
    --profile_active_steps {PROFILE_ACTIVE_STEPS} \
    --profile_wait_epochs {PROFILE_WAIT_EPOCHS} \
    --save_weights_only {str(SAVE_WEIGHTS_ONLY).lower()} \
    --embed_dim {EMBED_DIM} \
    --num_heads {NUM_HEADS} \
    --num_layers {NUM_LAYERS} \
    --time_embed_dim {TIME_EMBED_DIM} \
    --dropout {DROPOUT} \
    --prediction_target {PREDICTION_TARGET} \
    --use_cross_level_attention {str(USE_CROSS_LEVEL_ATTENTION).lower()} \
    --wavelet_type {WAVELET_TYPE} \
    --wavelet_levels {WAVELET_LEVELS} \
    --noise_schedule {NOISE_SCHEDULE} \
    --energy_weight {ENERGY_WEIGHT} \
    --scheduler_type {SCHEDULER_TYPE} \
    --warmup_epochs {WARMUP_EPOCHS}

In [None]:
# @title Cell 4: Save Experiment to Google Drive
from google.colab import drive
import shutil
import os

if SAVE_TO_DRIVE:
    # 1. Mount Drive
    if not os.path.exists(DRIVE_MOUNT_PATH):
        drive.mount(DRIVE_MOUNT_PATH)

    # 2. Prepare paths
    # Assuming we are currently in REPO_NAME directory after Cell 3
    experiment_output_dir = os.path.join("outputs", EXPERIMENT_NAME)
    archive_name = f"{EXPERIMENT_NAME}.tar.gz"
    
    # Create destination folder if not exists
    os.makedirs(DRIVE_BASE_PATH, exist_ok=True)
    destination_path = os.path.join(DRIVE_BASE_PATH, archive_name)

    print(f"Compressing experiment artifacts from {experiment_output_dir}...")
    
    try:
        # Create tar.gz file
        # shutil.make_archive('base_name', 'gztar', root_dir) creates base_name.tar.gz
        # We use 'gztar' for gzip compression which is better than zip for checkpoints
        archive_path = shutil.make_archive(EXPERIMENT_NAME, 'gztar', experiment_output_dir)
        
        print(f"Archive created at: {archive_path}")
        print(f"Copying to Google Drive: {destination_path}...")
        
        shutil.copy2(archive_path, destination_path)
        print(f"✅ Successfully saved experiment archive to Google Drive!")
        print(f"Location: {destination_path}")
        
    except Exception as e:
        print(f"❌ Error saving to Drive: {e}")
else:
    print("Skipping Save to Drive (SAVE_TO_DRIVE=False)")