# WaveletDiff Training (Replication)

This notebook replicates the logic of `src/train.py` exactly, using a config cell for parameters.

In [None]:
# @title Cell 1: Global Configuration

# --- Experiment Identity ---
EXPERIMENT_NAME = "default_experiment" # @param {type:"string"}

# --- Dataset ---
DATASET_NAME = "etth1" # @param {type:"string"}
SEQ_LEN = 24 # @param {type:"integer"}
NORMALIZE_DATA = True # @param {type:"boolean"}

# --- Model Hyperparameters ---
EMBED_DIM = 256 # @param {type:"integer"}
NUM_HEADS = 8 # @param {type:"integer"}
NUM_LAYERS = 8 # @param {type:"integer"}
TIME_EMBED_DIM = 128 # @param {type:"integer"}
DROPOUT = 0.1 # @param {type:"number"}
PREDICTION_TARGET = "noise" # @param ["noise", "coefficient"]

# --- Attention & Energy ---
USE_CROSS_LEVEL_ATTENTION = True # @param {type:"boolean"}
ENERGY_WEIGHT = 0.0 # @param {type:"number"}

# --- Noise Schedule ---
NOISE_SCHEDULE = "exponential" # @param ["exponential", "cosine", "linear"]

# --- Wavelet ---
WAVELET_TYPE = "auto" # @param {type:"string"}
WAVELET_LEVELS = "auto" # @param {type:"raw"}

# --- Sampling ---
SAMPLING_METHOD = "ddpm" # @param ["ddpm", "ddim"]
DDIM_ETA = 0.0 # @param {type:"number"}
DDIM_STEPS = None # @param {type:"raw"}

# --- Optimizer ---
SCHEDULER_TYPE = "onecycle" # @param ["onecycle", "cosine_warmup", "plateau_warmup", "cosine", "plateau"]
WARMUP_EPOCHS = 50 # @param {type:"integer"}
LEARNING_RATE = 0.0002 # @param {type:"number"}

# --- Training ---
NUM_EPOCHS = 5000 # @param {type:"integer"}
BATCH_SIZE = 512 # @param {type:"integer"}
SAVE_MODEL = True # @param {type:"boolean"}

# --- Paths ---
DATA_DIR = "../data" # @param {type:"string"}
OUTPUT_DIR = "../outputs" # @param {type:"string"}

# --- Reproducibility ---
SEED = 42 # @param {type:"integer"}

In [None]:
# @title Cell 2: Setup
import os
import sys
import argparse
import time
from datetime import timedelta
from pathlib import Path
import random

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Timer
import numpy as np

# Add source to path if running from root or adjacent
# We need 'WaveletDiff_source/src' to be in sys.path so 'import models' works
project_root = os.getcwd()
src_path = os.path.join(project_root, 'WaveletDiff_source', 'src')

if src_path not in sys.path:
    sys.path.append(src_path)
    print(f"Added {src_path} to sys.path")

try:
    from models import WaveletDiffusionTransformer
    from training import DiffusionTrainer
    from data import WaveletTimeSeriesDataModule
    from utils import ConfigManager
    print("✅ Imports successful")
except ImportError as e:
    print(f"❌ Import failed: {e}")
    print("Please ensure WaveletDiff_source/src exists.")

# Set Precision
try:
    torch.set_float32_matmul_precision('medium')
    print("Enabled optimized matmul precision")
except Exception as e:
    print(f"Could not set matmul precision: {e}")
    print("Continuing with default precision...")

# Seed
pl.seed_everything(SEED)

In [None]:
# @title Cell 3: Configuration Processing

# Handle WAVELET_LEVELS type
wavelet_levels = str(WAVELET_LEVELS)
if wavelet_levels.isdigit():
    wavelet_levels = int(wavelet_levels)
elif wavelet_levels.lower() == "auto":
    wavelet_levels = "auto"

