# WaveletDiff Evaluation

This notebook evaluates a trained WaveletDiff model directly from your Google Colab environment using a checkpoint from Google Drive.

In [None]:
# @title Cell 1: Global Configuration
import os

# --- Drive Paths ---
DRIVE_MOUNT_PATH = "/content/drive" # @param {type:"string"}
CHECKPOINT_NAME = "checkpoint.ckpt" # @param {type:"string"}
DRIVE_CHECKPOINT_FOLDER = "/content/drive/MyDrive/personal_drive/trading/waveletDiff/checkpoints" # @param {type:"string"}
DRIVE_CHECKPOINT_PATH = os.path.join(DRIVE_CHECKPOINT_FOLDER, CHECKPOINT_NAME)

# --- Repository Settings ---
REPO_BRANCH = "develop" # @param {type:"string"}

# --- Evaluation Settings ---
DATASET = "stocks" # @param {type:"string"}
EXPERIMENT_NAME = "evaluation_run" # @param {type:"string"}
NUM_SAMPLES = 2000 # @param {type:"integer"}
SAMPLING_METHOD = "ddpm" # @param ["ddpm", "ddim"]
COMPILE_MODE = "none" # @param ["none", "default", "reduce-overhead", "max-autotune"]
DEVICE = "cuda" # @param ["cuda", "cpu"]

In [None]:
# @title Cell 2: Setup (Clone, Install, Mount)
import os
import sys
import shutil
from google.colab import drive

REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
REPO_NAME = "waveletDiff_synth_data"
REPO_PATH = os.path.abspath(REPO_NAME)

# 1. Clone or Pull Repo
if os.path.exists(REPO_PATH):
    # Check for nested directory issue (common in Colab if run multiple times)
    nested_path = os.path.join(REPO_PATH, REPO_NAME)
    if os.path.exists(nested_path):
        print(f"⚠️ Detected nested repository at {nested_path}. cleaning up...")
        shutil.rmtree(REPO_PATH)
        print(f"Cloning {REPO_URL}...")
        !git clone {REPO_URL} {REPO_NAME}
        !git -C {REPO_NAME} checkout {REPO_BRANCH}
    else:
        print(f"Pulling latest for {REPO_NAME}...")
        !git -C {REPO_NAME} fetch origin
        !git -C {REPO_NAME} checkout {REPO_BRANCH}
        !git -C {REPO_NAME} pull origin {REPO_BRANCH}
else:
    print(f"Cloning {REPO_URL}...")
    !git clone {REPO_URL} {REPO_NAME}
    !git -C {REPO_NAME} checkout {REPO_BRANCH}

# 2. Install Dependencies (using scikit-learn instead of sklearn)
!pip install -q pytorch-lightning pywavelets scipy pandas tqdm scikit-learn tslearn seaborn

# 3. Mount Google Drive
if not os.path.exists(DRIVE_MOUNT_PATH):
    drive.mount(DRIVE_MOUNT_PATH)

# 4. Setup Paths
if os.path.join(REPO_PATH, "src") not in sys.path:
    sys.path.append(os.path.join(REPO_PATH, "src"))
if os.path.join(REPO_PATH, "src", "evaluation") not in sys.path:
    sys.path.append(os.path.join(REPO_PATH, "src", "evaluation"))

# 5. Prepare Checkpoint and Configs
local_exp_dir = os.path.join(REPO_PATH, "outputs", EXPERIMENT_NAME)
os.makedirs(local_exp_dir, exist_ok=True)

if os.path.exists(DRIVE_CHECKPOINT_PATH):
    dst_ckpt = os.path.join(local_exp_dir, "checkpoint.ckpt")
    shutil.copy2(DRIVE_CHECKPOINT_PATH, dst_ckpt)
    print(f"✅ Copied checkpoint to {dst_ckpt}")
else:
    print(f"❌ ERROR: Checkpoint not found at {DRIVE_CHECKPOINT_PATH}")

# Sync configs to repo root so scripts can find them at ../configs
configs_src = os.path.join(REPO_PATH, "WaveletDiff_source", "configs")
configs_dst = os.path.join(REPO_PATH, "configs")

