# WaveletDiff Training (Stocks Dataset)

This notebook trains the WaveletDiff model on the **stocks** dataset using the modular `src` directory.

### Workflow:
1. **Configuration**: Tune hyperparameters, paths, and precision.
2. **Setup**: Clones the repo, installs dependencies.
3. **Environment**: Configures PyTorch precision and seeds.
4. **Data**: Load and prepare data.
5. **Model**: Initialize the model.
6. **Train**: Run training.

In [None]:
# @title Cell 1: Global Configuration
import random

# --- Experiment Identity ---
RUN_NAME = "stocks_baseline_v1" # @param {type:"string"}
RUN_ID = f"{random.randint(0, 999):03d}"
UNIQUE_RUN_NAME = f"{RUN_NAME}_{RUN_ID}"

# --- Dataset (Stocks-specific, fixed) ---
DATASET_NAME = "stocks"
SEQ_LEN = 24
NORMALIZE_DATA = True

# --- Model Hyperparameters ---
EMBED_DIM = 256 # @param {type:"integer"}
NUM_HEADS = 8 # @param {type:"integer"}
NUM_LAYERS = 8 # @param {type:"integer"}
TIME_EMBED_DIM = 128 # @param {type:"integer"}
DROPOUT = 0.1 # @param {type:"number"}

# --- Training Hyperparameters ---
NUM_EPOCHS = 5000 # @param {type:"integer"}
BATCH_SIZE = 512 # @param {type:"integer"}
LEARNING_RATE = 2e-4 # @param {type:"number"}
GRADIENT_CLIP_VAL = 1.0 # @param {type:"number"}

# --- Optimizer Hyperparameters ---
WEIGHT_DECAY = 1e-5 # @param {type:"number"}
ONECYCLE_MAX_LR = 1e-3 # @param {type:"number"}
ONECYCLE_PCT_START = 0.3 # @param {type:"number"}

# --- Logging & Checkpointing ---
LOG_EVERY_N_STEPS = 50 # @param {type:"integer"}
SAVE_EVERY_N_EPOCHS = 100 # @param {type:"integer"}

# --- Hardware & Precision ---
PRECISION = "bf16-mixed" # @param ["32", "bf16-mixed"]
MATMUL_PRECISION = "medium" # @param ["medium", "high"]

# --- Paths ---
DRIVE_BASE_PATH = "/content/drive/MyDrive/personal_drive/trading" # @param {type:"string"}
OUTPUT_DIR = f"{DRIVE_BASE_PATH}/checkpoints/temp/{UNIQUE_RUN_NAME}"
REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
REPO_DIR = "/content/waveletDiff_synth_data"

# --- Reproducibility ---
SEED = 42 # @param {type:"integer"}

# --- Fixed Settings (Stocks configuration) ---
PREDICTION_TARGET = "noise"
USE_CROSS_LEVEL_ATTENTION = True
ENERGY_WEIGHT = 0.0
NOISE_SCHEDULE = "exponential"
SCHEDULER_TYPE = "onecycle"
WAVELET_TYPE = "db2"
WAVELET_LEVELS = "auto"
DDIM_ETA = 0.0
DDIM_STEPS = None
ACCELERATOR = "gpu"
DEVICES = 1
SAVE_TOP_K = -1
WARMUP_EPOCHS = 50

In [None]:
# @title Cell 2: Setup (Clone, Install, Mount)
import os
import sys
import subprocess

try:
    from google.colab import drive
    if not os.path.exists('/content/drive'):
        drive.mount('/content/drive')
    print("✅ Drive mounted")
except ImportError:
    print("Not running on Colab. Skipping Drive mount.")
    DRIVE_BASE_PATH = "local_checkpoints"
    os.makedirs(DRIVE_BASE_PATH, exist_ok=True)

if not os.path.exists(REPO_DIR):
    print(f"Cloning {REPO_URL} into {REPO_DIR}...")
    subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
