# WaveletDiff Fabric Native Training (High Performance)

This notebook implements a clean, high-performance training loop for WaveletDiff using **Lightning Fabric** natively. 

### Key Optimizations:
- **Native Fabric Loop**: Removes overhead of high-level abstractions.
- **Smart Gradient Clipping**: Disabled by default for 2x TPU speedup (configurable).
- **Efficient Logging**: Metrics are only synchronized periodically (every 1% of progress) to prevent XLA graph breaks.
- **BF16 Pre-casting**: Data is cast to BF16 before the loop to maximize TPU throughput.

> **Note**: Designed for TPU `v4-8` or `v5e` but compatible with GPUs.

In [None]:
# @title Cell 1: Global Configuration
import os
from datetime import datetime

# Experiment Identity
RUN_NAME = "fabric_native_v1" # @param {type:"string"}
RUN_ID = datetime.now().strftime("%Y%m%d_%H%M%S")
UNIQUE_RUN_NAME = f"{RUN_NAME}_{RUN_ID}"

# Dataset Configuration
DATASET_NAME = "stocks"
SEQ_LEN = 24 # @param {type:"integer"}

# Model Hyperparameters
EMBED_DIM = 256 # @param {type:"integer"}
NUM_HEADS = 8 # @param {type:"integer"}
NUM_LAYERS = 8 # @param {type:"integer"}
TIME_EMBED_DIM = 128 # @param {type:"integer"}
DROPOUT = 0.1 # @param {type:"number"}
PREDICTION_TARGET = "noise"
USE_CROSS_LEVEL_ATTENTION = True # @param {type:"boolean"}

# Training Parameters
TOTAL_TRAINING_STEPS = 50000 # @param {type:"integer"}
BATCH_SIZE = 512 # @param {type:"integer"}
LEARNING_RATE = 2e-4 # @param {type:"number"}
WEIGHT_DECAY = 1e-5 # @param {type:"number"}
WARMUP_STEPS = 500 # @param {type:"integer"}
ENABLE_GRAD_CLIPPING = False # @param {type:"boolean"}

# Logging & Checkpointing
# Set dynamically to 1% of training progress
LOG_INTERVAL = max(1, int(TOTAL_TRAINING_STEPS * 0.01))
SAVE_INTERVAL = 5000 # @param {type:"integer"}

# Wavelet Configuration
WAVELET_TYPE = "db2" # @param {type:"string"}
WAVELET_LEVELS = "auto"

# Paths
DRIVE_BASE_PATH = "/content/drive/MyDrive/personal_drive/trading"
CHECKPOINT_DIR = f"{DRIVE_BASE_PATH}/checkpoints/{UNIQUE_RUN_NAME}"
REPO_DIR = "/content/WaveletDiff"

In [None]:
# @title Cell 2: Imports & Fabric Initialization
import sys
import os
import subprocess
import torch
import time
from pathlib import Path
from tqdm.auto import tqdm

# 1. Performance Polishing: Set Matmul Precision for GPUs (Ampere+/L4/A100)
if torch.cuda.is_available():
    # 'high' uses TensorFloat-32 (TF32) for significantly faster matmuls on L4/A100
    torch.set_float32_matmul_precision('high')

# 2. Conditional Dependency Installation
deps = ["lightning", "pywavelets", "scipy", "pandas", "tqdm"]

# STRICT CHECK: Only install torch_xla if explicitly on a TPU environment
is_tpu = 'COLAB_TPU_ADDR' in os.environ or 'TPU_NAME' in os.environ
if is_tpu and not any("torch_xla" in line for line in subprocess.getoutput("pip list").splitlines()):
    deps.append("torch_xla[tpu]")

try:
    import lightning.fabric
except ImportError:
    print(f"Installing dependencies ({', '.join(deps)})...")
    subprocess.run(["pip", "install"] + deps, check=True, stdout=subprocess.DEVNULL)

import lightning as L
from lightning.fabric import Fabric

# 3. Dynamic Precision Detection
if is_tpu:
    PRECISION = "bf16-true"
elif torch.cuda.is_available():
    PRECISION = "bf16-mixed"
else:
    PRECISION = "32-true"

# 4. Initialize Fabric
fabric = Fabric(accelerator="auto", devices="auto", precision=PRECISION)
fabric.launch()

# 5. Clone Repository
REPO_URL = "https://github.com/GarlicWang/WaveletDiff.git"
if not os.path.exists(REPO_DIR):
    if fabric.is_global_zero:
        print(f"Cloning {REPO_URL}...")
        subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True, stdout=subprocess.DEVNULL)
    fabric.barrier() # Wait for clone

