# WaveletDiff Evaluation

This notebook evaluates a trained WaveletDiff model directly from your Google Colab environment using a compressed experiment archive from Google Drive.

In [None]:
# @title Cell 1: Global Configuration
import os

# --- Drive Paths ---
DRIVE_MOUNT_PATH = "/content/drive" # @param {type:"string"}
DRIVE_BASE_PATH = "/content/drive/MyDrive/waveletDiff_experiments" # @param {type:"string"}
CHECKPOINT_FOLDER = "checkpoints" # @param {type:"string"}
SAMPLES_FOLDER = "samples" # @param {type:"string"}

MODEL_FILENAME = "stocks_experiment.tar.gz" # @param {type:"string"}
MODEL_BASENAME = MODEL_FILENAME.replace('.tar.gz', '').replace('.zip', '').replace('.ckpt', '').replace('.tgz', '').replace('.gz', '')

DRIVE_CHECKPOINT_PATH = os.path.join(DRIVE_BASE_PATH, CHECKPOINT_FOLDER, MODEL_FILENAME)
DRIVE_SAMPLES_PATH = os.path.join(DRIVE_BASE_PATH, SAMPLES_FOLDER, MODEL_BASENAME)

# --- Repository Settings ---
REPO_BRANCH = "develop" # @param {type:"string"}

# --- Evaluation Settings ---
DATASET = "stocks" # @param {type:"string"}
EXPERIMENT_NAME = "evaluation_run" # @param {type:"string"}
NUM_SAMPLES = 2000 # @param {type:"integer"}
SAMPLING_METHOD = "ddpm" # @param ["ddpm", "ddim"]
COMPILE_MODE = "none" # @param ["none", "default", "reduce-overhead", "max-autotune"]
DEVICE = "cuda" # @param ["cuda", "cpu"]

# --- Evaluation Options ---
EXCLUDE_VOLUME = True # @param {type:"boolean"}
CACHE_SAMPLES_TO_DRIVE = True # @param {type:"boolean"}
USE_CACHED_SAMPLES = True # @param {type:"boolean"}

In [None]:
# @title Cell 2: Setup (Clone, Install, Mount)
import os
import sys
import shutil
import zipfile
import tarfile
from google.colab import drive

REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
REPO_NAME = "waveletDiff_synth_data"
REPO_PATH = os.path.abspath(REPO_NAME)

# 1. Clone or Pull Repo
if os.path.exists(REPO_PATH):
    nested_path = os.path.join(REPO_PATH, REPO_NAME)
    if os.path.exists(nested_path):
        print(f"⚠️ Detected nested repository at {nested_path}. cleaning up...")
        shutil.rmtree(REPO_PATH)
        print(f"Cloning {REPO_URL}...")
        !git clone {REPO_URL} {REPO_NAME}
        !git -C {REPO_NAME} checkout {REPO_BRANCH}
    else:
        print(f"Pulling latest for {REPO_NAME}...")
        !git -C {REPO_NAME} fetch origin
        !git -C {REPO_NAME} checkout {REPO_BRANCH}
        !git -C {REPO_NAME} pull origin {REPO_BRANCH}
else:
    print(f"Cloning {REPO_URL}...")
    !git clone {REPO_URL} {REPO_NAME}
    !git -C {REPO_NAME} checkout {REPO_BRANCH}

# 2. Install Dependencies
!pip install -q pytorch-lightning pywavelets scipy pandas tqdm scikit-learn tslearn seaborn statsmodels

# 3. Mount Google Drive
if not os.path.exists(DRIVE_MOUNT_PATH):
    drive.mount(DRIVE_MOUNT_PATH)

# 4. Setup Paths
if os.path.join(REPO_PATH, "src") not in sys.path:
    sys.path.append(os.path.join(REPO_PATH, "src"))
if os.path.join(REPO_PATH, "src", "evaluation") not in sys.path:
    sys.path.append(os.path.join(REPO_PATH, "src", "evaluation"))

