# WaveletDiff Evaluation (Refactored)

This notebook evaluates a trained WaveletDiff model. It acts as a frontend interface, delegating heavy logic to `src/evaluation` modules.

In [None]:
# @title Cell 1: Global Configuration
import os

# --- Drive Paths ---
DRIVE_MOUNT_PATH = "/content/drive" # @param {type:"string"}
DRIVE_BASE_PATH = "/content/drive/MyDrive/personal_drive/trading/waveletDiff" # @param {type:"string"}
CHECKPOINT_FOLDER = "checkpoints" # @param {type:"string"}
SAMPLES_FOLDER = "samples" # @param {type:"string"}

MODEL_FILENAME = "stocks_experiment.tar.gz" # @param {type:"string"}
MODEL_BASENAME = MODEL_FILENAME.replace('.tar.gz', '').replace('.zip', '').replace('.ckpt', '').replace('.tgz', '').replace('.gz', '')

DRIVE_CHECKPOINT_PATH = os.path.join(DRIVE_BASE_PATH, CHECKPOINT_FOLDER, MODEL_FILENAME)
DRIVE_SAMPLES_PATH = os.path.join(DRIVE_BASE_PATH, SAMPLES_FOLDER, MODEL_BASENAME)

# --- Repository Settings ---
REPO_BRANCH = "develop" # @param {type:"string"}

# --- Evaluation Settings ---
DATASET = "stocks" # @param {type:"string"}
EXPERIMENT_NAME = "evaluation_run" # @param {type:"string"}
NUM_SAMPLES = 2000 # @param {type:"integer"}
SAMPLING_METHOD = "ddpm" # @param ["ddpm", "ddim"]
COMPILE_MODE = "none" # @param ["none", "default", "reduce-overhead", "max-autotune"]
DEVICE = "cuda" # @param ["cuda", "cpu"]

# --- Evaluation Options ---
EXCLUDE_VOLUME = True # @param {type:"boolean"}
CACHE_SAMPLES_TO_DRIVE = True # @param {type:"boolean"}
USE_CACHED_SAMPLES = True # @param {type:"boolean"}

In [None]:
# @title Cell 2: Setup (Clone, Install, Mount)
import os
import sys
import shutil
from google.colab import drive

REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
REPO_NAME = "waveletDiff_synth_data"
REPO_PATH = os.path.abspath(REPO_NAME)

# 1. Clone or Pull Repo
if os.path.exists(REPO_PATH):
    nested_path = os.path.join(REPO_PATH, REPO_NAME)
    if os.path.exists(nested_path):
        shutil.rmtree(REPO_PATH)
        !git clone {REPO_URL} {REPO_NAME}
        !git -C {REPO_NAME} checkout {REPO_BRANCH}
    else:
        !git -C {REPO_NAME} fetch origin
        !git -C {REPO_NAME} checkout {REPO_BRANCH}
        !git -C {REPO_NAME} pull origin {REPO_BRANCH}
else:
    !git clone {REPO_URL} {REPO_NAME}
    !git -C {REPO_NAME} checkout {REPO_BRANCH}

# 2. Install Dependencies
!pip install -q pytorch-lightning pywavelets scipy pandas tqdm scikit-learn tslearn seaborn statsmodels

# 3. Mount Drive & Setup Paths
if not os.path.exists(DRIVE_MOUNT_PATH):
    drive.mount(DRIVE_MOUNT_PATH)

for p in [os.path.join(REPO_PATH, "src"), os.path.join(REPO_PATH, "src", "evaluation")]:
    if p not in sys.path: sys.path.append(p)

# 4. Prepare Experiments
local_exp_dir = os.path.join(REPO_PATH, "outputs", EXPERIMENT_NAME)
os.makedirs(local_exp_dir, exist_ok=True)

# (Unpacking logic shortened for brevity - relies on user awareness or existing logic)
if os.path.exists(DRIVE_CHECKPOINT_PATH):
    print(f"Unpacking model from {DRIVE_CHECKPOINT_PATH}...")
    if DRIVE_CHECKPOINT_PATH.endswith(".ckpt"):
        shutil.copy2(DRIVE_CHECKPOINT_PATH, os.path.join(local_exp_dir, "checkpoint.ckpt"))
    else:
        shutil.unpack_archive(DRIVE_CHECKPOINT_PATH, local_exp_dir, format='gztar' if '.gz' in DRIVE_CHECKPOINT_PATH and not '.tar' in DRIVE_CHECKPOINT_PATH else None)
else:
    print(f"❌ Model file not found.")

# Sync configs
if os.path.exists(os.path.join(REPO_PATH, "WaveletDiff_source", "configs")):
    shutil.rmtree(os.path.join(REPO_PATH, "configs"), ignore_errors=True)
    shutil.copytree(os.path.join(REPO_PATH, "WaveletDiff_source", "configs"), os.path.join(REPO_PATH, "configs"))

