# WaveletDiff Evaluation

This notebook calls `sample.py` and `run_eval.py` directly, using checkpoints from Google Drive.

In [None]:
# @title Cell 1: Global Configuration

# --- Drive Paths ---
DRIVE_MOUNT_PATH = "/content/drive" # @param {type:"string"}
DRIVE_CHECKPOINT_PATH = "/content/drive/MyDrive/checkpoints/stocks_model/best.ckpt" # @param {type:"string"}

# --- Evaluation Settings ---
DATASET = "stocks" # @param {type:"string"}
EXPERIMENT_NAME = "evaluation_run" # @param {type:"string"}
NUM_SAMPLES = 2000 # @param {type:"integer"}
SAMPLING_METHOD = "ddpm" # @param ["ddpm", "ddim"]
DEVICE = "cuda" # @param ["cuda", "cpu"]

In [None]:
# @title Cell 2: Setup
import os
import sys
import shutil

REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
REPO_NAME = "waveletDiff_synth_data"
REPO_PATH = os.path.abspath(REPO_NAME)

if not os.path.exists(REPO_PATH):
    print(f"Cloning {REPO_URL}...")
    !git clone {REPO_URL} {REPO_NAME}
else:
    print(f"Pulling latest for {REPO_DIR}...")
    !git -C {REPO_NAME} pull

# Install dependencies (using scikit-learn instead of sklearn)
!pip install -q pytorch-lightning pywavelets scipy pandas tqdm scikit-learn tslearn

# Mount Google Drive
from google.colab import drive
if not os.path.exists(DRIVE_MOUNT_PATH):
    drive.mount(DRIVE_MOUNT_PATH)

# Prepare Checkpoint and Configs
local_exp_dir = os.path.join(REPO_PATH, "outputs", EXPERIMENT_NAME)
os.makedirs(local_exp_dir, exist_ok=True)

if os.path.exists(DRIVE_CHECKPOINT_PATH):
    dst_ckpt = os.path.join(local_exp_dir, "checkpoint.ckpt")
    shutil.copy2(DRIVE_CHECKPOINT_PATH, dst_ckpt)
    print(f"✅ Copied checkpoint to {dst_ckpt}")
else:
    print(f"❌ ERROR: Checkpoint not found at {DRIVE_CHECKPOINT_PATH}")

configs_src = os.path.join(REPO_PATH, "WaveletDiff_source", "configs")
configs_dst = os.path.join(REPO_PATH, "configs")
if not os.path.exists(configs_dst) and os.path.exists(configs_src):
    shutil.copytree(configs_src, configs_dst)
    print("✅ Copied default configs to repo root")

print("✅ Setup Complete")

In [None]:
# @title Cell 3: Run Evaluation Pipeline
%cd {REPO_NAME}

print("1. Generating Samples...")
!python src/sample.py \
    --experiment_name {EXPERIMENT_NAME} \
    --dataset {DATASET} \
    --num_samples {NUM_SAMPLES} \
    --sampling_method {SAMPLING_METHOD}

print("\n2. Running Metrics...")
real_samples_path = f"outputs/{EXPERIMENT_NAME}/real_samples.npy"
fake_samples_path = f"outputs/{EXPERIMENT_NAME}/{SAMPLING_METHOD}_samples.npy"

if os.path.exists(real_samples_path) and os.path.exists(fake_samples_path):
    !python src/evaluation/run_eval.py \
        --real_data {real_samples_path} \
        --fake_data {fake_samples_path} \
        --output_dir outputs/{EXPERIMENT_NAME} \
        --device {DEVICE}
else:
    print("❌ ERROR: Sample files not found. Check generation output above.")