# 6. Add to System Path
sys.path.append(os.path.join(REPO_DIR, "src"))

# 7. Create Checkpoint Directory
if fabric.is_global_zero:
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"[Rank {fabric.global_rank}] Fabric initialized on device: {fabric.device} with precision: {PRECISION}")

In [None]:
# @title Cell 3: Data Loading (Fabric Optimized)
import pandas as pd
from data.loaders import create_sliding_windows
from data.module import WaveletTimeSeriesDataModule
from torch.utils.data import TensorDataset, DataLoader

def get_dataloaders():
    """Setup DataModule and return optimal Fabric DataLoader"""
    stocks_path = os.path.join(REPO_DIR, "data", "stocks", "stock_data.csv")
    
    if fabric.is_global_zero:
        print(f"Loading data from {stocks_path}...")
        
    df = pd.read_csv(stocks_path)
    CORE_COLS = ['Open', 'High', 'Low', 'Close', 'Volume']
    df_filtered = df[CORE_COLS]

    # Create windows (Memory intensive, so we do it on CPU)
    custom_data_windows, _ = create_sliding_windows(
        df_filtered.values,
        seq_len=SEQ_LEN,
        normalize=True
    )

    full_data_tensor = torch.FloatTensor(custom_data_windows)
    
    # CONFIG
    full_config = {
        'dataset': {'name': DATASET_NAME, 'seq_len': SEQ_LEN},
        'training': {'batch_size': BATCH_SIZE, 'epochs': 1},
        'data': {'data_dir': os.path.join(REPO_DIR, "data"), 'normalize_data': False},
        'wavelet': {'type': WAVELET_TYPE, 'levels': WAVELET_LEVELS}
    }

    # Create DataModule (Handles Wavelet Transforms internally during init if needed, or we use it for metadata)
    datamodule = WaveletTimeSeriesDataModule(config=full_config, data_tensor=full_data_tensor)
    
    # XLA/GPU OPTIMIZATION: Cast to target dtype BEFORE creating dataset to avoid cast overhead in loop
    # If using bf16 precision ON TPU, we cast input data to bf16.
    # For GPU (bf16-mixed), we keep as float32 to satisfy AMP requirements.
    if PRECISION == "bf16-true" and fabric.device.type != "cpu":
        if fabric.is_global_zero: print(f"Optimizing: Casting data to bfloat16 for {fabric.device.type}...")
        full_data_tensor = full_data_tensor.to(torch.bfloat16)

    dataset = TensorDataset(full_data_tensor)
    
    # DataLoader
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True, # Important for XLA compilation stability
        num_workers=0,  # 0 is often safer for simple TensorDatasets to avoid fork overhead
        pin_memory=False 
    )
    
    # Fabric Setup (Handles Sharding/Distributed Sampler)
    loader = fabric.setup_dataloaders(loader)
    
    return loader, datamodule, full_config

train_loader, datamodule, model_base_config = get_dataloaders()
WAVELET_INFO = datamodule.get_wavelet_info()
INPUT_DIM = datamodule.get_input_dim()

In [None]:
# @title Cell 4: Model Initialization
from models.transformer import WaveletDiffusionTransformer

# Update Config
model_base_config.update({
    'model': {
        'embed_dim': EMBED_DIM,
        'num_heads': NUM_HEADS,
        'num_layers': NUM_LAYERS,
        'time_embed_dim': TIME_EMBED_DIM,
        'dropout': DROPOUT,
        'prediction_target': PREDICTION_TARGET
    },
    'attention': {'use_cross_level_attention': USE_CROSS_LEVEL_ATTENTION},
    'noise': {'schedule': "exponential"},
    'sampling': {'ddim_eta': 0.0, 'ddim_steps': None},
    'energy': {'weight': 0.0},
    'optimizer': {
        'scheduler_type': 'onecycle',
        'lr': LEARNING_RATE,
        'warmup_epochs': 5,
        'cosine_eta_min': 1e-6
    }
})

def init_system():
    model = WaveletDiffusionTransformer(data_module=datamodule, config=model_base_config)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # Fabric Setup
    model, optimizer = fabric.setup(model, optimizer)
    return model, optimizer

model, optimizer = init_system()

In [None]:
# @title Cell 5: Native Fabric Training Loop

