In [None]:
# Import required libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import random
import pywt
import seaborn as sns
import os
import glob
import torchvision.utils as vutils
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from diffusion.unet import UNet
from diffusion.diffusion import Diffusion
from statsmodels.tsa.stattools import acf
from scipy import stats
from scipy.stats import ks_2samp
from scipy.signal import find_peaks
# Import detrending pipeline
from wavelet_detrending import WaveletDetrendingPipeline, visualize_detrending_effect

print("All libraries imported successfully!")


## 1. Generate Synthetic Cosine-Based Time Series

Create time series by combining multiple cosine waves with:
- Random frequencies (1-10 Hz)
- Random amplitudes (0.3-2.0)
- Random phase shifts
- Gaussian noise
- Linear trend component

In [None]:
# Generate synthetic cosine-based time series with realistic properties

np.random.seed(42)  # For reproducibility

# Parameters
num_synthetic_series = 50000  # Number of synthetic time series to generate
series_length = 64  # Same as window_size for real data
num_frequencies = 5  # Number of cosine components to combine

# Storage for synthetic series
synthetic_volume_series = []

for i in range(num_synthetic_series):
    # Initialize time series
    t = np.linspace(0, 1, series_length)
    ts = np.zeros(series_length)
    
    # Combine multiple cosine waves with random parameters
    for _ in range(num_frequencies):
        # Random frequency (higher frequencies create more variation)
        freq = np.random.uniform(1, 10)
        
        # Random amplitude (creates volume clustering effect)
        amplitude = np.random.uniform(0.3, 2.0)
        
        # Random phase shift
        phase = np.random.uniform(0, 2 * np.pi)
        
        # Add cosine component
        ts += amplitude * np.cos(2 * np.pi * freq * t + phase)
    
    # Add Gaussian noise for realism (mimics market microstructure noise)
    noise_level = 0.3
    ts += np.random.randn(series_length) * noise_level
    
    # Add trend component (mimics volume trends)
    trend = np.random.uniform(-0.5, 0.5) * t
    ts += trend
    
    synthetic_volume_series.append(ts)

synthetic_volume_series = np.array(synthetic_volume_series)

print(f"Generated {num_synthetic_series} synthetic time series")
print(f"Shape: {synthetic_volume_series.shape}")
print(f"Raw statistics BEFORE normalization:")
print(f"  Mean: {synthetic_volume_series.mean():.4f}")
print(f"  Std: {synthetic_volume_series.std():.4f}")
print(f"  Min: {synthetic_volume_series.min():.4f}")
print(f"  Max: {synthetic_volume_series.max():.4f}")

# Visualize 3 random synthetic series
fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=True)
random_indices = random.sample(range(num_synthetic_series), 3)

for i, idx in enumerate(random_indices):
    ts = synthetic_volume_series[idx]
    axes[i].plot(ts, color='tab:orange', linewidth=1.5)
    axes[i].set_ylabel('Synthetic Volume', color='tab:orange')
    axes[i].set_title(f'Synthetic Cosine-based Time Series #{idx}')
    axes[i].set_xlabel('Day')
    axes[i].grid(True, alpha=0.3)
    
fig.suptitle('Generated Synthetic Time Series (Before Normalization)', fontsize=14, y=1.0)
fig.tight_layout()
plt.show()

In [None]:
# Global Z-Score normalization (same as real GOOG data preprocessing)
print(f"{'='*60}")
print(f"GLOBAL Z-SCORE NORMALIZATION")
print(f"{'='*60}")

# Compute global statistics
synthetic_mean = float(synthetic_volume_series.mean())
synthetic_std = float(synthetic_volume_series.std())

# Apply global z-score normalization
synthetic_normalized = (synthetic_volume_series - synthetic_mean) / synthetic_std

print(f"\nOriginal synthetic data statistics:")
print(f"  Mean: {synthetic_mean:.4f}")
print(f"  Std: {synthetic_std:.4f}")
print(f"  Min: {synthetic_volume_series.min():.4f}")
print(f"  Max: {synthetic_volume_series.max():.4f}")