# 5. Prepare Checkpoint (Handle .zip, .tar.gz, .tgz, .gz, .ckpt)
local_exp_dir = os.path.join(REPO_PATH, "outputs", EXPERIMENT_NAME)
os.makedirs(local_exp_dir, exist_ok=True)

if os.path.exists(DRIVE_CHECKPOINT_PATH):
    print(f"Found model file at {DRIVE_CHECKPOINT_PATH}")
    
    if DRIVE_CHECKPOINT_PATH.endswith((".zip", ".tar.gz", ".tgz", ".gz")):
        print(f"Unpacking archive to {local_exp_dir}...")
        try:
            if DRIVE_CHECKPOINT_PATH.endswith(".gz") and not DRIVE_CHECKPOINT_PATH.endswith(".tar.gz"):
                # Training notebook uses shutil.make_archive with 'gztar', so it is a tarball even if named .gz
                shutil.unpack_archive(DRIVE_CHECKPOINT_PATH, local_exp_dir, format='gztar')
            else:
                shutil.unpack_archive(DRIVE_CHECKPOINT_PATH, local_exp_dir)
            print(f"✅ Successfully unpacked archive!")
        except Exception as e:
            print(f"❌ Error unpacking archive: {e}")
            
    elif DRIVE_CHECKPOINT_PATH.endswith(".ckpt"):
        print("Detected direct .ckpt file. Copying...")
        dst_ckpt = os.path.join(local_exp_dir, "checkpoint.ckpt")
        try:
            shutil.copy2(DRIVE_CHECKPOINT_PATH, dst_ckpt)
            print(f"✅ Copied checkpoint to {dst_ckpt}")
        except Exception as e:
            print(f"❌ Error copying checkpoint: {e}")
    else:
        print(f"⚠️ Unknown file extension. Trying unpack anyway...")
        try:
             shutil.unpack_archive(DRIVE_CHECKPOINT_PATH, local_exp_dir)
             print(f"✅ Successfully unpacked archive!")
        except:
             print(f"❌ Could not unpack or identify file format.")

    if os.path.exists(os.path.join(local_exp_dir, "checkpoint.ckpt")):
         print("✅ Validated checkpoint.ckpt exists")
    else:
         print("⚠️ WARNING: checkpoint.ckpt not found.")
else:
    print(f"❌ ERROR: Model file not found at {DRIVE_CHECKPOINT_PATH}")

# Sync configs
configs_src = os.path.join(REPO_PATH, "WaveletDiff_source", "configs")
configs_dst = os.path.join(REPO_PATH, "configs")

if os.path.exists(configs_src):
    if os.path.exists(configs_dst):
        shutil.rmtree(configs_dst)
    shutil.copytree(configs_src, configs_dst)
    print("✅ Configs synced to repo root")

print("✅ Setup Complete")

In [None]:
# @title Cell 3: Generate or Load Samples
import os
import numpy as np
os.chdir(os.path.join(REPO_PATH, "src"))

# Define paths
local_gen_path = f"../outputs/{EXPERIMENT_NAME}/{SAMPLING_METHOD}_samples.npy"
local_real_path = f"../outputs/{EXPERIMENT_NAME}/real_samples.npy"

drive_gen_path = os.path.join(DRIVE_SAMPLES_PATH, f"{SAMPLING_METHOD}_samples.npy")
drive_real_path = os.path.join(DRIVE_SAMPLES_PATH, "real_samples.npy")

samples_loaded = False

# Try to load cached samples from Drive
if USE_CACHED_SAMPLES and os.path.exists(drive_gen_path) and os.path.exists(drive_real_path):
    print(f"✅ Found cached samples in Drive: {DRIVE_SAMPLES_PATH}")
    print("Loading cached samples...")
    
    os.makedirs(os.path.dirname(local_gen_path), exist_ok=True)
    shutil.copy2(drive_gen_path, local_gen_path)
    shutil.copy2(drive_real_path, local_real_path)
    
    _gen = np.load(local_gen_path)
    _real = np.load(local_real_path)
    print(f"Loaded Generated: {_gen.shape}, Real: {_real.shape}")
    samples_loaded = True