if os.path.exists(configs_src):
    if os.path.exists(configs_dst):
        shutil.rmtree(configs_dst)
    shutil.copytree(configs_src, configs_dst)
    print("✅ Configs synced to repo root")

print("✅ Setup Complete")

In [None]:
# @title Cell 3: Generate Samples
import os
os.chdir(os.path.join(REPO_PATH, "src"))

print(f"Generating {NUM_SAMPLES} samples using {SAMPLING_METHOD} with compile_mode={COMPILE_MODE}...")

status = !python sample.py \
    --experiment_name {EXPERIMENT_NAME} \
    --dataset {DATASET} \
    --num_samples {NUM_SAMPLES} \
    --sampling_method {SAMPLING_METHOD} \
    --compile_mode {COMPILE_MODE}

print("\n".join(status))

if any("Traceback" in s or "Error" in s for s in status):
    print("\n❌ Generation failed. Check the error log above.")
else:
    print("\n✅ Sampling Complete")

In [None]:
# @title Cell 4: Imports and Setup
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import pandas as pd
import os

from discriminative_metrics import discriminative_score_metrics
from predictive_metrics import predictive_score_metrics
from context_fid import Context_FID
from cross_correlation import CrossCorrelLoss
from metric_utils import display_scores
from dtw import dtw_js_divergence_distance

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context("notebook", font_scale=1.2)
COLORS = {"Real": "#d62728", "Generated": "#1f77b4"} # Red for Real, Blue for Generated

### Load Real and Generated Data

In [None]:
# Paths are relative to 'src' directory where we currently are
real_data_path = f"../outputs/{EXPERIMENT_NAME}/real_samples.npy"
gen_data_path = f"../outputs/{EXPERIMENT_NAME}/{SAMPLING_METHOD}_samples.npy"

if not os.path.exists(real_data_path):
    raise FileNotFoundError(f"Could not find real samples at {os.path.abspath(real_data_path)}. Check Cell 3 output.")

real_data = np.load(real_data_path)
generated_data = np.load(gen_data_path)
print(f"Loaded Real: {real_data.shape}, Generated: {generated_data.shape}")

In [None]:
num_samples = min(real_data.shape[0], generated_data.shape[0])
if real_data.shape[0] > num_samples:
    print(f"WARNING: Using all {num_samples} generated samples for evaluation.")
else:
    print(f"Number of samples: {num_samples}")

random_indices = np.random.choice(len(real_data), num_samples, replace=False)
real_data = real_data[random_indices]
random_indices = np.random.choice(len(generated_data), num_samples, replace=False)
generated_data = generated_data[random_indices]

In [None]:
# minmax scale the inputs for fair comparison
data_min = np.min(real_data, axis=(0,1), keepdims=True)
data_max = np.max(real_data, axis=(0,1), keepdims=True)

real_data = (real_data - data_min) / (data_max - data_min)
generated_data = (generated_data - data_min) / (data_max - data_min)

### Visualizations

In [None]:
# @title t-SNE & PCA Visualization

def plot_distribution_reduction(real, generated, n_samples=1000):
    # Flatten time series for t-SNE (Standard approach in TimeGAN/Diffusion-TS)
    # Shape: (N, T*D)
    n_samples = min(n_samples, len(real), len(generated))
    
    real_flat = real[:n_samples].reshape(n_samples, -1)
    gen_flat = generated[:n_samples].reshape(n_samples, -1)
    
    # Concatenate
    data = np.concatenate([real_flat, gen_flat], axis=0)
    labels = ["Real"] * n_samples + ["Generated"] * n_samples
    
    # t-SNE
    print("Running t-SNE...")
    tsne = TSNE(n_components=2, perplexity=40, n_iter=300)
    tsne_results = tsne.fit_transform(data)
    
    # Plot
    plt.figure(figsize=(16, 6))
    
    # t-SNE Plot
    plt.subplot(1, 2, 1)
    sns.scatterplot(x=tsne_results[:n_samples, 0], y=tsne_results[:n_samples, 1], 
                    color=COLORS["Real"], alpha=0.3, label="Real", s=20)
    sns.scatterplot(x=tsne_results[n_samples:, 0], y=tsne_results[n_samples:, 1],
                    color=COLORS["Generated"], alpha=0.3, label="Generated", s=20)
    plt.title("t-SNE Visualization")
    plt.legend()
    
    # PCA Plot (for variance check)
    print("Running PCA...")
    pca = PCA(n_components=2)
    pca_results = pca.fit_transform(data)
    
    plt.subplot(1, 2, 2)
    sns.scatterplot(x=pca_results[:n_samples, 0], y=pca_results[:n_samples, 1], 
                    color=COLORS["Real"], alpha=0.3, label="Real", s=20)
    sns.scatterplot(x=pca_results[n_samples:, 0], y=pca_results[n_samples:, 1],
                    color=COLORS["Generated"], alpha=0.3, label="Generated", s=20)
    plt.title("PCA Visualization")
    plt.legend()
    
    plt.show()

