# WaveletDiff Fabric Training (Frontend Notebook)

This notebook serves as the frontend for the WaveletDiff training pipeline. The heavy lifting is handled by `src.torch_gpu_waveletDiff.train.trainer`.

### Workflow:
1. **Configuration**: Set your parameters here.
2. **Setup**: Clones the repo and installs dependencies.
3. **Initalization**: Sets up Fabric, Model, and Data via the backend.
4. **Execution**: Runs the high-performance training loop.


In [None]:
# @title Cell 1: Global Configuration
import random

# Experiment Identity
RUN_NAME = "stocks_ohlcv_v1" # @param {type:"string"}
RUN_ID =  f"{random.randint(0, 999):03d}"
UNIQUE_RUN_NAME = f"{RUN_NAME}_{RUN_ID}"

# Dataset Configuration
DATASET_NAME = "stocks"
DATA_PATH = "src/copied_waveletDiff/data/stocks/stock_data.csv" # @param {type:"string"}
SEQ_LEN = 24 # @param {type:"integer"}

# Model Hyperparameters
EMBED_DIM = 256 # @param {type:"integer"}
NUM_HEADS = 8 # @param {type:"integer"}
NUM_LAYERS = 8 # @param {type:"integer"}
TIME_EMBED_DIM = 128 # @param {type:"integer"}
DROPOUT = 0.1 # @param {type:"number"}
PREDICTION_TARGET = "noise"
USE_CROSS_LEVEL_ATTENTION = True # @param {type:"boolean"}

# Training Parameters
TOTAL_TRAINING_STEPS = 35000 # @param {type:"integer"}
BATCH_SIZE = 512 # @param {type:"integer"}
LEARNING_RATE = 2e-4 # @param {type:"number"}
WEIGHT_DECAY = 1e-5 # @param {type:"number"}
ENABLE_GRAD_CLIPPING = True # @param {type:"boolean"}
# ^ CRITICAL: Enabled for stability (prevents exploding gradients). Small perf check on TPU.

# Precision
PRECISION = "bf16-mixed" # @param ["bf16-mixed", "32-true"]
MATMUL_PRECISION = "high" # @param ["medium", "high"]

# Profiler
ENABLE_PROFILER = False # @param {type:"boolean"}

# Optimizer Configuration
# Matched to source repo defaults for OneCycleLR
MAX_LR = 1e-3 # @param {type:"number"}
PCT_START = 0.3 # @param {type:"number"}

# Checkpointing
SAVE_INTERVAL = 5000 # @param {type:"integer"}

# Wavelet Configuration
WAVELET_TYPE = "db2" # @param {type:"string"}
WAVELET_LEVELS = "auto"

# Paths
DRIVE_BASE_PATH = "/content/drive/MyDrive/personal_drive/trading"
CHECKPOINT_DIR = f"{DRIVE_BASE_PATH}/checkpoints/{UNIQUE_RUN_NAME}"
REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
REPO_DIR = "/content/waveletDiff_synth_data"

In [None]:
# @title Cell 2: Setup (Clone, Install, Mount)
import os
import sys
import subprocess
import shutil

# 1. Mount Drive
try:
    from google.colab import drive
    if os.path.exists('/content/drive'):
        if not os.listdir('/content/drive'):
            print("Detected ghost directory. Force remounting...")
            drive.mount('/content/drive', force_remount=True)
    else:
        drive.mount('/content/drive')
except ImportError:
    print("Not running on Colab. Skipping Drive mount.")

# 2. Force Clone Repository
if os.path.exists(REPO_DIR):
    print(f"Repo exists at {REPO_DIR}, pulling changes...")
    subprocess.run(["git", "-C", REPO_DIR, "pull"], check=True)
else:
    print(f"Cloning {REPO_URL} into {REPO_DIR}...")
    subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)

# 3. Install Dependencies
print("Installing dependencies (this may take a minute)...")
deps = ["lightning", "pywavelets", "scipy", "pandas", "tqdm"]
is_tpu = 'COLAB_TPU_ADDR' in os.environ or 'TPU_NAME' in os.environ
if is_tpu and not any("torch_xla" in line for line in subprocess.getoutput("pip list").splitlines()):
    deps.append("torch_xla[tpu]")

try:
    import lightning.fabric
except ImportError:
    subprocess.run(["pip", "install"] + deps, check=True)

# 4. Setup Paths
# Critical: We add the repo to path so we can import src.torch_gpu_waveletDiff
if REPO_DIR not in sys.path:
    sys.path.append(REPO_DIR)
# Also add the inner source for legacy imports inside the trainer
source_path = os.path.join(REPO_DIR, "src", "copied_waveletDiff", "src")
if source_path not in sys.path:
    sys.path.append(source_path)

print("Setup Complete.")

In [None]:
# @title Cell 3: Initialize Fabric
from src.torch_gpu_waveletDiff.train import trainer

fabric = trainer.setup_fabric(precision=PRECISION, matmul_precision=MATMUL_PRECISION)

In [None]:
# @title Cell 4: Load Data
train_loader, datamodule, config = trainer.get_dataloaders(
    fabric=fabric,
    repo_dir=REPO_DIR,
    dataset_name=DATASET_NAME,
    seq_len=SEQ_LEN,
    batch_size=BATCH_SIZE,
    wavelet_type=WAVELET_TYPE,
    wavelet_levels=WAVELET_LEVELS,
    data_path=DATA_PATH
)

In [None]:
# @title Cell 5: Initialize Model
model, optimizer, config = trainer.init_model(
    fabric=fabric,
    datamodule=datamodule,
    config=config,
    embed_dim=EMBED_DIM,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    time_embed_dim=TIME_EMBED_DIM,
    dropout=DROPOUT,
    prediction_target=PREDICTION_TARGET,
    use_cross_level_attention=USE_CROSS_LEVEL_ATTENTION,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    max_lr=MAX_LR,
    pct_start=PCT_START
)

In [None]:
# @title Cell 6: Run Training
trainer.train_loop(
    fabric=fabric,
    model=model,
    optimizer=optimizer,
    train_loader=train_loader,
    config=config,
    total_steps=TOTAL_TRAINING_STEPS,
    save_interval=SAVE_INTERVAL,
    checkpoint_dir=CHECKPOINT_DIR,
    enable_profiler=ENABLE_PROFILER,
    enable_grad_clipping=ENABLE_GRAD_CLIPPING
)