# Generate samples if not cached
if not samples_loaded:
    print(f"Generating {NUM_SAMPLES} samples using {SAMPLING_METHOD} with compile_mode={COMPILE_MODE}...")

    status = !python sample.py \
        --experiment_name {EXPERIMENT_NAME} \
        --dataset {DATASET} \
        --num_samples {NUM_SAMPLES} \
        --sampling_method {SAMPLING_METHOD} \
        --compile_mode {COMPILE_MODE}

    print("\n".join(status))

    if any("Traceback" in s or "Error" in s for s in status):
        print("\n❌ Generation failed. Check the error log above.")
    else:
        print("\n✅ Sampling Complete")
        
        # Save to Drive if configured
        if CACHE_SAMPLES_TO_DRIVE:
            print(f"Saving samples to Drive: {DRIVE_SAMPLES_PATH}...")
            os.makedirs(DRIVE_SAMPLES_PATH, exist_ok=True)
            shutil.copy2(local_gen_path, drive_gen_path)
            shutil.copy2(local_real_path, drive_real_path)
            print("✅ Samples saved to Drive!")

In [None]:
# @title Cell 4: Imports and Setup
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import pandas as pd
import os

from discriminative_metrics import discriminative_score_metrics
from predictive_metrics import predictive_score_metrics
from context_fid import Context_FID
from cross_correlation import CrossCorrelLoss
from metric_utils import display_scores
from dtw import dtw_js_divergence_distance
from advanced_metrics import calculate_distribution_fidelity, calculate_structural_alignment, calculate_financial_reality, calculate_memorization_ratio, calculate_diversity_metrics, calculate_fld

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context("notebook", font_scale=1.2)
COLORS = {"Real": "#d62728", "Generated": "#1f77b4"}

### Load Real and Generated Data

In [None]:
# Paths are relative to 'src' directory where we currently are
real_data_path = f"../outputs/{EXPERIMENT_NAME}/real_samples.npy"
gen_data_path = f"../outputs/{EXPERIMENT_NAME}/{SAMPLING_METHOD}_samples.npy"

if not os.path.exists(real_data_path):
    raise FileNotFoundError(f"Could not find real samples at {os.path.abspath(real_data_path)}. Check Cell 3 output.")

real_data = np.load(real_data_path)
generated_data = np.load(gen_data_path)
print(f"Loaded Real: {real_data.shape}, Generated: {generated_data.shape}")

# Optionally exclude volume (last feature)
if EXCLUDE_VOLUME and real_data.shape[2] > 1:
    print(f"⚠️ Excluding volume feature (last dimension). D: {real_data.shape[2]} -> {real_data.shape[2]-1}")
    real_data = real_data[:, :, :-1]
    generated_data = generated_data[:, :, :-1]
    print(f"New shapes: Real: {real_data.shape}, Generated: {generated_data.shape}")

In [None]:
num_samples = min(real_data.shape[0], generated_data.shape[0])
if real_data.shape[0] > num_samples:
    print(f"WARNING: Using all {num_samples} generated samples for evaluation.")
else:
    print(f"Number of samples: {num_samples}")

random_indices = np.random.choice(len(real_data), num_samples, replace=False)
real_data = real_data[random_indices]
random_indices = np.random.choice(len(generated_data), num_samples, replace=False)
generated_data = generated_data[random_indices]

In [None]:
# minmax scale the inputs for fair comparison
data_min = np.min(real_data, axis=(0,1), keepdims=True)
data_max = np.max(real_data, axis=(0,1), keepdims=True)

