# WaveletDiff Evaluation Notebook

This notebook evaluates a trained WaveletDiff model using various metrics:
- Discriminative Score
- Predictive Score
- Context-FID
- Time-Series Correlation
- DTW Distance

It uses the backend defined in `src/eval` and `src/inference`.

In [None]:
# @title Cell 1: Global Configuration

# Paths
REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
REPO_DIR = "/content/waveletDiff_synth_data" 

# Checkpoint Path (Update this to your checkpoint in Google Drive)
DRIVE_DIR = "/content/drive/MyDrive/personal_drive/trading"
RUN_NAME = "stocks_ohlcv_v1_top"
CHECKPOINT_NAME = "last"
CHECKPOINT_PATH = f"{DRIVE_DIR}/checkpoints/{RUN_NAME}/{CHECKPOINT_NAME}.ckpt"

# Data Path (Real data for comparison)
# If using the one from the repo:
DATA_PATH = f"{REPO_DIR}/WaveletDiff_source/data/stocks/stock_data.csv"
# Or if you have a pre-processed .npy file in drive:
# DATA_PATH = "/content/drive/MyDrive/.../real_samples.npy"

OUTPUT_DIR = "/content/eval_outputs"
NUM_SAMPLES = 2000 # Number of samples to generate and evaluate
DEVICE = "cuda" # or "cpu"

In [None]:
# @title Cell 2: Setup (Clone, Install, Mount)
import os
import sys
import subprocess
import shutil
from google.colab import drive

# 1. Mount Drive
if os.path.exists('/content/drive'):
    if not os.listdir('/content/drive'):
        drive.mount('/content/drive', force_remount=True)
else:
    try:
        drive.mount('/content/drive')
    except:
        print("Drive mount failed or not in Colab.")

# 2. Force Clone Repository
if os.path.exists(REPO_DIR):
    print(f"Removing existing repo at {REPO_DIR} for a fresh sync...")
    shutil.rmtree(REPO_DIR)

print(f"Cloning {REPO_URL} into {REPO_DIR}...")
subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)

# 3. Install Dependencies
print("Installing dependencies...")
deps = ["lightning", "pywavelets", "scipy", "pandas", "tqdm", "torch_xla[tpu]" if 'COLAB_TPU_ADDR' in os.environ else ""]
deps = [d for d in deps if d]
subprocess.run(["pip", "install"] + deps, check=True)

# 4. Setup Paths
sys.path.append(REPO_DIR) # Allows "import src.eval"
sys.path.append(os.path.join(REPO_DIR, "WaveletDiff_source", "src")) # Allows "import models", "import training"
sys.path.append(os.path.join(REPO_DIR, "WaveletDiff_source", "src", "evaluation")) # Allows "import discriminative_metrics"

print("Setup Complete.")

In [None]:
# @title Cell 3: Run Evaluation
import torch
from src.eval.evaluator import run_evaluation

# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Run
results = run_evaluation(
    checkpoint_path=CHECKPOINT_PATH,
    data_path=DATA_PATH,
    output_dir=OUTPUT_DIR,
    num_samples=NUM_SAMPLES,
    device=DEVICE
)
