# WaveletDiff Evaluation Notebook

This notebook evaluates a trained WaveletDiff model using various metrics:
- Discriminative Score
- Predictive Score
- Context-FID
- Time-Series Correlation
- DTW Distance

It uses the backend defined in `src/torch_gpu_waveletDiff/eval` and `src/torch_gpu_waveletDiff/inference`.

In [None]:
# @title Cell 1: Global Configuration

# Paths
REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
REPO_DIR = "/content/waveletDiff_synth_data" 

# Checkpoint Path (Update this to your checkpoint in Google Drive)
DRIVE_DIR = "/content/drive/MyDrive/personal_drive/trading"
RUN_NAME = "stocks_ohlcv_v1_top"
CHECKPOINT_NAME = "last"
CHECKPOINT_PATH = f"{DRIVE_DIR}/checkpoints/{RUN_NAME}/{CHECKPOINT_NAME}.ckpt"

# Data Path (Real data for comparison)
# If using the one from the repo:
DATA_PATH = f"{REPO_DIR}/src/copied_waveletDiff/data/stocks/stock_data.csv"
# Or if you have a pre-processed .npy file in drive:
# DATA_PATH = "/content/drive/MyDrive/.../real_samples.npy"

OUTPUT_DIR = "/content/eval_outputs"
NUM_SAMPLES = 2000 # Number of samples to generate and evaluate
DEVICE = "cuda" # or "cpu"
BATCH_SIZE = 2000 # Batch size for generation

In [None]:
# @title Cell 2: Setup (Clone, Install, Mount)
import os
import sys
import subprocess
import shutil
from google.colab import drive

# 1. Mount Drive
if os.path.exists('/content/drive'):
    if not os.listdir('/content/drive'):
        drive.mount('/content/drive', force_remount=True)
else:
    try:
        drive.mount('/content/drive')
    except:
        print("Drive mount failed or not in Colab.")

# 2. Clone or Pull Repository
if os.path.exists(REPO_DIR):
    print(f"Repo exists at {REPO_DIR}. Pulling latest changes...")
    try:
        subprocess.run(["git", "-C", REPO_DIR, "pull"], check=True)
    except subprocess.CalledProcessError:
         print("Git pull failed. Removing and re-cloning...")
         shutil.rmtree(REPO_DIR)
         subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
else:
    print(f"Cloning {REPO_URL} into {REPO_DIR}...")
    subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)

# 3. Install Dependencies
print("Installing dependencies...")
deps = ["lightning", "pywavelets", "scipy", "pandas", "tqdm", "torch_xla[tpu]" if 'COLAB_TPU_ADDR' in os.environ else ""]
deps = [d for d in deps if d]
subprocess.run(["pip", "install"] + deps, check=True)

# 4. Setup Paths
# Main repo root
sys.path.append(REPO_DIR)

# Add the directory containing 'models', 'data', etc. as a package source.
# The loader.py does 'from models.transformer import ...', so 'models' must be a top-level package.
src_path = os.path.join(REPO_DIR, "src", "copied_waveletDiff", "src")
if src_path not in sys.path:
    sys.path.append(src_path)

# IMPORTANT: Also enable importing from the root 'src' for torch_gpu_waveletDiff
repo_src = os.path.join(REPO_DIR, "src")
if repo_src not in sys.path:
    sys.path.append(repo_src)

print("Setup Complete.")

In [None]:
# @title Cell 3: Run Evaluation
import torch
import warnings
from src.torch_gpu_waveletDiff.eval.evaluator import run_evaluation

# Set float32 matmul precision to 'medium' to use Tensor Cores on Ampere+ GPUs
torch.set_float32_matmul_precision('medium')

# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Run
results = run_evaluation(
    checkpoint_path=CHECKPOINT_PATH,
    data_path=DATA_PATH,
    output_dir=OUTPUT_DIR,
    num_samples=NUM_SAMPLES,
    device=DEVICE,
    batch_size=BATCH_SIZE
)


In [None]:
# @title Cell 4: Visualization
import glob
import numpy as np
import os
from IPython.display import Image, display
from src.torch_gpu_waveletDiff.eval.visualizer import visualize_evaluation

print("Starting visualization...")

# Find latest generated samples
list_of_files = glob.glob(f'{OUTPUT_DIR}/generated_samples_*.npy')
if not list_of_files:
    print("No generated samples found. Please run evaluation first.")
else:
    latest_gen_file = max(list_of_files, key=os.path.getctime)
    print(f"Loading generated samples from {latest_gen_file}...")
    generated_data = np.load(latest_gen_file)
    
    # Find saved real samples
    real_file = os.path.join(OUTPUT_DIR, "real_samples_used.npy")
    if os.path.exists(real_file):
        print(f"Loading real samples from {real_file}...")
        real_data = np.load(real_file)
    else:
        print(f"Real samples file not found at {real_file}. Falling back to DATA_PATH...")
        # Fallback might fail for CSV, but it's a backup
        try:
             real_data = np.load(DATA_PATH)
        except:
             print("Could not load real data from DATA_PATH (likely CSV). Please re-run evaluation to generate 'real_samples_used.npy'.")
             real_data = None

    if real_data is not None:
        # Run Visualization Backend
        visualize_evaluation(real_data, generated_data, OUTPUT_DIR)
        
        # Display saved plots
        print("\n--- Visual Comparisons ---")
        if os.path.exists(f"{OUTPUT_DIR}/sample_comparison.png"):
            print("1. Sample Comparison (Real vs Generated)")
            display(Image(filename=f"{OUTPUT_DIR}/sample_comparison.png"))
        
        if os.path.exists(f"{OUTPUT_DIR}/distribution_comparison.png"):
            print("\n2. Feature Distribution Comparison")
            display(Image(filename=f"{OUTPUT_DIR}/distribution_comparison.png"))
            
        if os.path.exists(f"{OUTPUT_DIR}/pca_projection.png"):
            print("\n3. PCA Projection (2D)")
            display(Image(filename=f"{OUTPUT_DIR}/pca_projection.png"))
    else:
        print("Skipping visualization due to missing real data.")