# WaveletDiff Lightning Trainer (High Performance)

This notebook implements the training loop for WaveletDiff using the high-level **PyTorch Lightning Trainer**. 
It retains the valid Hardware/Performance optimizations of the native Fabric implementation while offering a more standardized API.

### Optimizations Preserved:
- **Dynamic Precision**: Automatically selects `bf16-true` (TPU) or `bf16-mixed` (GPU).
- **BF16 Pre-casting**: Data is cast to BF16 before the loop *only* on TPUs to maximize throughput.
- **Smart Gradient Clipping**: Disabled by default.

> **Note**: Designed for TPU `v4-8` or `v5e` but compatible with GPUs.

In [None]:
# @title Cell 1: Global Configuration
import os
from datetime import datetime

# Experiment Identity
RUN_NAME = "lightning_trainer_v1" # @param {type:"string"}
RUN_ID = datetime.now().strftime("%Y%m%d_%H%M%S")
UNIQUE_RUN_NAME = f"{RUN_NAME}_{RUN_ID}"

# Dataset Configuration
DATASET_NAME = "stocks"
SEQ_LEN = 24 # @param {type:"integer"}

# Model Hyperparameters
EMBED_DIM = 256 # @param {type:"integer"}
NUM_HEADS = 8 # @param {type:"integer"}
NUM_LAYERS = 8 # @param {type:"integer"}
TIME_EMBED_DIM = 128 # @param {type:"integer"}
DROPOUT = 0.1 # @param {type:"number"}
PREDICTION_TARGET = "noise"
USE_CROSS_LEVEL_ATTENTION = True # @param {type:"boolean"}

# Training Parameters
TOTAL_TRAINING_STEPS = 50000 # @param {type:"integer"}
BATCH_SIZE = 512 # @param {type:"integer"}
LEARNING_RATE = 2e-4 # @param {type:"number"}
WEIGHT_DECAY = 1e-5 # @param {type:"number"}
WARMUP_STEPS = 500 # @param {type:"integer"}
ENABLE_GRAD_CLIPPING = False # @param {type:"boolean"}

# Wavelet Configuration
WAVELET_TYPE = "db2" # @param {type:"string"}
WAVELET_LEVELS = "auto"

# Paths
DRIVE_BASE_PATH = "/content/drive/MyDrive/personal_drive/trading"
CHECKPOINT_DIR = f"{DRIVE_BASE_PATH}/checkpoints/{UNIQUE_RUN_NAME}"
REPO_DIR = "/content/WaveletDiff"

In [None]:
# @title Cell 2: Imports & Environment Setup
import sys
import os
import subprocess
import torch
import time
from pathlib import Path

# 1. Performance Polishing: Set Matmul Precision for GPUs (Ampere+/L4/A100)
if torch.cuda.is_available():
    # 'high' uses TensorFloat-32 (TF32) for significantly faster matmuls on L4/A100
    torch.set_float32_matmul_precision('high')

# 2. Conditional Dependency Installation
deps = ["lightning", "pywavelets", "scipy", "pandas", "tqdm"]

# STRICT CHECK: Only install torch_xla if explicitly on a TPU environment
is_tpu = 'COLAB_TPU_ADDR' in os.environ or 'TPU_NAME' in os.environ
if is_tpu and not any("torch_xla" in line for line in subprocess.getoutput("pip list").splitlines()):
    deps.append("torch_xla[tpu]")

try:
    import lightning.pytorch as pl
except ImportError:
    print(f"Installing dependencies ({', '.join(deps)})...")
    subprocess.run(["pip", "install"] + deps, check=True, stdout=subprocess.DEVNULL)
    import lightning.pytorch as pl

# 3. Dynamic Precision Detection
# Trainer accepts 'bf16-mixed' or 'bf16-true' via the 'precision' arg
if is_tpu:
    PRECISION = "bf16-true"
elif torch.cuda.is_available():
    PRECISION = "bf16-mixed"
else:
    PRECISION = "32-true"

# 4. Clone Repository
REPO_URL = "https://github.com/GarlicWang/WaveletDiff.git"
if not os.path.exists(REPO_DIR):
    print(f"Cloning {REPO_URL}...")
    subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True, stdout=subprocess.DEVNULL)

# 5. Add to System Path
sys.path.append(os.path.join(REPO_DIR, "src"))

# 6. Create Checkpoint Directory
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"Environment Ready. Accelerator: {'TPU' if is_tpu else 'GPU' if torch.cuda.is_available() else 'CPU'}")
print(f"Selected Precision: {PRECISION}")

In [None]:
# @title Cell 3: Data Loading & Optimization
import pandas as pd
from data.loaders import create_sliding_windows
from data.module import WaveletTimeSeriesDataModule