real_data = (real_data - data_min) / (data_max - data_min + 1e-8)
generated_data = (generated_data - data_min) / (data_max - data_min + 1e-8)

### Visualizations

In [None]:
# @title t-SNE & PCA Visualization

def plot_distribution_reduction(real, generated, n_samples=1000):
    n_samples = min(n_samples, len(real), len(generated))
    
    real_flat = real[:n_samples].reshape(n_samples, -1)
    gen_flat = generated[:n_samples].reshape(n_samples, -1)
    
    data = np.concatenate([real_flat, gen_flat], axis=0)
    
    print("Running t-SNE...")
    tsne = TSNE(n_components=2, perplexity=40, n_iter=300)
    tsne_results = tsne.fit_transform(data)
    
    plt.figure(figsize=(16, 6))
    
    plt.subplot(1, 2, 1)
    sns.scatterplot(x=tsne_results[:n_samples, 0], y=tsne_results[:n_samples, 1], 
                    color=COLORS["Real"], alpha=0.3, label="Real", s=20)
    sns.scatterplot(x=tsne_results[n_samples:, 0], y=tsne_results[n_samples:, 1],
                    color=COLORS["Generated"], alpha=0.3, label="Generated", s=20)
    plt.title("t-SNE Visualization")
    plt.legend()
    
    print("Running PCA...")
    pca = PCA(n_components=2)
    pca_results = pca.fit_transform(data)
    
    plt.subplot(1, 2, 2)
    sns.scatterplot(x=pca_results[:n_samples, 0], y=pca_results[:n_samples, 1], 
                    color=COLORS["Real"], alpha=0.3, label="Real", s=20)
    sns.scatterplot(x=pca_results[n_samples:, 0], y=pca_results[n_samples:, 1],
                    color=COLORS["Generated"], alpha=0.3, label="Generated", s=20)
    plt.title("PCA Visualization")
    plt.legend()
    
    plt.show()

plot_distribution_reduction(real_data, generated_data)

In [None]:
# @title Probability Density Function (Data Values)

def plot_pdf(real, generated):
    plt.figure(figsize=(10, 6))
    
    sns.kdeplot(real.flatten(), fill=True, color=COLORS["Real"], label="Real", alpha=0.3)
    sns.kdeplot(generated.flatten(), fill=True, color=COLORS["Generated"], label="Generated", alpha=0.3)
    
    plt.title("Probability Density Function (All Values)")
    plt.xlabel("Data Value")
    plt.ylabel("Density")
    plt.legend()
    plt.show()

plot_pdf(real_data, generated_data)

In [None]:
# @title Sample Visualization (Generated vs Real)

def plot_samples(real, generated, n_samples=5):
    n_features = real.shape[2]
    
    fig, axes = plt.subplots(2, n_samples, figsize=(n_samples * 4, 6), sharey=True)
    
    for i in range(n_samples):
        for f in range(n_features):
            axes[0, i].plot(real[i, :, f], alpha=0.8)
        axes[0, i].set_title(f"Real Sample {i}")
        if i == 0: axes[0, i].set_ylabel("Value (MinMax Scaled)")
        
        for f in range(n_features):
            axes[1, i].plot(generated[i, :, f], alpha=0.8)
        axes[1, i].set_title(f"Gen Sample {i}")
        if i == 0: axes[1, i].set_ylabel("Value (MinMax Scaled)")
    
    from matplotlib.lines import Line2D
    lines = [Line2D([0], [0], color=f"C{i}", lw=2) for i in range(n_features)]
    fig.legend(lines, [f"Feature {i}" for i in range(n_features)], loc='lower center', ncol=n_features)
    
    plt.tight_layout()
    plt.show()

plot_samples(real_data, generated_data)

### Statistical Metrics

In [None]:
# @title Statistical Distribution Metrics (Real vs Generated Comparison)
from scipy.stats import skew, kurtosis