plot_distribution_reduction(real_data, generated_data)

In [None]:
# @title Probability Density Function (Data Values)

def plot_pdf(real, generated):
    plt.figure(figsize=(10, 6))
    
    # Flatten all data to compare value distributions
    sns.kdeplot(real.flatten(), fill=True, color=COLORS["Real"], label="Real", alpha=0.3)
    sns.kdeplot(generated.flatten(), fill=True, color=COLORS["Generated"], label="Generated", alpha=0.3)
    
    plt.title("Probability Density Function (All Values)")
    plt.xlabel("Data Value")
    plt.ylabel("Density")
    plt.legend()
    plt.show()

plot_pdf(real_data, generated_data)

In [None]:
# @title Sample Visualization (Generated vs Real)

def plot_samples(real, generated, n_samples=5):
    # Assuming shape (N, T, D)
    # We exclude the last feature if it is volume (assuming > 1 feature)
    n_features = real.shape[2]
    plot_features = n_features - 1 if n_features > 1 else n_features
    
    fig, axes = plt.subplots(2, n_samples, figsize=(n_samples * 4, 6), sharey=True)
    
    for i in range(n_samples):
        # Real Samples
        for f in range(plot_features):
            axes[0, i].plot(real[i, :, f], alpha=0.8)
        axes[0, i].set_title(f"Real Sample {i}")
        if i == 0: axes[0, i].set_ylabel("Value (MinMax Scaled)")
        
        # Generated Samples
        for f in range(plot_features):
            axes[1, i].plot(generated[i, :, f], alpha=0.8)
        axes[1, i].set_title(f"Gen Sample {i}")
        if i == 0: axes[1, i].set_ylabel("Value (MinMax Scaled)")
    
    # Create a dummy legend
    from matplotlib.lines import Line2D
    lines = [Line2D([0], [0], color=f"C{i}", lw=2) for i in range(plot_features)]
    fig.legend(lines, [f"Feature {i}" for i in range(plot_features)], loc='lower center', ncol=plot_features)
    
    plt.tight_layout()
    print(f"Note: Showing {plot_features} features (excluding last/volume feature if D>1)")
    plt.show()

plot_samples(real_data, generated_data)

### Statistical Metrics

In [None]:
# @title Statistical Distribution Metrics (Real vs Generated Comparison)
from scipy.stats import skew, kurtosis