print("✅ Setup Complete")

In [None]:
# @title Cell 3: Generate or Load Samples
import numpy as np
os.chdir(os.path.join(REPO_PATH, "src"))

local_gen_path = f"../outputs/{EXPERIMENT_NAME}/{SAMPLING_METHOD}_samples.npy"
local_real_path = f"../outputs/{EXPERIMENT_NAME}/real_samples.npy"
drive_gen_path = os.path.join(DRIVE_SAMPLES_PATH, f"{SAMPLING_METHOD}_samples.npy")
drive_real_path = os.path.join(DRIVE_SAMPLES_PATH, "real_samples.npy")

if USE_CACHED_SAMPLES and os.path.exists(drive_gen_path):
    print("Loading cached samples...")
    os.makedirs(os.path.dirname(local_gen_path), exist_ok=True)
    shutil.copy2(drive_gen_path, local_gen_path); shutil.copy2(drive_real_path, local_real_path)
else:
    print(f"Generating {NUM_SAMPLES} samples...")
    !python sample.py --experiment_name {EXPERIMENT_NAME} --dataset {DATASET} --num_samples {NUM_SAMPLES} --sampling_method {SAMPLING_METHOD} --compile_mode {COMPILE_MODE}
    if CACHE_SAMPLES_TO_DRIVE:
        os.makedirs(DRIVE_SAMPLES_PATH, exist_ok=True)
        shutil.copy2(local_gen_path, drive_gen_path); shutil.copy2(local_real_path, drive_real_path)

print("✅ Samples Ready")

In [None]:
# @title Cell 4: Initialize Modules
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt

# New Refactored Modules
from evaluation import visualizations as viz
from evaluation import statistics as stats
from evaluation import reporting as report
from evaluation import wrappers

# Load Data
real_path = f"../outputs/{EXPERIMENT_NAME}/real_samples.npy"
gen_path = f"../outputs/{EXPERIMENT_NAME}/{SAMPLING_METHOD}_samples.npy"
real_data = np.load(real_path)
generated_data = np.load(gen_path)

if EXCLUDE_VOLUME and real_data.shape[2] > 1:
    real_data, generated_data = real_data[..., :-1], generated_data[..., :-1]

# Downsample
n_s = min(2000, len(real_data), len(generated_data))
real_data = real_data[np.random.choice(len(real_data), n_s, replace=False)]
generated_data = generated_data[np.random.choice(len(generated_data), n_s, replace=False)]

# Prepare Raw vs Scaled
real_data_raw, generated_data_raw = real_data.copy(), generated_data.copy()

dmin, dmax = np.min(real_data, axis=(0,1), keepdims=True), np.max(real_data, axis=(0,1), keepdims=True)
real_data_scaled = (real_data - dmin) / (dmax - dmin + 1e-8)
generated_data_scaled = (generated_data - dmin) / (dmax - dmin + 1e-8)

print(f"Data Loaded: {real_data.shape}")

### Visualizations

In [None]:
# @title Visual Analysis
viz.plot_distribution_reduction(real_data_scaled, generated_data_scaled)
viz.plot_pdf(real_data_scaled, generated_data_scaled)
viz.plot_samples(real_data_scaled, generated_data_scaled)

### Statistical & Model Quality Metrics

In [None]:
# @title Run All Metrics
metrics_dict = {}

# 1. Statistics (Raw Data)
stat_agg, real_stat_det, gen_stat_det = stats.calculate_statistical_metrics(real_data_raw, generated_data_raw)
metrics_dict['stat_results'] = stat_agg

# 2. Discriminative/Predictive (Scaled Data)
metrics_dict['discriminative_score'] = wrappers.run_discriminative_benchmark(real_data_scaled, generated_data_scaled)
metrics_dict['predictive_score'] = wrappers.run_predictive_benchmark(real_data_scaled, generated_data_scaled)
metrics_dict['context_fid_score'] = wrappers.run_context_fid_benchmark(real_data_scaled, generated_data_scaled)

# 3. Correlational & DTW
metrics_dict['correlational_score'] = wrappers.run_cross_correlation_benchmark(real_data_scaled, generated_data_scaled)
metrics_dict['js_results'] = wrappers.run_dtw_benchmark(real_data_scaled, generated_data_scaled)

# 4. Advanced Metrics (Raw Data)
metrics_dict['dist_results'], metrics_dict['struct_results'], metrics_dict['fin_results'] = wrappers.run_advanced_metrics(real_data_raw, generated_data_raw)
metrics_dict['mem_ratio'], metrics_dict['div_results'], metrics_dict['fld_score'] = wrappers.run_new_metrics(real_data_raw, generated_data_raw)

### Final Scorecard

In [None]:
# @title Summary Scorecard
report.generate_summary_scorecard(metrics_dict)

In [None]:
# @title Detailed Feature Stats
report.display_feature_stats(real_stat_det, gen_stat_det)