def calculate_statistical_metrics(real, generated):
    r_flat = real.reshape(-1, real.shape[2])
    g_flat = generated.reshape(-1, generated.shape[2])
    n_features = real.shape[2]
    
    real_stats = {
        "Mean": np.mean(r_flat, axis=0),
        "Std": np.std(r_flat, axis=0),
        "Skewness": skew(r_flat, axis=0),
        "Kurtosis": kurtosis(r_flat, axis=0),
        "Min": np.min(r_flat, axis=0),
        "Max": np.max(r_flat, axis=0),
    }
    
    gen_stats = {
        "Mean": np.mean(g_flat, axis=0),
        "Std": np.std(g_flat, axis=0),
        "Skewness": skew(g_flat, axis=0),
        "Kurtosis": kurtosis(g_flat, axis=0),
        "Min": np.min(g_flat, axis=0),
        "Max": np.max(g_flat, axis=0),
    }
    
    print("=" * 80)
    print("STATISTICAL COMPARISON: Real vs Generated Data")
    print("=" * 80)
    
    for f in range(n_features):
        print(f"\n--- Feature {f} ---")
        print(f"{'Metric':<15} {'Real':>12} {'Generated':>12} {'Diff (Abs)':>12}")
        print("-" * 55)
        for stat_name in real_stats:
            r_val = real_stats[stat_name][f]
            g_val = gen_stats[stat_name][f]
            diff = abs(r_val - g_val)
            print(f"{stat_name:<15} {r_val:>12.4f} {g_val:>12.4f} {diff:>12.4f}")
    
    aggregate = {
        "Mean MAE": np.mean(np.abs(real_stats["Mean"] - gen_stats["Mean"])),
        "Std MAE": np.mean(np.abs(real_stats["Std"] - gen_stats["Std"])),
        "Skewness MAE": np.mean(np.abs(real_stats["Skewness"] - gen_stats["Skewness"])),
        "Kurtosis MAE": np.mean(np.abs(real_stats["Kurtosis"] - gen_stats["Kurtosis"])),
    }
    
    return aggregate, real_stats, gen_stats

stat_results, real_stats_detail, gen_stats_detail = calculate_statistical_metrics(real_data, generated_data)

### Discriminative Score

In [None]:
iterations = 5
discriminative_score = []

for i in range(iterations):
    temp_disc, fake_acc, real_acc = discriminative_score_metrics(real_data, generated_data)
    discriminative_score.append(temp_disc)
    print(f'Iter {i}: ', temp_disc, '\n')
      
display_scores(discriminative_score)
print()

### Predictive Score

In [None]:
iterations = 5
predictive_score = []
for i in range(iterations):
    temp_pred = predictive_score_metrics(real_data, generated_data)
    predictive_score.append(temp_pred)
    print(i, ' epoch: ', temp_pred, '\n')
      
display_scores(predictive_score)
print()

### Context-FID Score

In [None]:
context_fid_score = []

for i in range(iterations):
    context_fid = Context_FID(real_data, generated_data)
    context_fid_score.append(context_fid)
    print(f'Iter {i}: ', 'context-fid =', context_fid, '\n')
      
display_scores(context_fid_score)

### Correlational Score

In [None]:
def random_choice(size, num_select=100):
    select_idx = np.random.randint(low=0, high=size, size=(num_select,))
    return select_idx

x_real = torch.from_numpy(real_data)
x_fake = torch.from_numpy(generated_data)

correlational_score = []
size = 1000

for i in range(iterations):
    real_idx = random_choice(x_real.shape[0], size)
    fake_idx = random_choice(x_fake.shape[0], size)
    corr = CrossCorrelLoss(x_real[real_idx, :, :], name='CrossCorrelLoss')
    loss = corr.compute(x_fake[fake_idx, :, :])
    correlational_score.append(loss.item())
    print(f'Iter {i}: ', 'cross-correlation =', loss.item(), '\n')

display_scores(correlational_score)

### DTW distance

