# WaveletDiff Fabric Native Training (High Performance)

This notebook implements a clean, high-performance training loop for WaveletDiff using **Lightning Fabric** natively.

### Key Optimizations:
- **Native Fabric Loop**: Removes overhead of high-level abstractions.
- **Smart Gradient Clipping**: Disabled by default for 2x TPU speedup (configurable).
- **Efficient Logging**: Metrics are only synchronized periodically (every 1% of progress) to prevent XLA graph breaks.
- **BF16 Pre-casting**: Data is cast to BF16 before the loop to maximize TPU throughput.

> **Note**: Designed for TPU `v4-8` or `v5e` but compatible with GPUs.

In [None]:
# @title Cell 1: Global Configuration
import os
from datetime import datetime
import random

# Experiment Identity
RUN_NAME = "stocks_ohlcv_v1" # @param {type:"string"}
RUN_ID =  f"{random.randint(0, 999):03d}"
UNIQUE_RUN_NAME = f"{RUN_NAME}_{RUN_ID}"

# Dataset Configuration
DATASET_NAME = "stocks"
SEQ_LEN = 24 # @param {type:"integer"}

# Model Hyperparameters
EMBED_DIM = 256 # @param {type:"integer"}
NUM_HEADS = 8 # @param {type:"integer"}
NUM_LAYERS = 8 # @param {type:"integer"}
TIME_EMBED_DIM = 128 # @param {type:"integer"}
DROPOUT = 0.1 # @param {type:"number"}
PREDICTION_TARGET = "noise"
USE_CROSS_LEVEL_ATTENTION = True # @param {type:"boolean"}

# Training Parameters
TOTAL_TRAINING_STEPS = 15000 # @param {type:"integer"}
BATCH_SIZE = 512 # @param {type:"integer"}
LEARNING_RATE = 2e-4 # @param {type:"number"}
WEIGHT_DECAY = 1e-5 # @param {type:"number"}
WARMUP_STEPS = int(.05 * TOTAL_TRAINING_STEPS) # @param
ENABLE_GRAD_CLIPPING = True # @param {type:"boolean"}

# === PROFILER CONFIGURATION (NEW) ===
ENABLE_PROFILER = False # @param {type:"boolean"}
PROFILER_WARMUP = 25
PROFILER_ACTIVE = 5

# Logging & Checkpointing
# Set dynamically to 1% of training progress
LOG_INTERVAL = max(1, int(TOTAL_TRAINING_STEPS * 0.01))
SAVE_INTERVAL = 5000 # @param {type:"integer"}

# Wavelet Configuration
WAVELET_TYPE = "db2" # @param {type:"string"}
WAVELET_LEVELS = "auto"

# Paths
DRIVE_BASE_PATH = "/content/drive/MyDrive/personal_drive/trading"
CHECKPOINT_DIR = f"{DRIVE_BASE_PATH}/checkpoints/{UNIQUE_RUN_NAME}"
REPO_DIR = "/content/WaveletDiff"

In [None]:
# @title Cell 2: Imports & Fabric Initialization
import sys
import os
import subprocess
import torch
import time
from pathlib import Path
from tqdm.auto import tqdm

# === CRITICAL: FORCE GOOGLE DRIVE MOUNT ===
# We do this FIRST to ensure checkpoints go to persistent storage.
try:
    from google.colab import drive
    if os.path.exists('/content/drive'):
        # Check if it's actually mounted (if empty, it might be a ghost dir)
        if not os.listdir('/content/drive'):
            print("Detected ghost directory. Force remounting...")
            drive.mount('/content/drive', force_remount=True)
    else:
        drive.mount('/content/drive')
except ImportError:
    print("Not running on Colab. Skipping Drive mount.")
except Exception as e:
    print(f"Drive mount warning: {e}")

# 1. Performance Polishing: Set Matmul Precision for GPUs (Ampere+/L4/A100)
if torch.cuda.is_available():
    # 'high' uses TensorFloat-32 (TF32) for significantly faster matmuls on L4/A100
    torch.set_float32_matmul_precision('high')

# 2. Conditional Dependency Installation
deps = ["lightning", "pywavelets", "scipy", "pandas", "tqdm"]

# STRICT CHECK: Only install torch_xla if explicitly on a TPU environment
is_tpu = 'COLAB_TPU_ADDR' in os.environ or 'TPU_NAME' in os.environ
if is_tpu and not any("torch_xla" in line for line in subprocess.getoutput("pip list").splitlines()):
    deps.append("torch_xla[tpu]")

try:
    import lightning.fabric