config = {
    'training': {
        'epochs': NUM_EPOCHS,
        'batch_size': BATCH_SIZE,
        'save_model': SAVE_MODEL
    },
    'model': {
        'embed_dim': EMBED_DIM,
        'num_heads': NUM_HEADS,
        'num_layers': NUM_LAYERS,
        'time_embed_dim': TIME_EMBED_DIM,
        'dropout': DROPOUT,
        'prediction_target': PREDICTION_TARGET
    },
    'attention': {
        'use_cross_level_attention': USE_CROSS_LEVEL_ATTENTION
    },
    'energy': {
        'weight': ENERGY_WEIGHT
    },
    'noise': {
        'schedule': NOISE_SCHEDULE
    },
    'wavelet': {
        'type': WAVELET_TYPE,
        'levels': wavelet_levels
    },
    'sampling': {
        'method': SAMPLING_METHOD,
        'ddim_eta': DDIM_ETA,
        'ddim_steps': DDIM_STEPS
    },
    'data': {
        'normalize_data': NORMALIZE_DATA,
        'data_dir': DATA_DIR
    },
    'optimizer': {
        'scheduler_type': SCHEDULER_TYPE,
        'warmup_epochs': WARMUP_EPOCHS,
        'lr': LEARNING_RATE
    },
    'dataset': {
        'name': DATASET_NAME,
        'seq_len': SEQ_LEN
    },
    'evaluation': {
        'num_samples': 20000
    },
    'paths': {
        'output_dir': OUTPUT_DIR
    }
}

print(f"Starting WaveletDiff Training")
print(f"Dataset: {config['dataset']['name']}")
print(f"Sequence Length: {config['dataset']['seq_len']}")
print(f"Epochs: {config['training']['epochs']}")
print(f"Batch Size: {config['training']['batch_size']}")
print(f"Prediction Target: {config['model']['prediction_target']}")
print(f"Cross-level Attention: {'Enabled' if config['attention']['use_cross_level_attention'] else 'Disabled'} (cross_only)")
print(f"Loss Strategy: coefficient_weighted (approximation_weight=2)")
print(f"Energy Term: {'Enabled' if config['energy']['weight'] > 0 else 'Disabled'} (level_feature, absolute)")
print(f"Noise Schedule: {config['noise']['schedule']}")

In [None]:
# @title Cell 4: Data Module Setup

print("\n" + "="*60)
print("SETTING UP DATA MODULE")
print("="*60)

data_module = WaveletTimeSeriesDataModule(config=config)

print(f"Data module setup complete")
print(f"Input dimension: {data_module.get_input_dim()}")
print(f"Dataset size: {len(data_module.dataset)}")
print(f"Wavelet: {data_module.wavelet_type} with {data_module.wavelet_info['levels']} levels")

# Get wavelet info
wavelet_info = data_module.get_wavelet_info()
print(f"   Wavelet levels: {wavelet_info['levels']}")
for i, shape in enumerate(wavelet_info['coeffs_shapes']):
    print(f"     Level {i}: {shape} -> {np.prod(shape)} coefficients")

In [None]:
# @title Cell 5: Model Initialization

print("\n" + "="*60)
print("INITIALIZING MODEL")
print("="*60)

model = WaveletDiffusionTransformer(data_module=data_module, config=config)

# Create experiment directories
dataset_name = config['dataset']['name']

# Create experiment folder structure
experiment_name = EXPERIMENT_NAME
experiment_dir = Path(config['paths']['output_dir']) / experiment_name
experiment_dir.mkdir(parents=True, exist_ok=True)

model_filename = f"checkpoint.ckpt"
model_path = experiment_dir / model_filename

print(f"Experiment: {experiment_name}")
print(f"Model checkpoint will be saved to: {model_path}")

In [None]:
# @title Cell 6: Training

print("\n" + "="*60)
print("TRAINING MODEL")
print("="*60)

trainer = pl.Trainer(
    max_epochs=config['training']['epochs'],
    accelerator='gpu',
    devices='auto',
    strategy="ddp_find_unused_parameters_true",
    precision="32",
    callbacks=[Timer()],
    enable_checkpointing=False,
    enable_progress_bar=False,
    log_every_n_steps=50,
    gradient_clip_val=1.0,
    detect_anomaly=False,
    gradient_clip_algorithm="norm",
    logger=False
)

start_time = time.time()
trainer.fit(model, data_module)
training_time = time.time() - start_time

print(f"Training completed in {timedelta(seconds=training_time)}")

In [None]:
# @title Cell 7: Save Model

if config['training']['save_model']:
    trainer.save_checkpoint(str(model_path))
    print(f"Model saved to {model_path}")