In [None]:
iterations = 5
js_results = []
for i in range(iterations):
    js_dist = dtw_js_divergence_distance(real_data, generated_data, n_samples=100)['js_divergence']
    print("js_dist: ", round(js_dist, 4))
    js_results.append(js_dist)
display_scores(js_results)

### Advanced Financial Metrics (Wasserstein, KS, PCA, ACF)

In [None]:
# @title Run Advanced Metrics
print("Running Distribution Fidelity Checks (Wasserstein, KS Test)...")
dist_results = calculate_distribution_fidelity(real_data, generated_data)
print("Distribution Fidelity:", dist_results)

print("\nRunning Structural Alignment Checks (PCA, t-SNE)...")
struct_results = calculate_structural_alignment(real_data, generated_data)
print("Structural Alignment:", struct_results)

print("\nRunning Financial Reality Checks (ACF, Cross-Corr, Volatility)...")
fin_results = calculate_financial_reality(real_data, generated_data)
print("Financial Reality:", fin_results)

In [None]:
# @title Run New Advanced Metrics (Memorization, Diversity, FLD)
print("\nRunning Memorization Check (1/3 Rule)...")
mem_ratio = calculate_memorization_ratio(real_data, generated_data)
print(f"Memorization Ratio: {mem_ratio:.4f}")

print("\nRunning Diversity Check (Coverage)...")
div_results = calculate_diversity_metrics(real_data, generated_data)
print(f"Diversity Metrics: {div_results}")

print("\nRunning Feature Likelihood Divergence (FLD)...")
fld_score = calculate_fld(real_data, generated_data)
print(f"FLD Score: {fld_score:.4f}")

### Centralized Metric Summary

In [None]:
# @title Summary Scorecard
import pandas as pd

summary_data = []

# 1. Statistical (Aggregated MAEs from stat_results)
for k, v in stat_results.items():
    summary_data.append({"Category": "Statistical (MAE)", "Metric": k, "Value": v, "Goal": "lower", "Description": "Diff in statistical moments (Mean/Std/Skew/Kurt)."})

# 2. Discriminative & Predictive
summary_data.append({"Category": "Model Quality", "Metric": "Discriminative Score", "Value": np.mean(discriminative_score), "Goal": "lower", "Description": "Classifier accuracy deviation from 0.5 (Real vs Fake)."})
summary_data.append({"Category": "Model Quality", "Metric": "Predictive Score", "Value": np.mean(predictive_score), "Goal": "lower", "Description": "MAE of TSTR (Train on Synthetic, Test on Real)."})
summary_data.append({"Category": "Model Quality", "Metric": "Context-FID", "Value": np.mean(context_fid_score), "Goal": "lower", "Description": "FID score on embeddings (e.g. Inception/Transformer)."})
summary_data.append({"Category": "Model Quality", "Metric": "Cross-Correl Loss", "Value": np.mean(correlational_score), "Goal": "lower", "Description": "Difference in cross-correlation matrices."})
summary_data.append({"Category": "Model Quality", "Metric": "DTW (JS Divergence)", "Value": np.mean(js_results), "Goal": "lower", "Description": "DTW-based distribution distance."})

# 3. Distribution Fidelity
summary_data.append({"Category": "Distribution Fidelity", "Metric": "Wasserstein (Mean)", "Value": dist_results["Wasserstein_Mean"], "Goal": "lower", "Description": "Earth Mover's Distance between features."})
summary_data.append({"Category": "Distribution Fidelity", "Metric": "KS Test Stat (Mean)", "Value": dist_results["KS_Stat_Mean"], "Goal": "lower", "Description": "Kolmogorov-Smirnov statistic (max diff in CDF)."})
summary_data.append({"Category": "Distribution Fidelity", "Metric": "KS P-Value (Mean)", "Value": dist_results["KS_PVal_Mean"], "Goal": "higher", "Description": "Statistical significance of KS test."})