except ImportError:
    print(f"Installing dependencies ({', '.join(deps)})...")
    subprocess.run(["pip", "install"] + deps, check=True, stdout=subprocess.DEVNULL)

import lightning as L
from lightning.fabric import Fabric

# 3. Dynamic Precision Detection
if is_tpu:
    PRECISION = "bf16-true"
elif torch.cuda.is_available():
    PRECISION = "bf16-mixed"
else:
    PRECISION = "bf16-true"

# 4. Initialize Fabric
# Note: logical_cpu_count is safer for shared environments
fabric = Fabric(accelerator="auto", devices="auto", precision=PRECISION)
fabric.launch()

# 5. Clone Repository
# We need to ensure REPO_DIR is defined. It usually comes from Cell 1.
# We assume Cell 1 has run.
if 'REPO_DIR' not in globals():
    REPO_DIR = "/content/WaveletDiff" # Fallback

REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
if not os.path.exists(REPO_DIR):
    if fabric.is_global_zero:
        print(f"Cloning {REPO_URL}...")
        subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True, stdout=subprocess.DEVNULL)
    fabric.barrier() # Wait for clone

# 6. Add to System Path (Revised for new repo structure)
sys.path.append(os.path.join(REPO_DIR, "WaveletDiff_source", "src"))

# 7. Create Checkpoint Directory
# We assume CHECKPOINT_DIR is defined in Cell 1.
if 'CHECKPOINT_DIR' in globals():
    if fabric.is_global_zero:
        os.makedirs(CHECKPOINT_DIR, exist_ok=True)
        print(f"Checkpoint Directory Verified: {CHECKPOINT_DIR}")

print(f"[Rank {fabric.global_rank}] Fabric initialized on device: {fabric.device} with precision: {PRECISION}")

  _C._set_float32_matmul_precision(precision)
INFO: Using bfloat16 Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)


Checkpoint Directory Verified: /content/drive/MyDrive/personal_drive/trading/checkpoints/stocks_ohlcv_v1_865
[Rank 0] Fabric initialized on device: cuda:0 with precision: bf16-mixed


In [None]:
# @title Cell 3: Data Loading (Fabric Optimized)
import pandas as pd
from data.loaders import create_sliding_windows
from data.module import WaveletTimeSeriesDataModule
from torch.utils.data import TensorDataset, DataLoader
import multiprocessing

def get_dataloaders():
    """Setup DataModule and return optimal Fabric DataLoader"""
    stocks_path = os.path.join(REPO_DIR, "WaveletDiff_source", "data", "stocks", "stock_data.csv")

    if fabric.is_global_zero:
        print(f"Loading data from {stocks_path}...")

    df = pd.read_csv(stocks_path)
    CORE_COLS = ['Open', 'High', 'Low', 'Close', 'Volume']
    df_filtered = df[CORE_COLS]

    # Create windows (Memory intensive, so we do it on CPU)
    custom_data_windows, _ = create_sliding_windows(
        df_filtered.values,
        seq_len=SEQ_LEN,
        normalize=True
    )

    full_data_tensor = torch.FloatTensor(custom_data_windows)

    # CONFIG
    full_config = {
        'dataset': {'name': DATASET_NAME, 'seq_len': SEQ_LEN},
        'training': {'batch_size': BATCH_SIZE, 'epochs': 1},
        'data': {'data_dir': os.path.join(REPO_DIR, "WaveletDiff_source", "data"), 'normalize_data': False},
        'wavelet': {'type': WAVELET_TYPE, 'levels': WAVELET_LEVELS}
    }

    # Create DataModule (Handles Wavelet Transforms internally during init if needed, or we use it for metadata)
    datamodule = WaveletTimeSeriesDataModule(config=full_config, data_tensor=full_data_tensor)

    # XLA/GPU OPTIMIZATION: Cast to target dtype BEFORE creating dataset to avoid cast overhead in loop
    if PRECISION == "bf16-true" and fabric.device.type != "cpu":
        if fabric.is_global_zero: print(f"Optimizing: Casting data to bfloat16 for {fabric.device.type}...")
        full_data_tensor = full_data_tensor.to(torch.bfloat16)

    dataset = TensorDataset(full_data_tensor)

    # WORKER CONFIGURATION
    # We use minimal workers to avoid overhead, but >0 to ensure prefetching.
    # On Colab (2 vCPU), usually 2 is max. On larger nodes, 4-8 is good.
    cpu_count = multiprocessing.cpu_count()
    num_workers = min(4, max(0, cpu_count - 2))
    print(f"Using {num_workers} num_workers...")

    # DataLoader
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True, # Important for XLA compilation stability
        num_workers=num_workers,
        pin_memory=True if fabric.device.type == "cuda" else False, # Pin only on CUDA
        persistent_workers=True if num_workers > 0 else False
    )

    # Fabric Setup (Handles Sharding/Distributed Sampler)
    loader = fabric.setup_dataloaders(loader)

    return loader, datamodule, full_config

