# WaveletDiff Training (Stocks Dataset)

This notebook trains the WaveletDiff model on the **stocks** dataset using the modular `src` directory.

### Workflow:
1. **Configuration**: Tune hyperparameters, paths, and precision.
2. **Setup**: Clones the repo, installs dependencies.
3. **Cache**: Restore torch.compile cache if available.
4. **Environment**: Configures PyTorch precision and seeds.
5. **Data**: Load and prepare data.
6. **Model**: Initialize and compile the model.
7. **Train**: Run training.
8. **Cache**: Save torch.compile cache to Drive.

In [None]:
# @title Cell 1: Global Configuration
import random
import os

# --- Experiment Identity ---
RUN_NAME = "stocks_baseline_v1" # @param {type:"string"}
RUN_ID = f"{random.randint(0, 999):03d}"
UNIQUE_RUN_NAME = f"{RUN_NAME}_{RUN_ID}"

# --- Dataset (Stocks-specific, fixed) ---
DATASET_NAME = "stocks"
SEQ_LEN = 24
NORMALIZE_DATA = True

# --- Model Hyperparameters ---
EMBED_DIM = 256 # @param {type:"integer"}
NUM_HEADS = 8 # @param {type:"integer"}
NUM_LAYERS = 8 # @param {type:"integer"}
TIME_EMBED_DIM = 128 # @param {type:"integer"}
DROPOUT = 0.1 # @param {type:"number"}

# --- Training Hyperparameters ---
NUM_EPOCHS = 5000 # @param {type:"integer"}
BATCH_SIZE = 512 # @param {type:"integer"}
LEARNING_RATE = 2e-4 # @param {type:"number"}
GRADIENT_CLIP_VAL = 1.0 # @param {type:"number"}

# --- Optimizer Hyperparameters ---
WEIGHT_DECAY = 1e-5 # @param {type:"number"}
ONECYCLE_MAX_LR = 1e-3 # @param {type:"number"}
ONECYCLE_PCT_START = 0.3 # @param {type:"number"}

# --- Logging & Checkpointing ---
LOG_EVERY_N_EPOCHS = 1 # @param {type:"integer"}
LOG_LEVEL_LOSSES_EVERY_N_EPOCHS = 100 # @param {type:"integer"}
LOG_EVERY_N_STEPS = 50 # @param {type:"integer"}
SAVE_EVERY_N_EPOCHS = 1000 # @param {type:"integer"}

# --- Hardware & Precision ---
PRECISION = "32" # @param ["32", "bf16-mixed"]
MATMUL_PRECISION = "medium" # @param ["medium", "high"]

# --- torch.compile Settings ---
COMPILE_MODE = None # @param [None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"]
COMPILE_CACHE_DIR = "/content/torchinductor_cache" # @param {type:"string"}

# --- Paths ---
DRIVE_BASE_PATH = "/content/drive/MyDrive/personal_drive/trading/waveletDiff" # @param {type:"string"}
OUTPUT_DIR = f"{DRIVE_BASE_PATH}/checkpoints/temp/{UNIQUE_RUN_NAME}"
COMPILE_CACHE_DRIVE_PATH = f"{DRIVE_BASE_PATH}/compile_cache" # @param {type:"string"}
REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
REPO_DIR = "/content/waveletDiff_synth_data"

# Specific path to your CSV file
DATA_PATH = "/content/waveletDiff_synth_data/stocks/stock_data.csv" # @param {type:"string"}

# --- Reproducibility ---
SEED = 42 # @param {type:"integer"}

# --- Fixed Settings (Stocks configuration) ---
PREDICTION_TARGET = "noise"
USE_CROSS_LEVEL_ATTENTION = True
ENERGY_WEIGHT = 0.0
NOISE_SCHEDULE = "exponential"
SCHEDULER_TYPE = "onecycle"
WAVELET_TYPE = "db2"
WAVELET_LEVELS = "auto"
DDIM_ETA = 0.0
DDIM_STEPS = None
ACCELERATOR = "gpu"
DEVICES = 1
SAVE_TOP_K = -1
WARMUP_EPOCHS = 50

In [None]:
# @title Cell 2: Setup (Clone, Install, Mount)
import os
import sys
import subprocess