def prepare_optimized_datamodule():
    """Preloads data, creates windows, and optimizes type for HW"""
    stocks_path = os.path.join(REPO_DIR, "data", "stocks", "stock_data.csv")
    print(f"Loading data from {stocks_path}...")
        
    df = pd.read_csv(stocks_path)
    CORE_COLS = ['Open', 'High', 'Low', 'Close', 'Volume']
    df_filtered = df[CORE_COLS]

    # Create windows (Memory intensive, so we do it on CPU)
    custom_data_windows, _ = create_sliding_windows(
        df_filtered.values,
        seq_len=SEQ_LEN,
        normalize=True
    )

    full_data_tensor = torch.FloatTensor(custom_data_windows)
    
    # CONFIG
    full_config = {
        'dataset': {'name': DATASET_NAME, 'seq_len': SEQ_LEN},
        'training': {'batch_size': BATCH_SIZE, 'epochs': 1},
        'data': {'data_dir': os.path.join(REPO_DIR, "data"), 'normalize_data': False},
        'wavelet': {'type': WAVELET_TYPE, 'levels': WAVELET_LEVELS}
    }

    # Create DataModule (Handles Wavelet Transforms internally during init)
    # We pass the tensor explicitly so we don't reload inside.
    datamodule = WaveletTimeSeriesDataModule(config=full_config, data_tensor=full_data_tensor)
    
    # HARDWARE OPTIMIZATION
    # We directly access the underlying tensor in the datamodule and cast it if needed.
    # 1. Capture the original tensor
    data_tensor = datamodule.data_tensor
    
    # 2. Apply Casting Logic (Match Logic from Fabric Notebook)
    # If using bf16-true (TPU), we cast to bf16.
    # If using bf16-mixed (GPU), we keep as float32.
    if PRECISION == "bf16-true" and torch.device("cpu").type != "cpu": # Simple logic check, actually relies on is_tpu flag ideally
         # Better check: relies on the Global Configuration logic
         print(f"Optimizing: Casting data to bfloat16 for TPU execution...")
         data_tensor = data_tensor.to(torch.bfloat16)
    
    # 3. Re-assign to datamodule's dataset
    from torch.utils.data import TensorDataset
    datamodule.data_tensor = data_tensor
    datamodule.dataset = TensorDataset(data_tensor)
    
    return datamodule, full_config

datamodule, model_base_config = prepare_optimized_datamodule()
WAVELET_INFO = datamodule.get_wavelet_info()
INPUT_DIM = datamodule.get_input_dim()

In [None]:
# @title Cell 4: Model Initialization
from models.transformer import WaveletDiffusionTransformer

# Update Config
model_base_config.update({
    'model': {
        'embed_dim': EMBED_DIM,
        'num_heads': NUM_HEADS,
        'num_layers': NUM_LAYERS,
        'time_embed_dim': TIME_EMBED_DIM,
        'dropout': DROPOUT,
        'prediction_target': PREDICTION_TARGET
    },
    'attention': {'use_cross_level_attention': USE_CROSS_LEVEL_ATTENTION},
    'noise': {'schedule': "exponential"},
    'sampling': {'ddim_eta': 0.0, 'ddim_steps': None},
    'energy': {'weight': 0.0},
    'optimizer': {
        'scheduler_type': 'onecycle',
        'lr': LEARNING_RATE,
        'warmup_epochs': 5,
        'cosine_eta_min': 1e-6
    }
})

model = WaveletDiffusionTransformer(data_module=datamodule, config=model_base_config)

In [None]:
# @title Cell 5: Lightning Trainer Setup & Training
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import CSVLogger

# Callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath=CHECKPOINT_DIR,
    filename='{step}-{train_loss:.4f}',
    save_top_k=-1, # Save latest, or specific frequency
    every_n_train_steps=5000, # Matches SAVE_INTERVAL
    save_last=True
)

lr_monitor = LearningRateMonitor(logging_interval='step')

# Trainer
trainer = pl.Trainer(
    accelerator="auto",
    devices="auto",
    precision=PRECISION,
    max_steps=TOTAL_TRAINING_STEPS,
    logger=CSVLogger(save_dir=DRIVE_BASE_PATH, name=UNIQUE_RUN_NAME),
    callbacks=[checkpoint_callback, lr_monitor],
    enable_checkpointing=True,
    enable_progress_bar=True,
    gradient_clip_val=1.0 if ENABLE_GRAD_CLIPPING else None,
    log_every_n_steps=100 # Default moderate logging
)

# Start Training
print("Starting Training via pl.Trainer...")
trainer.fit(model, datamodule=datamodule)

In [None]:
# @title Cell 6: Evaluation (Sanity Check)

def sanity_check_sampling():
    # Move model to evaluation on appropriate device
    device = torch.device('cuda' if torch.cuda.is_available() else 'xla' if is_tpu else 'cpu')
    if is_tpu:
       import torch_xla.core.xla_model as xm
       device = xm.xla_device()
    
    model.to(device)
    model.eval()
    
    num_samples = 2
    print("Generating sanity check samples...")
    
    with torch.no_grad():
        # Start from random noise in wavelet domain
        shape = (num_samples, INPUT_DIM, WAVELET_INFO['n_features'])
        # Match dtype to model dtype (bf16 or float32)
        dtype = torch.bfloat16 if PRECISION.startswith('bf16') else torch.float32
        
        samples_wavelet = torch.randn(shape, device=device, dtype=dtype)
        
        # Reverse diffusion
        from tqdm.auto import tqdm
        for i in tqdm(reversed(range(model.T)), total=model.T, desc="Sampling"):
            t = torch.full((num_samples,), i, device=device, dtype=torch.long)
            t_norm = t.float() / model.T
            
            eps_theta = model(samples_wavelet, t_norm)
            
            # Standard DDPM update (simplified)
            alpha_t = model.alpha_all[t].view(-1, 1, 1)
            alpha_bar_t = model.alpha_bar_all[t].view(-1, 1, 1)
            beta_t = model.beta_all[t].view(-1, 1, 1)
            
            mean = (1 / torch.sqrt(alpha_t)) * (
                samples_wavelet - ((1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)) * eps_theta
            )
            
            if i > 0:
                noise = torch.randn_like(samples_wavelet)
                samples_wavelet = mean + torch.sqrt(beta_t) * noise
            else:
                samples_wavelet = mean

    print("Sampling complete. Converting to time series...")
    # Convert back to time series (ensure cpu float32 for reconstruction stability)
    samples_wavelet_cpu = samples_wavelet.float().cpu()
    samples_ts = datamodule.convert_wavelet_to_timeseries(samples_wavelet_cpu)
    print(f"Generated samples shape: {samples_ts.shape}")
    return samples_ts

samples = sanity_check_sampling()