train_loader, datamodule, model_base_config = get_dataloaders()
WAVELET_INFO = datamodule.get_wavelet_info()
INPUT_DIM = datamodule.get_input_dim()

Loading data from /content/WaveletDiff/WaveletDiff_source/data/stocks/stock_data.csv...
Raw Data Tensor Shape: torch.Size([3662, 24, 5])
Converting to wavelet coefficients with 3 levels...
Coefficient shapes per level: [(5,), (5,), (8,), (13,)]
Level dimensions: [np.int64(5), np.int64(5), np.int64(8), np.int64(13)]
Total coefficients per feature: 31
Converted torch.Size([3662, 24, 5]) time series to torch.Size([3662, 31, 5]) wavelet coefficients
Wavelet: db2, Levels: 3
Using 4 num_workers...


In [None]:
# @title Cell 4: Model Initialization (Optimized Sandwich)
from models.transformer import WaveletDiffusionTransformer

# Update Config
model_base_config.update({
    'model': {
        'embed_dim': EMBED_DIM,
        'num_heads': NUM_HEADS,
        'num_layers': NUM_LAYERS,
        'time_embed_dim': TIME_EMBED_DIM,
        'dropout': DROPOUT,
        'prediction_target': PREDICTION_TARGET
    },
    'attention': {'use_cross_level_attention': USE_CROSS_LEVEL_ATTENTION},
    'noise': {'schedule': "exponential"},
    'sampling': {'ddim_eta': 0.0, 'ddim_steps': None},
    'energy': {'weight': 0.0},
    'optimizer': {
        'scheduler_type': 'onecycle',
        'lr': LEARNING_RATE,
        'warmup_epochs': 5,
        'cosine_eta_min': 1e-6
    }
})

def init_system():
    # 1. Instantiate (CPU)
    model = WaveletDiffusionTransformer(data_module=datamodule, config=model_base_config)

    # 2. DEVICE PLACEMENT (Manual)
    # Critical: Move to device BEFORE compiling so we generate CUDA graphs, not CPU graphs.
    if fabric.is_global_zero: print(f"[Rank 0] Moving model to {fabric.device}...")
    model.to(fabric.device)

    # 3. COMPILATION (Before Setup)
    # We compile the raw model. If this crashes, change mode to "default".
    if fabric.device.type == "cuda":
        # 'reduce-overhead' is critical for fixing the dispatch bottleneck (High CPU/GPU ratio)
        # It uses CUDA Graphs. If this hangs > 3 mins, switch to 'default' or 'max-autotune'.
        COMPILE_MODE = "reduce-overhead"

        print(f"[Rank 0] Applying torch.compile(mode='{COMPILE_MODE}')...")
        print("[Rank 0] Note: You will see a ~60s delay at Step 0 while kernels are built.")

        try:
            model = torch.compile(model, mode=COMPILE_MODE)
        except Exception as e:
            print(f"[WARNING] Compilation failed: {e}. Falling back to eager execution.")

    # 4. OPTIMIZER (Fused)
    # 'fused=True' collapses optimizer kernels (huge win for kernel overhead)
    use_fused = fabric.device.type == "cuda"
    if fabric.is_global_zero and use_fused: print("[Rank 0] Using Fused AdamW...")

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        fused=use_fused
    )

    # 5. FABRIC SETUP
    # Fabric will see the compiled model and wrap it appropriately.
    model, optimizer = fabric.setup(model, optimizer)

    # 6. REGISTER CUSTOM ENTRY POINT (CRITICAL FIX)
    # Fabric requires us to whitelist any method that acts like forward()
    # so it can apply the correct strategies (DDP sync, precision, etc.)
    model.mark_forward_method('compute_loss')

    return model, optimizer

model, optimizer = init_system()

Initialized coefficient_weighted wavelet loss:
  Level 0: 5 coeffs, weight=0.4988
  Level 1: 5 coeffs, weight=0.2494
  Level 2: 8 coeffs, weight=0.1559
  Level 3: 13 coeffs, weight=0.0959
Using coefficient_weighted loss strategy (no energy term)
Created 4 level-specific transformers (Channel-based):
  Level 0: 5 coefficients, 512 embed_dim, 10 layers
  Level 1: 5 coefficients, 256 embed_dim, 8 layers
  Level 2: 8 coefficients, 256 embed_dim, 8 layers
  Level 3: 13 coefficients, 256 embed_dim, 8 layers