try:
    from google.colab import drive
    if not os.path.exists('/content/drive'):
        drive.mount('/content/drive')
    print("✅ Drive mounted")
except ImportError:
    print("Not running on Colab. Skipping Drive mount.")
    DRIVE_BASE_PATH = "local_checkpoints"
    os.makedirs(DRIVE_BASE_PATH, exist_ok=True)

if not os.path.exists(REPO_DIR):
    print(f"Cloning {REPO_URL} into {REPO_DIR}...")
    subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
else:
    print(f"Repo exists at {REPO_DIR}, pulling latest changes...")
    subprocess.run(["git", "-C", REPO_DIR, "pull"], check=True)

if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)

print("Installing dependencies...")
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "pytorch-lightning", "pywavelets", "scipy", "pandas", "tqdm", "lightning"], check=True)
print("✅ Dependencies installed")

import importlib
importlib.invalidate_caches()
print("✅ Env Ready")

In [None]:
# @title Cell 3: Restore Compile Cache
import os
import tarfile
import shutil

CACHE_ARCHIVE_NAME = "torchinductor_cache.tar.gz"
LOCAL_CACHE_ARCHIVE = f"/content/{CACHE_ARCHIVE_NAME}"
DRIVE_CACHE_ARCHIVE = f"{COMPILE_CACHE_DRIVE_PATH}/{CACHE_ARCHIVE_NAME}"

if COMPILE_MODE is not None:
    os.makedirs(COMPILE_CACHE_DIR, exist_ok=True)
    os.environ["TORCHINDUCTOR_CACHE_DIR"] = COMPILE_CACHE_DIR
    
    # Check if cache archive exists locally first
    if os.path.exists(LOCAL_CACHE_ARCHIVE):
        print(f"Found local cache archive, extracting...")
        with tarfile.open(LOCAL_CACHE_ARCHIVE, "r:gz") as tar:
            tar.extractall(COMPILE_CACHE_DIR)
        print(f"✅ Cache restored from local archive")
    elif os.path.exists(DRIVE_CACHE_ARCHIVE):
        print(f"Found cache archive on Drive, copying and extracting...")
        shutil.copy(DRIVE_CACHE_ARCHIVE, LOCAL_CACHE_ARCHIVE)
        with tarfile.open(LOCAL_CACHE_ARCHIVE, "r:gz") as tar:
            tar.extractall(COMPILE_CACHE_DIR)
        print(f"✅ Cache restored from Drive")
    else:
        print("No existing cache found, will compile from scratch")
else:
    print("torch.compile disabled, skipping cache restore")

In [None]:
# @title Cell 4: Setup Environment
import torch
import pytorch_lightning as pl

if SEED is not None:
    pl.seed_everything(SEED)

try:
    torch.set_float32_matmul_precision(MATMUL_PRECISION)
    print(f"✅ Matmul precision set to {MATMUL_PRECISION}")
except Exception as e:
    print(f"Could not set matmul precision: {e}")

In [None]:
# @title Cell 5: Load Data
import sys
import importlib
import numpy as np

# Force reload of modules to pick up changes
if 'src.data.loaders' in sys.modules:
    import src.data.loaders
    importlib.reload(src.data.loaders)
if 'src.data.module' in sys.modules:
    import src.data.module
    importlib.reload(src.data.module)
if 'src.data' in sys.modules:
    import src.data
    importlib.reload(src.data)

from src.data import WaveletTimeSeriesDataModule