def calculate_statistical_metrics(real, generated):
    """Calculate and display comprehensive statistical metrics."""
    # Flatten over samples and time to get distribution of values per feature
    r_flat = real.reshape(-1, real.shape[2])
    g_flat = generated.reshape(-1, generated.shape[2])
    n_features = real.shape[2]
    
    # Calculate per-feature statistics
    real_stats = {
        "Mean": np.mean(r_flat, axis=0),
        "Std": np.std(r_flat, axis=0),
        "Skewness": skew(r_flat, axis=0),
        "Kurtosis": kurtosis(r_flat, axis=0),
        "Min": np.min(r_flat, axis=0),
        "Max": np.max(r_flat, axis=0),
    }
    
    gen_stats = {
        "Mean": np.mean(g_flat, axis=0),
        "Std": np.std(g_flat, axis=0),
        "Skewness": skew(g_flat, axis=0),
        "Kurtosis": kurtosis(g_flat, axis=0),
        "Min": np.min(g_flat, axis=0),
        "Max": np.max(g_flat, axis=0),
    }
    
    # Print comparison table
    print("=" * 80)
    print("STATISTICAL COMPARISON: Real vs Generated Data")
    print("=" * 80)
    
    # Per-feature comparison
    for f in range(n_features):
        print(f"\n--- Feature {f} ---")
        print(f"{'Metric':<15} {'Real':>12} {'Generated':>12} {'Diff (Abs)':>12}")
        print("-" * 55)
        for stat_name in real_stats:
            r_val = real_stats[stat_name][f]
            g_val = gen_stats[stat_name][f]
            diff = abs(r_val - g_val)
            print(f"{stat_name:<15} {r_val:>12.4f} {g_val:>12.4f} {diff:>12.4f}")
    
    # Aggregate metrics (averaged across features)
    print("\n" + "=" * 80)
    print("AGGREGATE METRICS (Averaged Across All Features)")
    print("=" * 80)
    
    aggregate = {
        "Mean MAE": np.mean(np.abs(real_stats["Mean"] - gen_stats["Mean"])),
        "Std MAE": np.mean(np.abs(real_stats["Std"] - gen_stats["Std"])),
        "Skewness MAE": np.mean(np.abs(real_stats["Skewness"] - gen_stats["Skewness"])),
        "Kurtosis MAE": np.mean(np.abs(real_stats["Kurtosis"] - gen_stats["Kurtosis"])),
    }
    
    print(f"{'Metric':<20} {'Value':>12} {'Interpretation'}")
    print("-" * 60)
    for k, v in aggregate.items():
        quality = "✅ Good" if v < 0.05 else ("⚠️ Moderate" if v < 0.15 else "❌ High")
        print(f"{k:<20} {v:>12.6f} {quality}")
    
    return aggregate

stat_results = calculate_statistical_metrics(real_data, generated_data)

### Discriminative Score

In [None]:
iterations = 5
discriminative_score = []

for i in range(iterations):
    temp_disc, fake_acc, real_acc = discriminative_score_metrics(real_data, generated_data)
    discriminative_score.append(temp_disc)
    print(f'Iter {i}: ', temp_disc, '\n')
      
display_scores(discriminative_score)
print()

### Predictive Score

In [None]:
iterations = 5
predictive_score = []
for i in range(iterations):
    temp_pred = predictive_score_metrics(real_data, generated_data)
    predictive_score.append(temp_pred)
    print(i, ' epoch: ', temp_pred, '\n')
      
display_scores(predictive_score)
print()

### Context-FID Score

In [None]:
context_fid_score = []

for i in range(iterations):
    context_fid = Context_FID(real_data, generated_data)
    context_fid_score.append(context_fid)
    print(f'Iter {i}: ', 'context-fid =', context_fid, '\n')
      
display_scores(context_fid_score)

### Correlational Score

In [None]:
def random_choice(size, num_select=100):
    select_idx = np.random.randint(low=0, high=size, size=(num_select,))
    return select_idx

x_real = torch.from_numpy(real_data)
x_fake = torch.from_numpy(generated_data)

correlational_score = []
size = 1000

for i in range(iterations):
    real_idx = random_choice(x_real.shape[0], size)
    fake_idx = random_choice(x_fake.shape[0], size)
    corr = CrossCorrelLoss(x_real[real_idx, :, :], name='CrossCorrelLoss')
    loss = corr.compute(x_fake[fake_idx, :, :])
    correlational_score.append(loss.item())
    print(f'Iter {i}: ', 'cross-correlation =', loss.item(), '\n')

display_scores(correlational_score)

### DTW distance

In [None]:
iterations = 5
js_results = []
for i in range(iterations):
    # Standard n_samples for DTW comparison is 100 (slow metric)
    js_dist = dtw_js_divergence_distance(real_data, generated_data, n_samples=100)['js_divergence']
    print("js_dist: ", round(js_dist, 4))
    js_results.append(js_dist)
display_scores(js_results)