Cross-level attention enabled with common dimension: 512

WAVELET DIFFUSION TRANSFORMER MODEL INFO
Dataset: stocks
Input dimension: 31
Embedding dimension: 256
Time embedding dimension: 128
Prediction target: noise
Noise schedule: exponential
Cross-level attention: Enabled
Number of wavelet levels: 4

Wavelet level details:
  Level 0: 5 coefficients, shape (5,)
  Level 1: 5 coefficients, shape (5,)
  Level 2: 8 coefficients, shape (8,)
  Level 3: 13 coefficients, shape (13,)

Energy Loss 

In [None]:
# @title Cell 5: Native Fabric Training Loop (Hardened)
from contextlib import nullcontext
from torch.profiler import profile, record_function, ProfilerActivity, schedule, tensorboard_trace_handler
import math

def train_loop():
    # 1. Determine Effective Steps
    if ENABLE_PROFILER:
        effective_steps = PROFILER_WARMUP + PROFILER_ACTIVE + 2
        print(f"[Rank {fabric.global_rank}] PROFILER ENABLED: Overriding total steps to {effective_steps}")
    else:
        effective_steps = TOTAL_TRAINING_STEPS

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=LEARNING_RATE * 4,
        total_steps=TOTAL_TRAINING_STEPS,
        pct_start=0.3
    )

    train_iter = iter(train_loader)

    if fabric.is_global_zero:
        desc = "PROFILING" if ENABLE_PROFILER else f"{fabric.device.type.upper()} Training"
        pbar = tqdm(range(effective_steps), desc=desc)
    else:
        pbar = range(effective_steps)

    model.train()
    t0 = time.time()
    running_loss = 0.0

    # 2. Profiler Context
    if ENABLE_PROFILER:
        activities = [ProfilerActivity.CPU]
        if torch.cuda.is_available(): activities.append(ProfilerActivity.CUDA)
        prof_schedule = schedule(wait=1, warmup=PROFILER_WARMUP, active=PROFILER_ACTIVE, repeat=1)
        handler = tensorboard_trace_handler('./log/profiler')
        profiler_ctx = profile(
            activities=activities, schedule=prof_schedule, on_trace_ready=handler,
            record_shapes=False, profile_memory=False, with_stack=True
        )
    else:
        profiler_ctx = nullcontext()

    # 3. Main Loop
    with profiler_ctx as prof:
        for step in pbar:

            # MARKER: Data Loading
            with record_function("data_loading"):
                try:
                    batch = next(train_iter)
                except StopIteration:
                    train_iter = iter(train_loader)
                    batch = next(train_iter)
                x_0 = batch[0]

            # MARKER: Forward & Loss
            optimizer.zero_grad()
            with record_function("forward_pass"):
                t = torch.randint(0, model.T, (x_0.size(0),), device=fabric.device)
                loss = model.compute_loss(x_0, t)

            # MARKER: Backward
            with record_function("backward_pass"):
                fabric.backward(loss)

            # MARKER: Optimization
            with record_function("optimizer_step"):
                if ENABLE_GRAD_CLIPPING:
                    # CRITICAL FIX: error_if_nonfinite=False
                    # If gradients explode to Inf, this scales them back to max_norm (1.0)
                    # instead of crashing the training run.
                    fabric.clip_gradients(model, optimizer, max_norm=1.0, error_if_nonfinite=False)

                optimizer.step()
                scheduler.step()

            # 4. Logging
            # Accumulate running loss
            # loss.item() syncs CPU/GPU, but we need it for logging anyway.
            current_loss = loss.item()

            running_loss += current_loss

            if (step + 1) % LOG_INTERVAL == 0 and not ENABLE_PROFILER:
                avg_loss = running_loss / LOG_INTERVAL
                if fabric.world_size > 1:
                    avg_loss = fabric.all_reduce(avg_loss, reduce_op="mean")

                if fabric.is_global_zero:
                    current_lr = scheduler.get_last_lr()[0]
                    pct = ((step + 1) / TOTAL_TRAINING_STEPS) * 100
                    fabric.print(f"[Step {step+1:5d} | {pct:3.0f}%] loss: {avg_loss:.4f} | lr: {current_lr:.2e}")
                    pbar.set_postfix({"loss": f"{avg_loss:.4f}", "lr": f"{current_lr:.2e}"})
                    sys.stdout.flush()

                running_loss = 0.0

            # 5. Checkpointing
            if (step + 1) % SAVE_INTERVAL == 0 and not ENABLE_PROFILER:
                save_path = os.path.join(CHECKPOINT_DIR, f"step_{step+1}.ckpt")
                state = {"model": model.state_dict(), "optimizer": optimizer.state_dict()}
                fabric.save(save_path, state)
                if fabric.is_global_zero:
                    print(f"\nSaved checkpoint to {save_path}", flush=True)

            # 6. Step Profiler
            if ENABLE_PROFILER:
                prof.step()

    print("Training/Profiling Finished.")

    if ENABLE_PROFILER and fabric.is_global_zero:
        print("\n" + "="*80)
        print("PROFILING QUICK REPORT")
        print("="*80)
        print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=15))