config = {
    'training': {
        'epochs': NUM_EPOCHS,
        'batch_size': BATCH_SIZE,
        'save_model': True,
        'log_every_n_epochs': LOG_EVERY_N_EPOCHS,
        'log_level_losses_every_n_epochs': LOG_LEVEL_LOSSES_EVERY_N_EPOCHS
    },
    'dataset': {'name': DATASET_NAME, 'seq_len': SEQ_LEN},
    'data': {'data_dir': REPO_DIR, 'normalize_data': NORMALIZE_DATA, 'data_path': DATA_PATH},
    'wavelet': {'type': WAVELET_TYPE, 'levels': WAVELET_LEVELS},
    'model': {
        'embed_dim': EMBED_DIM, 'num_heads': NUM_HEADS, 'num_layers': NUM_LAYERS,
        'time_embed_dim': TIME_EMBED_DIM, 'dropout': DROPOUT, 'prediction_target': PREDICTION_TARGET,
    },
    'attention': {'use_cross_level_attention': USE_CROSS_LEVEL_ATTENTION},
    'energy': {'weight': ENERGY_WEIGHT},
    'noise': {'schedule': NOISE_SCHEDULE},
    'optimizer': {
        'scheduler_type': SCHEDULER_TYPE, 'warmup_epochs': WARMUP_EPOCHS, 'lr': LEARNING_RATE,
        'weight_decay': WEIGHT_DECAY, 'onecycle_max_lr': ONECYCLE_MAX_LR, 'onecycle_pct_start': ONECYCLE_PCT_START,
    },
    'sampling': {'ddim_eta': DDIM_ETA, 'ddim_steps': DDIM_STEPS},
    'paths': {'output_dir': OUTPUT_DIR},
}

datamodule = WaveletTimeSeriesDataModule(config=config)
print(f"✅ Data loaded: {datamodule.raw_data_tensor.shape}")
print(f"✅ Wavelet dimension: {datamodule.get_input_dim()}")

In [None]:
# @title Cell 6: Initialize Model
from src.models import WaveletDiffusionTransformer

model = WaveletDiffusionTransformer(data_module=datamodule, config=config)

if COMPILE_MODE is not None:
    print(f"Compiling model with mode='{COMPILE_MODE}'...")
    model = torch.compile(model, mode=COMPILE_MODE)
    print("✅ Model compiled")
else:
    print("✅ Model initialized (no compilation)")

In [None]:
# @title Cell 7: Run Training
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Timer
import os

os.makedirs(OUTPUT_DIR, exist_ok=True)

callbacks = [
    Timer(),
    ModelCheckpoint(
        dirpath=OUTPUT_DIR,
        filename='checkpoint-{epoch:02d}',
        save_top_k=SAVE_TOP_K,
        every_n_epochs=SAVE_EVERY_N_EPOCHS
    )
]

trainer = pl.Trainer(
    max_epochs=NUM_EPOCHS,
    accelerator=ACCELERATOR,
    devices=DEVICES,
    strategy="ddp_find_unused_parameters_true",
    precision=PRECISION,
    gradient_clip_val=GRADIENT_CLIP_VAL,
    gradient_clip_algorithm="norm",
    callbacks=callbacks,
    enable_checkpointing=True,
    logger=False,
    log_every_n_steps=LOG_EVERY_N_STEPS,
    enable_progress_bar=False
)

print("Starting training (Progress bar disabled to match source output style)...")
trainer.fit(model, datamodule)
print(f"✅ Training finished. Checkpoints saved to {OUTPUT_DIR}")

In [None]:
# @title Cell 8: Save Compile Cache to Drive
import os
import tarfile
import shutil

if COMPILE_MODE is not None:
    CACHE_ARCHIVE_NAME = "torchinductor_cache.tar.gz"
    LOCAL_CACHE_ARCHIVE = f"/content/{CACHE_ARCHIVE_NAME}"
    DRIVE_CACHE_ARCHIVE = f"{COMPILE_CACHE_DRIVE_PATH}/{CACHE_ARCHIVE_NAME}"
    
    if os.path.exists(COMPILE_CACHE_DIR) and os.listdir(COMPILE_CACHE_DIR):
        print("Packaging compile cache...")
        
        # Create tar.gz archive
        with tarfile.open(LOCAL_CACHE_ARCHIVE, "w:gz") as tar:
            for item in os.listdir(COMPILE_CACHE_DIR):
                item_path = os.path.join(COMPILE_CACHE_DIR, item)
                tar.add(item_path, arcname=item)
        
        # Copy to Drive
        os.makedirs(COMPILE_CACHE_DRIVE_PATH, exist_ok=True)
        shutil.copy(LOCAL_CACHE_ARCHIVE, DRIVE_CACHE_ARCHIVE)
        
        archive_size_mb = os.path.getsize(LOCAL_CACHE_ARCHIVE) / (1024 * 1024)
        print(f"✅ Cache saved to Drive ({archive_size_mb:.1f} MB)")
    else:
        print("No cache files to save")
else:
    print("torch.compile disabled, no cache to save")