else:
    print(f"Repo exists at {REPO_DIR}, pulling latest changes...")
    subprocess.run(["git", "-C", REPO_DIR, "pull"], check=True)

if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)

print("Installing dependencies...")
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "pytorch-lightning", "pywavelets", "scipy", "pandas", "tqdm", "lightning"], check=True)
print("✅ Dependencies installed")

import importlib
importlib.invalidate_caches()
print("✅ Env Ready")

In [None]:
# @title Cell 3: Setup Environment
import torch
import pytorch_lightning as pl

if SEED is not None:
    pl.seed_everything(SEED)

try:
    torch.set_float32_matmul_precision(MATMUL_PRECISION)
    print(f"✅ Matmul precision set to {MATMUL_PRECISION}")
except Exception as e:
    print(f"Could not set matmul precision: {e}")

In [None]:
# @title Cell 4: Load Data
from src.data import WaveletTimeSeriesDataModule
import numpy as np

config = {
    'training': {'epochs': NUM_EPOCHS, 'batch_size': BATCH_SIZE, 'save_model': True},
    'dataset': {'name': DATASET_NAME, 'seq_len': SEQ_LEN},
    'data': {'data_dir': REPO_DIR, 'normalize_data': NORMALIZE_DATA},
    'wavelet': {'type': WAVELET_TYPE, 'levels': WAVELET_LEVELS},
    'model': {
        'embed_dim': EMBED_DIM, 'num_heads': NUM_HEADS, 'num_layers': NUM_LAYERS,
        'time_embed_dim': TIME_EMBED_DIM, 'dropout': DROPOUT, 'prediction_target': PREDICTION_TARGET,
    },
    'attention': {'use_cross_level_attention': USE_CROSS_LEVEL_ATTENTION},
    'energy': {'weight': ENERGY_WEIGHT},
    'noise': {'schedule': NOISE_SCHEDULE},
    'optimizer': {
        'scheduler_type': SCHEDULER_TYPE, 'warmup_epochs': WARMUP_EPOCHS, 'lr': LEARNING_RATE,
        'weight_decay': WEIGHT_DECAY, 'onecycle_max_lr': ONECYCLE_MAX_LR, 'onecycle_pct_start': ONECYCLE_PCT_START,
    },
    'sampling': {'ddim_eta': DDIM_ETA, 'ddim_steps': DDIM_STEPS},
    'paths': {'output_dir': OUTPUT_DIR},
}

datamodule = WaveletTimeSeriesDataModule(config=config)
print(f"✅ Data loaded: {datamodule.raw_data_tensor.shape}")
print(f"✅ Wavelet dimension: {datamodule.get_input_dim()}")

In [None]:
# @title Cell 5: Initialize Model
from src.models import WaveletDiffusionTransformer

model = WaveletDiffusionTransformer(data_module=datamodule, config=config)
print("✅ Model initialized")

In [None]:
# @title Cell 6: Run Training
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Timer, TQDMProgressBar
import os

os.makedirs(OUTPUT_DIR, exist_ok=True)

callbacks = [
    Timer(),
    TQDMProgressBar(refresh_rate=1),
    ModelCheckpoint(
        dirpath=OUTPUT_DIR,
        filename='checkpoint-{epoch:02d}',
        save_top_k=SAVE_TOP_K,
        every_n_epochs=SAVE_EVERY_N_EPOCHS
    )
]

trainer = pl.Trainer(
    max_epochs=NUM_EPOCHS,
    accelerator=ACCELERATOR,
    devices=DEVICES,
    precision=PRECISION,
    gradient_clip_val=GRADIENT_CLIP_VAL,
    gradient_clip_algorithm="norm",
    callbacks=callbacks,
    enable_checkpointing=True,
    logger=False,
    log_every_n_steps=LOG_EVERY_N_STEPS
)

print("Starting training...")
trainer.fit(model, datamodule)
print(f"✅ Training finished. Checkpoints saved to {OUTPUT_DIR}")