# WaveletDiff Training (PyTorch Lightning)

This notebook uses PyTorch Lightning Trainer (matching the source repo exactly).

### Workflow:
1. **Configuration**: Set your parameters here.
2. **Setup**: Clones the repo and installs dependencies.
3. **Data**: Load and prepare data via DataModule.
4. **Model**: Initialize the WaveletDiffusionTransformer.
5. **Train**: Run training with pl.Trainer.


In [None]:
# @title Cell 1: Global Configuration
import random

# Experiment Identity
RUN_NAME = "stocks_baseline_v1" # @param {type:"string"}
RUN_ID =  f"{random.randint(0, 999):03d}"
UNIQUE_RUN_NAME = f"{RUN_NAME}_{RUN_ID}"

# Dataset Configuration
DATASET_NAME = "stocks"
DATA_PATH = "src/copied_waveletDiff/data/stocks/stock_data.csv" # @param {type:"string"}
SEQ_LEN = 24 # @param {type:"integer"}

# Model Hyperparameters
EMBED_DIM = 256 # @param {type:"integer"}
NUM_HEADS = 8 # @param {type:"integer"}
NUM_LAYERS = 8 # @param {type:"integer"}
TIME_EMBED_DIM = 128 # @param {type:"integer"}
DROPOUT = 0.1 # @param {type:"number"}
PREDICTION_TARGET = "noise"
USE_CROSS_LEVEL_ATTENTION = True # @param {type:"boolean"}

# Training Parameters
NUM_EPOCHS = 100 # @param {type:"integer"}
BATCH_SIZE = 512 # @param {type:"integer"}
LEARNING_RATE = 2e-4 # @param {type:"number"}
GRADIENT_CLIP_VAL = 1.0 # @param {type:"number"}

# Precision (matches source: "32" for FP32)
PRECISION = "32" # @param ["32", "bf16-mixed"]
MATMUL_PRECISION = "medium" # @param ["medium", "high"]

# torch.compile (optional, set to None to disable)
COMPILE_MODE = None # @param ["None", "default", "reduce-overhead", "max-autotune"]
COMPILE_FULLGRAPH = False # @param {type:"boolean"}

# Logging
LOG_EVERY_N_STEPS = 50 # @param {type:"integer"}
LOG_EVERY_N_EPOCHS = 1 # @param {type:"integer"}
ENABLE_PROGRESS_BAR = True # @param {type:"boolean"}

# Checkpointing
SAVE_EVERY_N_EPOCHS = 100 # @param {type:"integer"}

# Wavelet Configuration
WAVELET_TYPE = "db2" # @param {type:"string"}
WAVELET_LEVELS = "auto"

# Paths
DRIVE_BASE_PATH = "/content/drive/MyDrive/personal_drive/trading"
CHECKPOINT_DIR = f"{DRIVE_BASE_PATH}/checkpoints/temp/{UNIQUE_RUN_NAME}"
COMPILE_CACHE_DIR = f"{DRIVE_BASE_PATH}/compile_cache"
REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
REPO_DIR = "/content/waveletDiff_synth_data"

# Reproducibility Seed (set to None for random)
SEED = 42 # @param {type:"integer"}

In [None]:
# @title Cell 2: Setup (Clone, Install, Mount)
import os
import sys
import subprocess

# 1. Mount Drive
try:
    from google.colab import drive
    if os.path.exists('/content/drive'):
        if not os.listdir('/content/drive'):
            print("Force remounting Drive...")
            drive.mount('/content/drive', force_remount=True)
    else:
        drive.mount('/content/drive')
    print("✅ Drive mounted")
except ImportError:
    print("Not running on Colab. Skipping Drive mount.")
    DRIVE_BASE_PATH = "local_checkpoints"
    os.makedirs(DRIVE_BASE_PATH, exist_ok=True)
    COMPILE_CACHE_DIR = "local_compile_cache"

# 2. Clone/Update Repo
if not os.path.exists(REPO_DIR):
    print(f"Cloning {REPO_URL} into {REPO_DIR}...")
    subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
else:
    print(f"Repo exists at {REPO_DIR}, pulling latest changes...")
    subprocess.run(["git", "-C", REPO_DIR, "pull"], check=True)

# 3. Add to Path
if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)

# 4. Install Dependencies
print("Installing dependencies...")
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "pytorch-lightning", "pywavelets", "scipy", "pandas", "tqdm", "lightning"], check=True)
print("✅ Dependencies installed")

import importlib
importlib.invalidate_caches()
print("✅ Env Ready")

In [None]:
# @title Cell 3: Setup Environment
from src.torch_gpu_waveletDiff.train import trainer

trainer.setup_environment(matmul_precision=MATMUL_PRECISION, seed=SEED)

In [None]:
# @title Cell 4: Load Data
datamodule, config = trainer.get_datamodule(
    repo_dir=REPO_DIR,
    dataset_name=DATASET_NAME,
    seq_len=SEQ_LEN,
    batch_size=BATCH_SIZE,
    wavelet_type=WAVELET_TYPE,
    wavelet_levels=WAVELET_LEVELS,
    data_path=DATA_PATH
)

In [None]:
# @title Cell 5: Initialize Model
model, config = trainer.init_model(
    datamodule=datamodule,
    config=config,
    embed_dim=EMBED_DIM,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    time_embed_dim=TIME_EMBED_DIM,
    dropout=DROPOUT,
    prediction_target=PREDICTION_TARGET,
    use_cross_level_attention=USE_CROSS_LEVEL_ATTENTION,
    learning_rate=LEARNING_RATE,
    compile_mode=COMPILE_MODE if COMPILE_MODE != "None" else None,
    compile_fullgraph=COMPILE_FULLGRAPH,
    compile_cache_dir=COMPILE_CACHE_DIR
)

In [None]:
# @title Cell 6: Run Training
pl_trainer, trained_model = trainer.train(
    model=model,
    datamodule=datamodule,
    config=config,
    num_epochs=NUM_EPOCHS,
    precision=PRECISION,
    gradient_clip_val=GRADIENT_CLIP_VAL,
    log_every_n_steps=LOG_EVERY_N_STEPS,
    log_every_n_epochs=LOG_EVERY_N_EPOCHS,
    checkpoint_dir=CHECKPOINT_DIR,
    save_every_n_epochs=SAVE_EVERY_N_EPOCHS,
    enable_progress_bar=ENABLE_PROGRESS_BAR
)

In [None]:
# @title Cell 7: Persist Compilation Cache
# Run this cell after training (or after the first few steps) 
# to save your optimized kernels to Google Drive as a single archive.
trainer.persist_compilation_cache(COMPILE_CACHE_DIR)