def train_loop():
    # MATCHING SOURCE REPO: OneCycleLR with peak LR at 5x base (1e-3) and pct_start at 0.3 for stocks
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=LEARNING_RATE * 5, 
        total_steps=TOTAL_TRAINING_STEPS, 
        pct_start=0.3
    )

    # Iterator for infinite steps
    train_iter = iter(train_loader)
    
    # Progress bar ONLY on rank 0
    if fabric.is_global_zero:
        pbar = tqdm(range(TOTAL_TRAINING_STEPS), desc=f"{fabric.device.type.upper()} Training")
    else:
        pbar = range(TOTAL_TRAINING_STEPS)

    model.train()
    t0 = time.time()

    for step in pbar:
        # 1. Fetch Batch
        try:
            batch = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            batch = next(train_iter)
        
        x_0 = batch[0]

        # 2. Forward & Loss
        optimizer.zero_grad()
        
        # Sample time steps
        t = torch.randint(0, model.T, (x_0.size(0),), device=fabric.device)
        
        loss = model.compute_loss(x_0, t)
        
        # 3. Backward
        fabric.backward(loss)
        
        # 4. Optional Clip Gradients (SLOW on TPU if done every step)
        if ENABLE_GRAD_CLIPPING:
            fabric.clip_gradients(model, optimizer, max_norm=1.0)
            
        optimizer.step()
        scheduler.step()

        # 5. Logging (Async-friendly)
        # Happens every 1% of progress to keep training history readable
        if (step + 1) % LOG_INTERVAL == 0:
            # This forces a sync. 
            loss_val = loss.item()
            
            if fabric.is_global_zero:
                elapsed = time.time() - t0
                steps_per_sec = LOG_INTERVAL / elapsed
                t0 = time.time()
                
                # PERSISTENT LOGGING: Actually print the values to keep history
                pct = ((step + 1) / TOTAL_TRAINING_STEPS) * 100
                fabric.print(f"[Step {step+1:5d} | {pct:3.0f}%] loss: {loss_val:.4f} | spd: {steps_per_sec:.2f} it/s")
                pbar.set_postfix({"loss": f"{loss_val:.4f}", "spd": f"{steps_per_sec:.2f}it/s"})

        # 6. Checkpointing
        if (step + 1) % SAVE_INTERVAL == 0:
            save_path = os.path.join(CHECKPOINT_DIR, f"step_{step+1}.ckpt")
            state = {"model": model.state_dict(), "optimizer": optimizer.state_dict()}
            fabric.save(save_path, state)
            if fabric.is_global_zero:
                print(f"\nSaved checkpoint to {save_path}")

    print("Training Finished.")

if __name__ == "__main__":
    train_loop()

In [None]:
# @title Cell 6: Brief Evaluation (Sanity Check)

def sanity_check_sampling():
    # Generate a few samples to ensure model learned something
    model.eval()
    num_samples = 2
    print("Generating sanity check samples...")
    
    with torch.no_grad():
        # Start from random noise in wavelet domain
        shape = (num_samples, INPUT_DIM, WAVELET_INFO['n_features'])
        samples_wavelet = torch.randn(shape, device=fabric.device, dtype=torch.bfloat16)
        
        # Reverse diffusion
        for i in tqdm(reversed(range(model.T)), total=model.T, desc="Sampling", disable=not fabric.is_global_zero):
            t = torch.full((num_samples,), i, device=fabric.device, dtype=torch.long)
            t_norm = t.float() / model.T
            
            eps_theta = model(samples_wavelet, t_norm)
            
            # Standard DDPM update (simplified)
            alpha_t = model.alpha_all[t].view(-1, 1, 1)
            alpha_bar_t = model.alpha_bar_all[t].view(-1, 1, 1)
            beta_t = model.beta_all[t].view(-1, 1, 1)
            
            mean = (1 / torch.sqrt(alpha_t)) * (
                samples_wavelet - ((1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)) * eps_theta
            )
            
            if i > 0:
                noise = torch.randn_like(samples_wavelet)
                samples_wavelet = mean + torch.sqrt(beta_t) * noise
            else:
                samples_wavelet = mean

    print("Sampling complete. Converting to time series...")
    # Convert back to time series (ensure cpu float32 for reconstruction stability)
    samples_wavelet_cpu = samples_wavelet.float().cpu()
    samples_ts = datamodule.convert_wavelet_to_timeseries(samples_wavelet_cpu)
    print(f"Generated samples shape: {samples_ts.shape}")
    return samples_ts

if fabric.is_global_zero:
    samples = sanity_check_sampling()