if __name__ == "__main__":
    train_loop()

CUDA Training:   0%|          | 0/15000 [00:00<?, ?it/s]

[Step   150 |   1%] loss: 0.5553 | lr: 3.41e-05
[Step   300 |   2%] loss: 0.1828 | lr: 4.04e-05
[Step   450 |   3%] loss: 0.1317 | lr: 5.08e-05
[Step   600 |   4%] loss: 0.1030 | lr: 6.52e-05
[Step   750 |   5%] loss: 0.0858 | lr: 8.35e-05
[Step   900 |   6%] loss: 0.0749 | lr: 1.05e-04
[Step  1050 |   7%] loss: 0.0660 | lr: 1.31e-04
[Step  1200 |   8%] loss: 0.0621 | lr: 1.59e-04
[Step  1350 |   9%] loss: 0.0583 | lr: 1.90e-04
[Step  1500 |  10%] loss: 0.0555 | lr: 2.24e-04
[Step  1650 |  11%] loss: 0.0535 | lr: 2.60e-04
[Step  1800 |  12%] loss: 0.0518 | lr: 2.97e-04
[Step  1950 |  13%] loss: 0.0505 | lr: 3.36e-04
[Step  2100 |  14%] loss: 0.0495 | lr: 3.76e-04
[Step  2250 |  15%] loss: 0.0485 | lr: 4.16e-04
[Step  2400 |  16%] loss: 0.0472 | lr: 4.56e-04
[Step  2550 |  17%] loss: 0.0471 | lr: 4.96e-04
[Step  2700 |  18%] loss: 0.0470 | lr: 5.35e-04
[Step  2850 |  19%] loss: 0.0453 | lr: 5.72e-04
[Step  3000 |  20%] loss: 0.0450 | lr: 6.08e-04
[Step  3150 |  21%] loss: 0.0454 | lr: 6

In [None]:
# @title Cell 6: Brief Evaluation (Sanity Check)

def sanity_check_sampling():
    # Generate a few samples to ensure model learned something
    model.eval()
    num_samples = 2
    print("Generating sanity check samples...")

    with torch.no_grad():
        # Start from random noise in wavelet domain
        shape = (num_samples, INPUT_DIM, WAVELET_INFO['n_features'])
        samples_wavelet = torch.randn(shape, device=fabric.device, dtype=torch.bfloat16)

        # Reverse diffusion
        for i in tqdm(reversed(range(model.T)), total=model.T, desc="Sampling", disable=not fabric.is_global_zero):
            t = torch.full((num_samples,), i, device=fabric.device, dtype=torch.long)
            t_norm = t.float() / model.T

            eps_theta = model(samples_wavelet, t_norm)

            # Standard DDPM update (simplified)
            alpha_t = model.alpha_all[t].view(-1, 1, 1)
            alpha_bar_t = model.alpha_bar_all[t].view(-1, 1, 1)
            beta_t = model.beta_all[t].view(-1, 1, 1)

            mean = (1 / torch.sqrt(alpha_t)) * (
                samples_wavelet - ((1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)) * eps_theta
            )

            if i > 0:
                noise = torch.randn_like(samples_wavelet)
                samples_wavelet = mean + torch.sqrt(beta_t) * noise
            else:
                samples_wavelet = mean

    print("Sampling complete. Converting to time series...")
    # Convert back to time series (ensure cpu float32 for reconstruction stability)
    samples_wavelet_cpu = samples_wavelet.float().cpu()
    samples_ts = datamodule.convert_wavelet_to_timeseries(samples_wavelet_cpu)
    print(f"Generated samples shape: {samples_ts.shape}")
    return samples_ts

if fabric.is_global_zero:
    samples = sanity_check_sampling()

Generating sanity check samples...


Sampling:   0%|          | 0/1000 [00:00<?, ?it/s]

Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.



Sampling complete. Converting to time series...
Generated samples shape: torch.Size([2, 24, 5])