# 4. Structural Alignment
summary_data.append({"Category": "Structural Alignment", "Metric": "PCA EVR Correlation", "Value": struct_results["PCA_EVR_Corr"], "Goal": "higher", "Description": "Correlation of PCA Explained Variance Ratios."})
# t-SNE 1-NN: ideal is 0.5, so we show distance from 0.5
tsne_val = struct_results["tSNE_1NN_Acc"]
tsne_error = abs(tsne_val - 0.5)
summary_data.append({"Category": "Structural Alignment", "Metric": "t-SNE 1-NN (|x-0.5|)", "Value": tsne_error, "Goal": "lower", "Description": "Classifier accuracy in t-SNE space (Ideal=0.5)."})

# 5. Financial Reality
summary_data.append({"Category": "Financial Reality", "Metric": "ACF MSE (Lags 1,5,20)", "Value": fin_results["ACF_MSE"], "Goal": "lower", "Description": "MSE of Autocorrelation Functions."})
summary_data.append({"Category": "Financial Reality", "Metric": "Cross-Corr Matrix Diff", "Value": fin_results["CrossCorr_Norm_Diff"], "Goal": "lower", "Description": "Norm difference of correlation matrices."})
summary_data.append({"Category": "Financial Reality", "Metric": "Volatility Clustering MSE", "Value": fin_results["Volatility_MSE"], "Goal": "lower", "Description": "MSE of squared returns ACF (Volatility)."})

# 6. New Metrics
summary_data.append({"Category": "New Metrics", "Metric": "Memorization Ratio (1/3 Rule)", "Value": mem_ratio, "Goal": "lower", "Description": "Fraction of samples that are near-duplicates of training data."})
summary_data.append({"Category": "New Metrics", "Metric": "Diversity (Coverage)", "Value": div_results["Coverage"], "Goal": "higher", "Description": "Fraction of real data covered by synthetic samples."})
summary_data.append({"Category": "New Metrics", "Metric": "FLD (Likelihood Divergence)", "Value": fld_score, "Goal": "lower", "Description": "Divergence in likelihood under real data density (GMM)."})

# Create DataFrame
df_results = pd.DataFrame(summary_data)

# Custom styling based on Goal (with black text for readability)
def style_value(row):
    val = row['Value']
    goal = row['Goal']
    
    if goal == 'higher':
        # Higher is better: Green for high, Red for low
        if val >= 0.99: return 'background-color: #2ecc71; color: black'
        elif val >= 0.9: return 'background-color: #82e0aa; color: black'
        elif val >= 0.7: return 'background-color: #f9e79f; color: black'
        else: return 'background-color: #e74c3c; color: black'
    else:
        # Lower is better: Green for low, Red for high  
        if val <= 0.01: return 'background-color: #2ecc71; color: black'
        elif val <= 0.05: return 'background-color: #82e0aa; color: black'
        elif val <= 0.15: return 'background-color: #f9e79f; color: black'
        elif val <= 0.3: return 'background-color: #f5b041; color: black'
        else: return 'background-color: #e74c3c; color: black'

styled = df_results.style.apply(lambda row: [style_value(row) if col == 'Value' else '' for col in df_results.columns], axis=1)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_colwidth', None)
display(styled)

### Per-Feature Statistical Comparison

In [None]:
# @title Feature-Level Stats Table
n_features = len(real_stats_detail['Mean'])
feature_data = []

for f in range(n_features):
    for stat_name in real_stats_detail:
        r_val = real_stats_detail[stat_name][f]
        g_val = gen_stats_detail[stat_name][f]
        diff = abs(r_val - g_val)
        feature_data.append({
            "Feature": f,
            "Stat": stat_name,
            "Real": r_val,
            "Synthetic": g_val,
            "Abs Diff": diff
        })

df_features = pd.DataFrame(feature_data)
display(df_features.style.background_gradient(cmap='RdYlGn_r', subset=['Abs Diff']))