print(f"\nNormalized data statistics:")
print(f"  Mean: {float(synthetic_normalized.mean()):.6f}")
print(f"  Std: {float(synthetic_normalized.std()):.6f}")
print(f"  Min: {synthetic_normalized.min():.4f}")
print(f"  Max: {synthetic_normalized.max():.4f}")

# Visualize distributions before and after normalization
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Before normalization
axes[0].hist(synthetic_volume_series.flatten(), bins=100, color='steelblue', edgecolor='black', alpha=0.7, density=True)
axes[0].set_title('Synthetic Data - Before Normalization', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Value', fontsize=12)
axes[0].set_ylabel('Density', fontsize=12)
axes[0].grid(True, alpha=0.3)

# After normalization
axes[1].hist(synthetic_normalized.flatten(), bins=100, color='coral', edgecolor='black', alpha=0.7, density=True)
axes[1].set_title('Synthetic Data - After Global Z-Score', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Normalized Value', fontsize=12)
axes[1].set_ylabel('Density', fontsize=12)
axes[1].grid(True, alpha=0.3)
axes[1].axvline(x=0, color='red', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

print(f"\n{'='*60}")
print(f"Normalization parameters saved for inverse transformation:")
print(f"{'='*60}")
print(f"Synthetic: mean={synthetic_mean:.6f}, std={synthetic_std:.6f}")

In [None]:
# Autocorrelation analysis for synthetic data

max_lag = 60

# Flatten synthetic series for autocorrelation
synthetic_flat = synthetic_normalized.flatten()

# Calculate ACF
synthetic_acf = acf(synthetic_flat, nlags=max_lag, fft=True)

# Plot
plt.figure(figsize=(12, 6))
plt.stem(range(len(synthetic_acf)), synthetic_acf, linefmt='orange', markerfmt='o')
plt.title('Autocorrelation: Synthetic Cosine Volume', fontsize=14)
plt.xlabel('Lag')
plt.ylabel('ACF')
plt.grid(True, alpha=0.3)
plt.ylim(-0.3, 1.0)
plt.tight_layout()
plt.show()

print(f"\nKey Autocorrelation Values:")
print(f"  ACF(1): {synthetic_acf[1]:.4f}")
print(f"  ACF(5): {synthetic_acf[5]:.4f}")
print(f"  ACF(10): {synthetic_acf[10]:.4f}")

## 4. Generate Wavelet Images

Transform normalized time series to wavelet images using SWT (Stationary Wavelet Transform).

In [None]:
# Generate wavelet images from synthetic cosine time series using DETRENDING
# This should prevent mode collapse by removing mean bias

wavelet = 'haar'
level = int(np.log2(64))  # Max level for length 64

# Initialize detrending pipeline
synthetic_wavelet_pipeline = WaveletDetrendingPipeline(wavelet=wavelet, level=level)

# Preprocess with detrending
synthetic_wavelet_images, synth_detrending_stats = synthetic_wavelet_pipeline.preprocess(
    synthetic_normalized, 
    return_stats=True
)

print("=== SYNTHETIC WAVELET PREPROCESSING WITH DETRENDING ===")
print(f"Synthetic wavelet image array shape: {synthetic_wavelet_images.shape}")
print(f"\nDetrending Statistics:")
print(f"  Original data mean: {synth_detrending_stats['original_mean']:.4f}")
print(f"  Original data std: {synth_detrending_stats['original_std']:.4f}")
print(f"  Detrended data mean: {synth_detrending_stats['detrended_mean']:.6f}")
print(f"  Detrended data std: {synth_detrending_stats['detrended_std']:.4f}")
print("\nWavelet coefficients (before normalization):")
print(f"  Mean: {synth_detrending_stats['wavelet_mean_before_norm']:.4f}")
print(f"  Std: {synth_detrending_stats['wavelet_std_before_norm']:.4f}")
print("\nRobust normalization using 5th-95th percentiles:")
print(f"  p5: {synth_detrending_stats['p5']:.4f}")
print(f"  p95: {synth_detrending_stats['p95']:.4f}")
print("\nFinal normalized wavelet images:")
print(f"  Range: [{synthetic_wavelet_images.min():.4f}, {synthetic_wavelet_images.max():.4f}]")
print(f"  Mean: {synth_detrending_stats['final_mean']:.4f}")
print(f"  Std: {synth_detrending_stats['final_std']:.4f}")

# Visualize detrending effect
print("\n=== VISUALIZING DETRENDING EFFECT ===")
visualize_detrending_effect(synthetic_normalized, synthetic_wavelet_pipeline, idx=100)

# Visualize synthetic wavelet images
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for i in range(3):
    idx = random_indices[i] if i < len(random_indices) else i
    if idx < len(synthetic_wavelet_images):
        im = axes[i].imshow(synthetic_wavelet_images[idx, :, :, 0], aspect='auto', origin='lower', cmap='viridis')
        axes[i].set_title(f'Synthetic Wavelet Image #{idx} (DETRENDED)')
        axes[i].set_ylabel('Frequency Level')
        axes[i].set_xlabel('Time')
        plt.colorbar(im, ax=axes[i], label='Normalized Coefficients')

plt.suptitle('SWT Haar Wavelet Images - DETRENDED (Synthetic Cosine Data)', fontsize=14)
plt.tight_layout()
plt.show()

# Save synthetic wavelet images
np.save('synthetic_cosine_wavelet_images.npy', synthetic_wavelet_images)
np.save('synthetic_cosine_time_series.npy', synthetic_normalized)
print(f"\nSaved files:")
print(f"  - synthetic_cosine_wavelet_images.npy")
print(f"  - synthetic_cosine_time_series.npy")


## 5. Train Diffusion Model on Synthetic Data

Train the diffusion model using same architecture and hyperparameters as real data experiment.

In [None]:
# Train diffusion model on synthetic cosine wavelet images
# Same architecture and hyperparameters as real data training

OUT_DIR_SYNTH = './diffusion_checkpoints_synthetic'
N_SAMPLES_FOR_VIS = 8
BATCH_SIZE = 128
EPOCHS = 20
LR = 1e-3
TIMESTEPS = 1000
SAVE_EVERY_STEPS = 20000
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

os.makedirs(OUT_DIR_SYNTH, exist_ok=True)

# Use synthetic wavelet images
imgs_synth = synthetic_wavelet_images

# Ensure values in [0,1]
if imgs_synth.dtype != np.float32:
    imgs_synth = imgs_synth.astype('float32')

if imgs_synth.ndim != 4 or imgs_synth.shape[-1] != 1:
    raise RuntimeError(f"Unexpected shape: {imgs_synth.shape}. Expected (N, H, W, 1)")

H_synth, W_synth = imgs_synth.shape[1], imgs_synth.shape[2]

class WaveletDataset(Dataset):
    def __init__(self, arr):
        self.arr = arr
    def __len__(self):
        return len(self.arr)
    def __getitem__(self, i):
        img = self.arr[i]
        img = np.transpose(img, (2, 0, 1)).copy()
        img = img * 2.0 - 1.0
        return torch.from_numpy(img)

# Create dataset and dataloader
dataset_synth = WaveletDataset(imgs_synth)
loader_synth = DataLoader(dataset_synth, batch_size=BATCH_SIZE, shuffle=True)

# Model and diffusion
model_synth = UNet(in_channels=1).to(DEVICE)
diffusion_synth = Diffusion(model_synth, timesteps=TIMESTEPS, device=DEVICE, beta_schedule='cosine')
opt_synth = torch.optim.AdamW(model_synth.parameters(), lr=LR, weight_decay=0.01)

# Helper: sample and save grid
@torch.no_grad()
def sample_and_save_synth(step, num_samples=N_SAMPLES_FOR_VIS):
    model_synth.eval()
    try:
        samples = diffusion_synth.sample((num_samples, 1, H_synth, W_synth))
    except Exception as e:
        print('Sampling failed:', e)
        torch.cuda.empty_cache()
        samples = diffusion_synth.sample((min(num_samples, 4), 1, H_synth, W_synth))
    samples = samples.clamp(-1, 1)
    samples = (samples + 1.0) / 2.0
    grid = vutils.make_grid(samples.cpu(), nrow=min(8, num_samples))
    out_path = os.path.join(OUT_DIR_SYNTH, f'samples_step_{step}.png')
    vutils.save_image(grid, out_path)
    print('Saved samples to', out_path)
    model_synth.train()

# Training loop
model_synth.train()
global_step_synth = 0

print(f"\nTraining on SYNTHETIC cosine data...")
print(f"Dataset size: {len(dataset_synth)} samples")
print(f"Device: {DEVICE}")

for epoch in range(EPOCHS):
    pbar = tqdm(loader_synth, desc=f'[SYNTHETIC] Epoch {epoch+1}/{EPOCHS}')
    for batch in pbar:
        imgs_batch = batch.to(DEVICE)
        bs = imgs_batch.shape[0]
        
        t = torch.randint(0, diffusion_synth.timesteps, (bs,), device=DEVICE).long()
        loss = diffusion_synth.p_losses(imgs_batch, t, loss_type='huber')

        opt_synth.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model_synth.parameters(), 1.0)
        opt_synth.step()

        global_step_synth += 1
        pbar.set_postfix({'loss': float(loss.detach().cpu().item()), 'step': global_step_synth})

        if global_step_synth % SAVE_EVERY_STEPS == 0:
            ckpt = {'model': model_synth.state_dict(), 'opt': opt_synth.state_dict(), 
                    'epoch': epoch, 'step': global_step_synth}
            ckpt_path = os.path.join(OUT_DIR_SYNTH, f'ckpt_step_{global_step_synth}.pt')
            torch.save(ckpt, ckpt_path)
            print(f'Saved checkpoint to {ckpt_path}')
            sample_and_save_synth(global_step_synth)

    # End of epoch checkpoint
    ckpt = {'model': model_synth.state_dict(), 'opt': opt_synth.state_dict(), 
            'epoch': epoch + 1, 'step': global_step_synth}
    torch.save(ckpt, os.path.join(OUT_DIR_SYNTH, f'ckpt_epoch_{epoch+1}.pt'))
    print(f'Saved epoch {epoch+1} checkpoint')

print('Training on synthetic data finished!')

## 6. Generate Samples from Trained Model

Use the trained model to generate new synthetic wavelet images.

In [None]:
# Generate new samples from trained synthetic model
# Compare generated vs original synthetic data

CHECKPOINT_DIR_SYNTH = './diffusion_checkpoints_synthetic'
NUM_SAMPLES_SYNTH = 1000
TEMPERATURE = 1.5

# Find latest checkpoint
checkpoint_files_synth = glob.glob(os.path.join(CHECKPOINT_DIR_SYNTH, 'ckpt_epoch_*.pt'))
if not checkpoint_files_synth:
    raise RuntimeError(f"No checkpoint files found in {CHECKPOINT_DIR_SYNTH}")

latest_checkpoint_synth = max(checkpoint_files_synth, key=os.path.getctime)
print(f"Loading synthetic model checkpoint: {latest_checkpoint_synth}")

# Load model
model_gen_synth = UNet(in_channels=1).to(DEVICE)
diffusion_gen_synth = Diffusion(model_gen_synth, timesteps=1000, device=DEVICE, beta_schedule='cosine')

checkpoint_synth = torch.load(latest_checkpoint_synth, map_location=DEVICE)
model_gen_synth.load_state_dict(checkpoint_synth['model'])
model_gen_synth.eval()

print(f"Generating {NUM_SAMPLES_SYNTH} samples from synthetic-trained model with temperature={TEMPERATURE}...")

generated_from_synth = diffusion_gen_synth.sample((NUM_SAMPLES_SYNTH, 1, H_synth, W_synth), temperature=TEMPERATURE)

# Convert from [-1,1] to [0,1]
generated_from_synth = generated_from_synth.clamp(-1, 1)
generated_from_synth = (generated_from_synth + 1.0) / 2.0

# Save
generated_from_synth_np = generated_from_synth.cpu().numpy()
generated_from_synth_np = np.transpose(generated_from_synth_np, (0, 2, 3, 1))
np.save('generated_from_synthetic_model.npy', generated_from_synth_np)
print(f"Saved generated samples to 'generated_from_synthetic_model.npy'")

# Visualize grid
grid = vutils.make_grid(generated_from_synth.cpu(), nrow=4, padding=2, normalize=False)
grid_np = grid.permute(1, 2, 0).numpy()

plt.figure(figsize=(15, 12))
plt.imshow(grid_np[:, :, 0], cmap='viridis', aspect='auto')
plt.colorbar(label='Normalized Wavelet Coefficients')
plt.title(f'Generated Samples from Synthetic-Trained Model ({NUM_SAMPLES_SYNTH} samples)', fontsize=16)
plt.axis('off')
plt.tight_layout()
plt.show()

# Display individual samples
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i in range(min(8, NUM_SAMPLES_SYNTH)):
    sample = generated_from_synth[i, 0].cpu().numpy()
    axes[i].imshow(sample, aspect='auto', origin='lower', cmap='viridis')
    axes[i].set_title(f'Generated Sample {i+1}')
    axes[i].set_xlabel('Time')
    axes[i].set_ylabel('Frequency Level')

for i in range(min(8, NUM_SAMPLES_SYNTH), 8):
    axes[i].axis('off')

plt.tight_layout()
plt.show()

print(f"Generated {NUM_SAMPLES_SYNTH} samples successfully!")

## 7. Reconstruct Time Series from Generated Wavelet Images

Apply inverse wavelet transform to convert back to time series.

In [None]:
# Reconstruct time series from generated wavelet images using DETRENDING PIPELINE
# This will restore the DC component for realistic variance

# Load generated samples
try:
    generated_imgs = generated_from_synth.cpu().numpy()
    generated_imgs = np.transpose(generated_imgs, (0, 2, 3, 1))
except NameError:
    if os.path.exists('generated_from_synthetic_model.npy'):
        generated_imgs = np.load('generated_from_synthetic_model.npy')
        print("Loaded generated images from file")
    else:
        raise RuntimeError("No generated images available")

print(f"Generated wavelet images shape: {generated_imgs.shape}")

# Use detrending pipeline for reconstruction with DC restoration
if 'synthetic_wavelet_pipeline' not in globals():
    raise RuntimeError("synthetic_wavelet_pipeline not found. Run the wavelet preprocessing cell first.")

print("\n=== RECONSTRUCTION WITH DETRENDING PIPELINE ===")

# Reconstruct with automatic DC restoration
reconstructed_generated = synthetic_wavelet_pipeline.postprocess(generated_imgs, original_length=64)

print(f"Reconstructed shape: {reconstructed_generated.shape}")
print(f"Reconstruction complete with DC component restored!")

# Show detrending statistics
detrending_info = synthetic_wavelet_pipeline.get_detrending_info()
print(f"\n=== DETRENDING STATISTICS ===")
print(f"DC component (mean) statistics:")
print(f"  Mean of means: {detrending_info['mean_of_means']:.6f}")
print(f"  Std of means: {detrending_info['std_of_means']:.6f}")
print(f"  Range: [{detrending_info['min_mean']:.6f}, {detrending_info['max_mean']:.6f}]")

# Visualize reconstructed series
num_to_vis = min(8, len(reconstructed_generated))
fig, axes = plt.subplots(4, 2, figsize=(14, 12))
axes = axes.flatten()

for i in range(num_to_vis):
    vol_ts = reconstructed_generated[i]
    ax = axes[i]
    
    ax.plot(vol_ts, color='tab:purple', linewidth=1.5)
    ax.set_ylabel('Reconstructed Volume', color='tab:purple')
    ax.set_title(f'Generated Time Series #{i+1} (DC Restored)')
    ax.set_xlabel('Day')
    ax.grid(True, alpha=0.3)

for i in range(num_to_vis, 8):
    axes[i].axis('off')

plt.suptitle('Reconstructed Time Series from Generated Wavelet Images (with DC Restoration)', fontsize=14)
plt.tight_layout()
plt.show()

# Save reconstructed series
np.save('reconstructed_generated_synthetic.npy', reconstructed_generated)
print(f"\nSaved reconstructed time series to 'reconstructed_generated_synthetic.npy'")


## 8. Autocorrelation Analysis: Original vs Generated

Compare temporal structure between original synthetic data and model-generated data.

In [None]:
# Autocorrelation comparison: Original Synthetic vs Generated
# Analyze if the model learned the temporal structure

max_lag = 60

# Original synthetic data (input to model)
original_synth_flat = synthetic_normalized.flatten()

# Generated data (output from model)
generated_flat = reconstructed_generated.flatten()

# Calculate ACFs
original_synth_acf = acf(original_synth_flat, nlags=max_lag, fft=True)
generated_acf = acf(generated_flat, nlags=max_lag, fft=True)

# Side-by-side comparison
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

axes[0].stem(range(len(original_synth_acf)), original_synth_acf, linefmt='orange', markerfmt='o')
axes[0].set_title('Original Synthetic Cosine - Autocorrelation')
axes[0].set_xlabel('Lag')
axes[0].set_ylabel('ACF')
axes[0].grid(True, alpha=0.3)
axes[0].set_ylim(-0.3, 1.0)

axes[1].stem(range(len(generated_acf)), generated_acf, linefmt='purple', markerfmt='o')
axes[1].set_title('Generated from Model - Autocorrelation')
axes[1].set_xlabel('Lag')
axes[1].set_ylabel('ACF')
axes[1].grid(True, alpha=0.3)
axes[1].set_ylim(-0.3, 1.0)

plt.tight_layout()
plt.show()

# Overlay comparison
plt.figure(figsize=(12, 6))
plt.plot(range(len(original_synth_acf)), original_synth_acf, 'o-', color='orange', 
         label='Original Synthetic', alpha=0.7, markersize=4)
plt.plot(range(len(generated_acf)), generated_acf, 's-', color='purple', 
         label='Generated', alpha=0.7, markersize=4)
plt.title('Autocorrelation Comparison: Original Synthetic vs Generated')
plt.xlabel('Lag')
plt.ylabel('ACF')
plt.legend()
plt.grid(True, alpha=0.3)
plt.ylim(-0.3, 1.0)
plt.tight_layout()
plt.show()

# Statistical comparison
print(f"\n=== Autocorrelation Comparison ===")
print(f"Original Synthetic ACF(1): {original_synth_acf[1]:.4f}")
print(f"Generated ACF(1): {generated_acf[1]:.4f}")
print(f"Difference: {abs(original_synth_acf[1] - generated_acf[1]):.4f}")
print(f"\nOriginal Synthetic ACF(5): {original_synth_acf[5]:.4f}")
print(f"Generated ACF(5): {generated_acf[5]:.4f}")
print(f"Difference: {abs(original_synth_acf[5] - generated_acf[5]):.4f}")
print(f"\nOriginal Synthetic ACF(10): {original_synth_acf[10]:.4f}")
print(f"Generated ACF(10): {generated_acf[10]:.4f}")
print(f"Difference: {abs(original_synth_acf[10] - generated_acf[10]):.4f}")

# ACF similarity metrics
acf_correlation = np.corrcoef(original_synth_acf, generated_acf)[0, 1]
acf_mae = np.mean(np.abs(original_synth_acf - generated_acf))
acf_rmse = np.sqrt(np.mean((original_synth_acf - generated_acf)**2))

print(f"\n=== ACF Similarity Metrics ===")
print(f"ACF Correlation: {acf_correlation:.4f}")
print(f"ACF MAE: {acf_mae:.4f}")
print(f"ACF RMSE: {acf_rmse:.4f}")
print(f"\nInterpretation:")
print(f"  - Correlation close to 1.0 indicates model captured temporal dependencies")
print(f"  - Lower MAE/RMSE indicates better pattern matching")

## 9. Distribution Analysis: Original vs Generated

Comprehensive statistical comparison of distributions.

In [None]:
# Distribution comparison: Original Synthetic vs Generated
# Comprehensive statistical analysis

print("=== DISTRIBUTION ANALYSIS ===")

# Original synthetic data
original_synth_flat = synthetic_normalized.flatten()

# Generated data
generated_flat = reconstructed_generated.flatten()

print(f"\nOriginal Synthetic Statistics:")
print(f"  Mean: {original_synth_flat.mean():.4f}")
print(f"  Median: {np.median(original_synth_flat):.4f}")
print(f"  Std: {original_synth_flat.std():.4f}")
print(f"  Min: {original_synth_flat.min():.4f}")
print(f"  Max: {original_synth_flat.max():.4f}")

print(f"\nGenerated Statistics:")
print(f"  Mean: {generated_flat.mean():.4f}")
print(f"  Median: {np.median(generated_flat):.4f}")
print(f"  Std: {generated_flat.std():.4f}")
print(f"  Min: {generated_flat.min():.4f}")
print(f"  Max: {generated_flat.max():.4f}")

# Statistical tests
ks_stat, ks_p_value = ks_2samp(original_synth_flat, generated_flat)
print(f"\n=== STATISTICAL TESTS ===")
print(f"Kolmogorov-Smirnov Test:")
print(f"  Statistic: {ks_stat:.4f}")
print(f"  P-value: {ks_p_value:.4f}")
print(f"  Result: {'Distributions are similar' if ks_p_value > 0.05 else 'Distributions are different'}")

# Visualization
fig = plt.figure(figsize=(20, 12))

# 1. Histograms comparison
plt.subplot(3, 3, 1)
plt.hist(original_synth_flat, bins=50, alpha=0.7, label='Original Synthetic', color='orange', density=True)
plt.hist(generated_flat, bins=50, alpha=0.7, label='Generated', color='purple', density=True)
plt.title('Distribution Comparison')
plt.xlabel('Value')
plt.ylabel('Density')
plt.legend()
plt.grid(True, alpha=0.3)

# 2. Box plots
plt.subplot(3, 3, 2)
plt.boxplot([original_synth_flat, generated_flat], labels=['Original Synthetic', 'Generated'])
plt.title('Box Plot Comparison')
plt.ylabel('Value')
plt.grid(True, alpha=0.3)

# 3. Q-Q plot (Original)
plt.subplot(3, 3, 3)
stats.probplot(original_synth_flat, dist="norm", plot=plt)
plt.title('Original Synthetic Q-Q Plot')
plt.grid(True, alpha=0.3)

# 4. Q-Q plot (Generated)
plt.subplot(3, 3, 4)
stats.probplot(generated_flat, dist="norm", plot=plt)
plt.title('Generated Q-Q Plot')
plt.grid(True, alpha=0.3)

# 5. Empirical CDF
plt.subplot(3, 3, 5)
original_sorted = np.sort(original_synth_flat)
generated_sorted = np.sort(generated_flat)
original_cdf = np.arange(1, len(original_sorted) + 1) / len(original_sorted)
generated_cdf = np.arange(1, len(generated_sorted) + 1) / len(generated_sorted)

plt.plot(original_sorted, original_cdf, label='Original Synthetic', color='orange', alpha=0.8)
plt.plot(generated_sorted, generated_cdf, label='Generated', color='purple', alpha=0.8)
plt.title('Empirical CDF Comparison')
plt.xlabel('Value')
plt.ylabel('Cumulative Probability')
plt.legend()
plt.grid(True, alpha=0.3)

# 6. Kernel density estimation
plt.subplot(3, 3, 6)
sns.kdeplot(data=original_synth_flat, label='Original Synthetic', color='orange', alpha=0.7)
sns.kdeplot(data=generated_flat, label='Generated', color='purple', alpha=0.7)
plt.title('Kernel Density Estimation')
plt.xlabel('Value')
plt.ylabel('Density')
plt.legend()
plt.grid(True, alpha=0.3)

# 7. Percentile comparison
plt.subplot(3, 3, 7)
percentiles = np.arange(1, 100, 1)
original_percentiles = np.percentile(original_synth_flat, percentiles)
generated_percentiles = np.percentile(generated_flat, percentiles)

plt.plot(percentiles, original_percentiles, label='Original Synthetic', color='orange', marker='o', markersize=2)
plt.plot(percentiles, generated_percentiles, label='Generated', color='purple', marker='s', markersize=2)
plt.title('Percentile Comparison')
plt.xlabel('Percentile')
plt.ylabel('Value')
plt.legend()
plt.grid(True, alpha=0.3)

# 8. Scatter plot of percentiles
plt.subplot(3, 3, 8)
plt.scatter(original_percentiles, generated_percentiles, alpha=0.6, color='green')
plt.plot([original_percentiles.min(), original_percentiles.max()], 
         [original_percentiles.min(), original_percentiles.max()], 
         'r--', label='Perfect Match')
plt.title('Percentile Scatter Plot')
plt.xlabel('Original Synthetic Percentiles')
plt.ylabel('Generated Percentiles')
plt.legend()
plt.grid(True, alpha=0.3)

# 9. Distribution moments
plt.subplot(3, 3, 9)
moments = []
for i in range(2, 5):
    orig_moment = stats.moment(original_synth_flat, moment=i)
    gen_moment = stats.moment(generated_flat, moment=i)
    moments.append((orig_moment, gen_moment))

x = np.arange(2, 5)
orig_vals = [m[0] for m in moments]
gen_vals = [m[1] for m in moments]

width = 0.35
plt.bar(x - width/2, orig_vals, width, label='Original Synthetic', color='orange')
plt.bar(x + width/2, gen_vals, width, label='Generated', color='purple')
plt.xlabel('Moment Order')
plt.ylabel('Moment Value')
plt.title('Distribution Moments Comparison')
plt.xticks(x)
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Quantile comparison
print(f"\n=== KEY QUANTILES ===")
quantiles = [0.05, 0.25, 0.50, 0.75, 0.95]
for q in quantiles:
    orig_q = np.quantile(original_synth_flat, q)
    gen_q = np.quantile(generated_flat, q)
    print(f"  Q{q*100:02.0f}: Original={orig_q:.4f}, Generated={gen_q:.4f}, Diff={abs(orig_q-gen_q):.4f}")

print(f"\n*** INTERPRETATION ***")
print(f"This analysis compares the LEARNED distribution (generated) with the INPUT distribution (original synthetic).")
print(f"Good match indicates the diffusion model successfully learned the data distribution.")

## Key Observation: Mode Collapse Despite Good Autocorrelation

**Important Finding**: The synthetic cosine experiment shows:
- ✅ **Excellent autocorrelation matching** - Model learns temporal patterns perfectly
- ❌ **Severe distribution collapse** - Generated data has very narrow tails, concentrated near 0

**This reveals the true problem**: The mode collapse is NOT caused by DC component in frequency domain (cosines have varying means), but by the **diffusion model's tendency to generate safe, low-variance samples** during the reverse process.

**Why detrending may still help**:
- Removes per-series mean bias before training
- Forces model to learn variance/patterns independent of absolute level
- Restores realistic means during generation
- May prevent model from "playing it safe" by regressing to global mean
