In [16]:
"""
Conditional DDPM for Financial Wavelet Inputs

Model each asset window as a 2D time-frequency representation (wavelet scalogram) and 
train a conditional diffusion model to generate realistic return dynamics in the
wavelet domain, conditioned on macroeconomic information.

Primary artifacts:
- X_all: (T, N_assets, N_scales) wavelet coefficients
- C_all: (T, C_features) conditioning variables

Structured as a production ML experiment:
- explicit configuration
- dataset with leakage-aware splits
- U-Net backbone with time + conditioning embeddings
- diffusion wrapper + training loop + EMA
"""

import os
import json
import math
import warnings
import copy
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import itertools
import random
import gc
from datetime import datetime
from pathlib import Path


# Use GPU if possible, otherwise use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('using device:', device)

# We then set base hyperparameters within a callable configuration
CONFIG = {
    'paths': {
        'export_dir': 'artifacts_all',
        'checkpoint_dir': '/home/dsranelli/bigproject/checkpoint_dir'
    },
    'data': {
        'window': 128,
        'width': 128,
        'train_split': 0.8,
        'stride': 1,
        'asset_limit': 10,
        'filter_energy_sigma': 5.0,  # <-- ADD THIS LINE
    },
    'model': {
        "input_channels": 1,
        "base_channels": 128,
        'time_emb_dim': 128,
    },
    'diffusion': {
        'timesteps': 1000,
        'beta_start': 1e-4,
        'beta_end': 0.02,
        'lambda_x0': 0.1,
        'lambda_spec': 0.01
    },
    'training': {
        'batch_size': 128,
        'num_epochs': 300,
        'learning_rate': 3e-5,
        'weight_decay': 0.0,
        'max_grad_norm': 1.0,
        'early_stop_patience': 20,
        'dropout_rate': 0.1
    }
}


Starting DDPM setup...
using device: cpu
Configuration loaded successfully


In [17]:
def load_data(config):
    """
    X_all: np.ndarray
        Wavelet tensor of shape (T, Number of assets, Number of scales)
    C_all: np.ndarray
        Conditioning matrix of shape (T, conditional features)
    date_index: np.ndarray
        Array of date strings with shape (T,)
    N_ASSETS: int
        Number of assets included
    C_SCALES: int
        Conditioning feature dimension
    N_SCALES: int
        Wavelet scale dimension (height of scalogram)  
    """
    paths = config["paths"]
    export_dir = paths["export_dir"]
    asset_limit = config["data"].get("asset_limit", None)

    # Creating asset universe with metadata
    with open(os.path.join(export_dir, "basket_assets.json"), "r") as f:
        basket_assets = json.load(f)
    N_ASSETS_FULL = len(basket_assets)

    # We load our asset universe
    X_all = np.load(os.path.join(export_dir, "X_all.npy"), allow_pickle=True)

    
    # Squeeze into correct data shape
    X_raw = X_all.astype(np.float32)
    print("Loaded wavelet array:", X_raw.shape, "ndim:", X_raw.ndim, "size:", X_raw.size)


    mask = np.isfinite(X_raw)
    X_mean = X_raw[mask].mean()
    X_std  = X_raw[mask].std() + 1e-8
    
    
    X_raw = np.where(mask, X_raw, X_mean).astype(np.float32)

    X_all = (X_raw - X_mean) / X_std
    X_all = np.clip(X_all, -10, 10)


    # We then normalize all entire asset values
    X_mean = X_all.mean()
    X_std  = X_all.std() + 1e-8
    print(f"Global X mean/std before norm: {X_mean:.4f}, {X_std:.4f}")
    X_all = (X_all - X_mean) / X_std

    config["data"]["x_mean"] = float(X_mean)
    config["data"]["x_std"]  = float(X_std)

    # We then flatten if needed
    if X_all.ndim == 3:
        X_all = X_all.reshape(X_all.shape[0], -1)

    T, TOTAL_X_FEAT = X_all.shape
    EXPECTED_HEIGHT = 32
    EXPECTED_TOTAL = N_ASSETS_FULL * EXPECTED_HEIGHT

    # Reshape our values once more
    X_all = X_all.reshape(T, N_ASSETS_FULL, N_SCALES)
    
    # The we slice our assets
    if asset_limit and asset_limit < N_ASSETS_FULL:
        print(f"Slicing X_all from {N_ASSETS_FULL} to {asset_limit} assets.")
        X_all = X_all[:, :asset_limit, :]
        N_ASSETS = asset_limit
    else:
        N_ASSETS = N_ASSETS_FULL
    X_all = np.nan_to_num(X_all, nan=0.0, posinf=1.0, neginf=-1.0)
    X_all = np.clip(X_all, -10, 10)

    # Then we load our conditioning matrix
    C_all = np.load(os.path.join(export_dir, "C_all.npy"), allow_pickle=True)

    if C_all.ndim == 3:
        C_all = C_all.reshape(C_all.shape[0], -1)

    T_C, TOTAL_C_FEAT = C_all.shape

    
    if TOTAL_C_FEAT % N_ASSETS_FULL == 0:
        C_SCALES = TOTAL_C_FEAT // N_ASSETS_FULL
        C_all = C_all.reshape(T_C, N_ASSETS_FULL, C_SCALES)
        if asset_limit and asset_limit < N_ASSETS_FULL:
            C_all = C_all[:, :asset_limit, :]
        print(f"C_all shape (asset-specific): {C_all.shape}")
    else:
        C_SCALES = TOTAL_C_FEAT
        print(f"C_all shape (market-wide): {C_all.shape}")

    C_all = np.nan_to_num(C_all, nan=0.0, posinf=1.0, neginf=-1.0)
    C_all = np.clip(C_all, -10, 10)

    # We align all our information within the dates
    date_index = np.load(os.path.join(export_dir, "date_index.npy"), allow_pickle=True)
    date_index = np.array(date_index, dtype=str)

    print(f"X_all final: {X_all.shape}, C_all final: {C_all.shape}")
    RAW_DIR = "/home/dsranelli/bigproject/artifacts_all/raw"
    os.makedirs(RAW_DIR, exist_ok=True)
    return X_all, C_all, date_index, N_ASSETS, C_SCALES, N_SCALES


In [17]:
class WaveletDDPMDataset(Dataset):
    def __init__(self, config, X_all, C_all, N_ASSETS, C_SCALES):
        """
        Returns training samples in form of both:
        x: assets (1, H, W)
        c: (C_features)

        We treat wavelet slice as an image so the U-Net can learn multi-scale structure.
        """
        self.config = config
        self.X_all = X_all
        self.C_all = C_all
        self.W = config['data']['window']
        self.stride = config['data']['stride']
        self.N_ASSETS = N_ASSETS
        self.N_SCALES = X_all.shape[2]
        
        self.T = self.X_all.shape[0]
        self.time_starts = list(range(0, self.T - self.W + 1, self.stride))
        self.num_time_windows = len(self.time_starts)

        self.samples = []
        epsilon_1 = 1e-6  # tune
        for time_idx in range(self.num_time_windows):
            t_start = self.time_starts[time_idx]
            t_end   = t_start + self.W
            for asset_idx in range(N_ASSETS):

                x_window = self.X_all[t_start:t_end, asset_idx, :].T
                energy = np.sqrt((x_window**2).mean())
                if energy > epsilon_1:
                    self.samples.append((time_idx, asset_idx))
                    
        print(f"Dataset: {self.num_time_windows} windows × {self.N_ASSETS} assets = {len(self.samples)} samples (after filtering)")
        print(f"C_all shape: {C_all.shape}, C_all ndim: {C_all.ndim}")

    def __len__(self):
        return len(self.samples)

    
    def __getitem__(self, idx):
        # Get time window and asset index
        time_idx, asset_idx = self.samples[idx]
        t_start = self.time_starts[time_idx]
        t_end = t_start + self.W
        
        # Find window for wavelet
        x_window = self.X_all[t_start:t_end, asset_idx, :].T 

        # Find energy to filter out poor visuals
        energy = np.sqrt((x_window **2).mean())

        # Create conditioning window, reduced to one vector per sample
        if self.C_all.ndim == 2:
            c_window = self.C_all[t_start:t_end, :]
            c_reduced = c_window.mean(axis=0)
        else:
            c_window = self.C_all[t_start:t_end, asset_idx, :]
            c_reduced = c_window.mean(axis=0)
        
        c_full = np.concatenate(
            [c_reduced.astype(np.float32), np.array([energy], dtype=np.float32)],
            axis=0
        ) 
        x_tensor = torch.from_numpy(x_window).float().unsqueeze(0)
        c_tensor = torch.from_numpy(c_full).float()

        return x_tensor, c_tensor

def build_dataloaders(config, X_all, C_all, N_ASSETS, C_SCALES):
    dataset = WaveletDDPMDataset(config, X_all, C_all, N_ASSETS, C_SCALES)
    train_split = config['data']['train_split']

    # Split by TIME to prevent leakage
    total_windows = dataset.num_time_windows
    train_split = float(config["data"]["train_split"])

    split_idx = int(train_split * total_windows)
    split_idx = max(1, min(split_idx, total_windows - 1))

    train_indices = [i for i, (t, a) in enumerate(dataset.samples) if t < split_idx]
    val_indices   = [i for i, (t, a) in enumerate(dataset.samples) if t >= split_idx]


    train_subset = Subset(dataset, train_indices)
    val_subset   = Subset(dataset, val_indices)

    batch_size = config["training"]["batch_size"]
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True,  drop_last=True)
    val_loader   = DataLoader(val_subset,   batch_size=batch_size, shuffle=False, drop_last=False)

    return train_loader, val_loader




In [18]:
class SinusoidalTimeEmbedding(nn.Module):
    """
    Sinusoidal Embedding for Diffusion Timesteps

    Diffusion models condition the denoising network on the current timestep `t`.
    Maps integer timesteps into a continuous vector using sine/cosine features (similar to transformer embeddings).

    Input
    t : torch.Tensor
        Shape (B,) or (B, 1). Typically integer timesteps from [0, T-1].

    Output
    emb : torch.Tensor
        Shape (B, dim). No learned parameters.
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    # Computes sinusoidal embedding for timesteps
    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        freqs = torch.exp(
            torch.linspace(
                math.log(1.0),
                math.log(10000.0),
                half_dim,
                device = device))
        # Conver to float for division
        args = t.float().unsqueeze(1) / freqs.unsqueeze(0)
        # Concatenates sin and cos featues
        emb = torch.cat([torch.sin(args), torch.cos(args)], dim =-1)
        # If odd, pad to reach dimension
        if self.dim % 2 == 1:
            emb = torch.cat([emb, emb[:, :1]], dim=-1)
        return emb

In [19]:
Residual Block
class ResBlock2d(nn.Module):
    """
    Residual Block with Time/Conditioning Embedding

    Fundamental unit of the conditional U-Net within the diffusion model. 
    It processes a 2D feature map (wavelet scalogram)
    while incorporating a per-sample conditioning embedding (time step + macro/fundamental).

    Design choices:
    - Residual connections stabilize training in deep diffusion models.
    - Conditioning injected additively after the first convolution,
      standard in DDPM-style architectures and empirically stable.
    - GroupNorm is preferred over BatchNorm as diffusion models are trained
      with small or variable batch sizes.
    - SiLU is used in diffusion and transformer models
      due to smoother gradients than ReLU.
    """
    def __init__(self, in_ch, out_ch, emb_dim, kernel_size=3, dropout = 0.1):
         """
        in_ch:
            # of input feature channels.
        out_ch:
            # of output feature channels.
        emb_dim:
            Dimensionality of the conditioning embedding (time + macro features).
        kernel_size:
            Spatial kernel size for convolutions.
        dropout:
            Dropout probability applied after the second convolution.
        """
        super().__init__()
        padding = kernel_size // 2

        # Main conv path
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding)
        self.gn1   = nn.GroupNorm(8, out_ch)
        self.act   = nn.SiLU()

        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size, padding=padding)
        self.gn2   = nn.GroupNorm(8, out_ch)

        # Conditioning projection
        # Projects conditioning embedding (timestep and macro features) into channel space to be broadcasted and added to feature map 
        self.emb_proj = nn.Linear(emb_dim, out_ch)

        # Skip connection
        # If number of channels changes, project input to match output dimensions
        if in_ch != out_ch:
            self.skip = nn.Conv2d(in_ch, out_ch, kernel_size=1)
        else:
            self.skip = nn.Identity()
        # Regularization
        self.dropout = nn.Dropout2d(dropout)

    def forward(self, x, emb):
        """
        Forward pass.

        Parameters:
        x:
            Input feature map of shape (B, in_ch, H, W).
            H = number of wavelet scales and W = time window length.
        emb:
            Conditioning embedding of shape (B, emb_dim).
            A learned combination of:
              - diffusion time-step embedding
              - macro/fundamental/regime conditioning features

        Returns
        -------
        output:
            Output feature map of shape (B, out_ch, H, W).
        """

        # First conv block
        h = self.conv1(x)
        h = self.gn1(h)
        h = self.act(h)

        # Project embedding and broadcasted across spatial dimensions
        emb_out = self.emb_proj(emb).unsqueeze(-1).unsqueeze(-1)
        h = h + emb_out  # (B, out_ch, H, W)

        # Second conv block
        h = self.conv2(h)
        h = self.gn2(h)
        h = self.act(h)
        h = self.dropout(h)

        # Skip connection
        skip = self.skip(x)
        return h + skip

In [20]:
class UNET2DCond(nn.Module):
    """
    Conditional 2D U-Net for DDPM Noise Prediction

    Predicts noise component epsilon added at diffusion timestep t, given:
      - x_t: a noisy 2D image representation (wavelet scalogram)
      - t: diffusion timestep
      - cond: conditioning vector (macro information)

    Input/Output contract (typical for this project)
    x : (B, 1, H, W)  where H = N_scales (e.g., 32) & W = window length (e.g., 128)
    t : (B,) integer timesteps in [0, T-1]
    cond : (B, cond_dim) conditioning vector

    returns:
    eps_hat : (B, 1, H, W) predicted noise

    Notes (signals ML diffusion competence)
    - U-Net with skip connections to preserve fine-scale detail
    - ResBlocks inject conditioning embeddings at multiple resolutions
    - GroupNorm + SiLU are standard diffusion choices
    """   
    
    def __init__(self, in_channels, cond_dim, base_channels=64, time_emb_dim=128, dropout_rate = 0.1):
        super().__init__()
        
        # Time embedding
        # Sinusoidal embeddint and small multi-layer perception 
        # Provides smooth representtion of diffusion noise levels at timestep t
        self.time_mlp = nn.Sequential(
            SinusoidalTimeEmbedding(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim * 2),
            nn.SiLU(),
            nn.Linear(time_emb_dim * 2, time_emb_dim)
        )

        # Conditional embedding
        # Maps macro feartures into same embedding space to combine with timestep embedding
        self.cond_mlp = nn.Sequential(
            nn.Linear(cond_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )

        emb_dim = time_emb_dim

        # Input projection
        self.in_conv = nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1)

        # Downsampling path - FIXED
        # Use pairs of ResBlocks at each resolution, then strided conv for downsampling
        self.down1 = ResBlock2d(base_channels, base_channels, emb_dim)
        self.down2 = ResBlock2d(base_channels, base_channels, emb_dim)
        self.downsample1 = nn.Conv2d(base_channels, base_channels * 2, kernel_size=4, stride=2, padding=1)
        
        self.down3 = ResBlock2d(base_channels * 2, base_channels * 2, emb_dim)
        self.down4 = ResBlock2d(base_channels * 2, base_channels * 2, emb_dim)
        self.downsample2 = nn.Conv2d(base_channels * 2, base_channels * 4, kernel_size=4, stride=2, padding=1)
        
        self.down5 = ResBlock2d(base_channels * 4, base_channels * 4, emb_dim)
        self.down6 = ResBlock2d(base_channels * 4, base_channels * 4, emb_dim)
        self.downsample3 = nn.Conv2d(base_channels * 4, base_channels * 8, kernel_size=4, stride=2, padding=1)

        # Bottleneck (lowest resolution)
        self.bot1 = ResBlock2d(base_channels * 8, base_channels * 8, emb_dim)
        self.bot2 = ResBlock2d(base_channels * 8, base_channels * 8, emb_dim)

        # Upsampling path
        # Upsample, concat skip activations, and use ResBlock to fuse
        self.upsample1 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, kernel_size=4, stride=2, padding=1)
        self.up1 = ResBlock2d(base_channels * 8, base_channels * 4, emb_dim)
        
        self.upsample2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=4, stride=2, padding=1)
        self.up2 = ResBlock2d(base_channels * 4, base_channels * 2, emb_dim)
        
        self.upsample3 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=4, stride=2, padding=1)
        self.up3 = ResBlock2d(base_channels * 2, base_channels, emb_dim)

        # Final projection (back to noise prediction channel)
        self.out_conv = nn.Sequential(
            nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
            nn.SiLU(),
            nn.Conv2d(base_channels, in_channels, kernel_size=3, padding=1)  # Output: (B, 1, H, W)
        )
        if dropout_rate > 0:
            self.dropout = nn.Dropout2d(dropout_rate)
        else:
            self.dropout_rate = nn.Identity()
    

    def forward(self, x, t, cond):
        """
        Parameters
        ----------
        x:
            Noisy input scalogram, shape of (B, in_channels, H, W)
        t:
            Diffusion step indices, shape of (B,) or (B,1)
        cond:
            Conditioning features, shape of (B, cond_dim)

        Returns
        -------
        eps_hat:
            Predicted noise tensor, shape (B, in_channels, H, W)
        """

        # Building combined embedding used by all ResBlocks
        t_emb = self.time_mlp(t)
        c_emb = self.cond_mlp(cond)
        emb = t_emb + c_emb

        # Encoder
        x0 = self.in_conv(x)
        
        # Down path with skip connections
        d1 = self.down1(x0, emb)  # (B, 64, 32, 128)
        d2 = self.down2(d1, emb)  # (B, 64, 32, 128)
        d3 = self.downsample1(d2)  # (B, 128, 16, 64)
        
        d4 = self.down3(d3, emb)  # (B, 128, 16, 64)
        d5 = self.down4(d4, emb)  # (B, 128, 16, 64)
        d6 = self.downsample2(d5)  # (B, 256, 8, 32)
        
        d7 = self.down5(d6, emb)  # (B, 256, 8, 32)
        d8 = self.down6(d7, emb)  # (B, 256, 8, 32)
        x = self.downsample3(d8)  # (B, 512, 4, 16)

        # Bottleneck
        x = self.bot1(x, emb)  # (B, 512, 4, 16)
        x = self.bot2(x, emb)  # (B, 512, 4, 16)

        # Decoder with skip connections
        x = self.upsample1(x)  # (B, 256, 8, 32)
        x = torch.cat([x, d8], dim=1)  # (B, 512, 8, 32)
        x = self.up1(x, emb)  # (B, 256, 8, 32)
        
        x = self.upsample2(x)  # (B, 128, 16, 64)
        x = torch.cat([x, d5], dim=1)  # (B, 256, 16, 64)
        x = self.up2(x, emb)  # (B, 128, 16, 64)
        
        x = self.upsample3(x)  # (B, 64, 32, 128)
        x = torch.cat([x, d2], dim=1)  # (B, 128, 32, 128)
        x = self.up3(x, emb)  # (B, 64, 32, 128)

        # Final projection
        return self.out_conv(x)  # (B, 1, 32, 128)

In [21]:
# Diffusion Model
class GaussianDiffusion2D(nn.Module):
    """
    Gaussian Diffusion Wrapper (DDPM) for 2D Inputs (B, C, H, W)

    Implements:
    - Forward process q(x_t | x_0): add noise according to beta schedule
    - Training objective: predict noise epsilon (using epsilon-prediction)
    - Reverse process p(x_{t-1} | x_t): iterative denoising sampler

    Notes
    - Classic DDPM setup with a linear beta schedule.
    - Include optional auxiliary losses:
        - x0 reconstruction (lambda_x0)
        - spectral consistency along the time axis (lambda_spec)
    """
    def __init__(
        self,
        model,
        timesteps: int = 100,
        beta_start: float = 1e-4,
        beta_end: float = 0.02,
        device: str = "cpu",
        lambda_x0: float = 0.0,
        lambda_spec: float = 0.0,
    ):
        super().__init__()
        self.model = model
        self.device = device
        self.timesteps = timesteps
        self.lambda_x0 = lambda_x0
        self.lambda_spec = lambda_spec

        # Noise schedule (linear beta)
        betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)

        # Buffers with modeule
        self.register_buffer("betas", betas)
        self.register_buffer("alphas", alphas)
        self.register_buffer("alphas_cumprod", alphas_cumprod)
        self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
        self.register_buffer(
            "sqrt_one_minus_alphas_cumprod",
            torch.sqrt(1.0 - alphas_cumprod),
        )

    def q_sample(
        self,
        x_start: torch.Tensor,
        t: torch.Tensor,
        noise: torch.Tensor | None = None
    ):
        """
        Sample x_t from the forward process q(x_t | x_0, t):

            x_t = sqrt(a-bar_t) * x_0 + sqrt(1 - a-bar_t) * epsilon,   epsilon ~ N(0, I)

        Parameters
        ----------
        x_start:
            Clean input x_0 of shape (B, C, H, W).
        t:
            Timesteps (B,) is of dtype long.
        noise:
            Optional noise tensor epsilon. If None, sampled from standard Normal.

        Returns
        -------
        x_t:
            Noisy input at timestep t, shape (B, C, H, W).
        noise:
            The noise epsilon used.
        """
        if noise is None:
            noise = torch.randn_like(x_start)

        # Per-sample scalars and reshape for broadcast
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)

        x_t = sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
        return x_t, noise
    
    def predict_start_from_noise(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        """
        Reconstruct x0 from x_t and predicted noise epsilon:

            x0 = (x_t - sqrt(1 - a-bar_t) * epsilon) / sqrt(a-bar_t)
        """ 
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
    
        x0_pred = (x_t - sqrt_one_minus_alphas_cumprod_t * noise) / (sqrt_alphas_cumprod_t + 1e-8)
        return x0_pred

    @staticmethod
    def _spectral_loss(
        x_fake: torch.Tensor,
        x_real: torch.Tensor,
        eps: float = 1e-8
    ) -> torch.Tensor:
        """
        Spectral loss along the time axis (dimension W).

        For wavelet scalograms, we care that generated samples preserve
        temporal frequency content. We match log-magnitude spectra using rFFT.

        x_*: (B, C, H, W)
        """
        # Loss must be applied to the time dimension (W)
        Xf = torch.fft.rfft(x_fake, dim=-1)
        Xr = torch.fft.rfft(x_real, dim=-1)
        mag_f = torch.log(torch.abs(Xf) + eps)
        mag_r = torch.log(torch.abs(Xr) + eps)
        return F.mse_loss(mag_f, mag_r)

        # Training loss (epsilon prediction and auxiliary terms
    def p_losses(self, x_start, cond, t=None):
        """
        Computes DDPM training loss.

        Primary objective (epsilon prediction):
            L_eps = ||epsilon_theta(x_t, t, cond) - epsilon||^2

        """
        if t is None:
            t = torch.randint(0, self.timesteps, (x_start.size(0),), device=x_start.device)

        noise = torch.randn_like(x_start)
        x_noisy, _ = self.q_sample(x_start=x_start, t=t, noise=noise)
        eps_pred = self.model(x_noisy, t, cond)

        # Predict the noise with the conditional U-Net
        loss_eps = F.mse_loss(eps_pred, noise)

        # Weighted epsilon loss 
        weights = 1.0 / (1.0 - self.alphas_cumprod[t] + 1e-8)
        weights = weights.view(-1, 1, 1, 1)

        loss_eps = F.mse_loss(eps_pred, noise, reduction='none')
        loss_eps = (weights * loss_eps).mean()


        # x_0 reconstruction
        loss_x0 = 0.0
        if self.lambda_x0 > 0:
            x0_pred = self.predict_start_from_noise(x_noisy, t, eps_pred)
            loss_x0 = F.mse_loss(x0_pred, x_start)

        # Spectral term along time axis
        loss_spec = 0.0
        if self.lambda_spec > 0:
            real_spec = torch.fft.rfft(x_start, dim=-1).abs()
            gen_spec = torch.fft.rfft(x0_pred, dim=-1).abs()
            loss_spec = F.mse_loss(gen_spec, real_spec)

        return loss_eps + self.lambda_x0 * loss_x0 + self.lambda_spec * loss_spec


    # ---------- reverse diffusion for sampling ----------
    @torch.no_grad()
    def p_sample(self, x_t: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        """
        Reverse step.

        For DDPM with epsilon prediction:
            x_{t-1} = mean(x_t, epsilon_theta, t) + sigma_t * z,  z ~ N(0, I) if t > 0 else 0
        """
        betas_t = self.betas[t].view(-1, 1, 1, 1)
        alphas_t = self.alphas[t].view(-1, 1, 1, 1)
        alphas_cumprod_t = self.alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)

        # Predict epsilon for timestep
        eps_theta = self.model(x_t, t, cond)
        sqrt_alphas_cumprod_t = torch.sqrt(alphas_cumprod_t)
        x0_hat = (x_t - sqrt_one_minus_alphas_cumprod_t * eps_theta) / (
            sqrt_alphas_cumprod_t + 1e-8
        )

        # Compute mean of posterior q(x_{t-1} | x_t, x0_hat)
        mean = (1.0 / torch.sqrt(alphas_t)) * (
            x_t - (betas_t / (sqrt_one_minus_alphas_cumprod_t + 1e-8)) * eps_theta
        )

        # add noise except for t == 0
        noise = torch.randn_like(x_t)
        nonzero_mask = (t > 0).float().view(-1, 1, 1, 1)
        x_prev = mean + nonzero_mask * torch.sqrt(betas_t) * noise
        return x_prev

    @torch.no_grad()
    def sample(
        self,
        cond: torch.Tensor,
        shape: tuple[int, int, int, int],
        x_T: torch.Tensor | None = None,
        seed: int | None = None,
    ) -> torch.Tensor:
        """
        Reverse chain to sample x_0.

        Parameters
        cond:
            Conditioning tensor of shape (B, cond_dim).
        shape:
            Output sample shape (B, C, H, W).
        x_T:
            Starting noise. Enables deterministic counterfactuals.
        seed:
            For reproducible sampling (only used when x_T is None).
        """
        b = shape[0]
        cond = cond.to(self.device)

        if x_T is not None:
            x_t = x_T.to(self.device).clone()
        else:
            if seed is not None:
                g = torch.Generator(device=self.device)
                g.manual_seed(seed)
                x_t = torch.randn(shape, device=self.device, generator=g)
            else:
                x_t = torch.randn(shape, device=self.device)

        for step in reversed(range(self.timesteps)):
            t = torch.full((b,), step, device=self.device, dtype=torch.long)
            x_t = self.p_sample(x_t, t, cond)

        return x_t


In [22]:
def build_model_plus_diffusion(config, X_all, C_all):
    """
    Conditional U-Net wrapped in DDPM-style Gaussian diffusion module

    Parameters
    config:
        Experiment configuration dict containing `model` and `diffusion` blocks.
    X_all:
        Wavelet tensor. Expected shape: (T, N_assets, H_scales)
    C_all:
        Conditioning tensor
    device:
        torch.device

    Returns
    diffusion:
        GaussianDiffusion2D wrapping conditional UNet2DCond
    """
    # For energy column
    if C_all.ndim == 3:
        cond_dim = C_all.shape[2] +1
    else:
        cond_dim = C_all.shape[1] + 1
    
    print(f"Conditioning dimension (reduced): {cond_dim}")
    
    m_cfg = config['model']
    d_cfg = config['diffusion']

    unet = UNET2DCond(
        in_channels=m_cfg['input_channels'],
        cond_dim=cond_dim,
        base_channels=m_cfg['base_channels'],
        time_emb_dim=m_cfg['time_emb_dim'],
        dropout_rate = 0.1
    )

    diffusion = GaussianDiffusion2D(
        unet,
        timesteps=d_cfg['timesteps'],
        beta_start=d_cfg['beta_start'],
        beta_end=d_cfg['beta_end'],
        device=device,
        lambda_x0=d_cfg.get('lambda_x0', 0.0),
        lambda_spec=d_cfg.get('lambda_spec', 0.0),
    ).to(device)

    print('Model parameters (in millions): ',
          sum(p.numel() for p in diffusion.parameters()) / 1e6)

    return diffusion

In [23]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch

import numpy as np

import numpy as np

"""
Diagnostics and Visualization Utilities


Helpers evaluate whether generated samples reproduce key "stylized facts"
of financial returns:

- Fat tails (excess kurtosis, tail exceedance probability, survival plots)
- Volatility clustering (ACF of r^2 or |r|)
- Regime counterfactual tests (same noise seed, different conditioning)

All functions are written to be:
- batch-safe (work on tensors from DataLoader)
- numerically stable
- easy to log per epoch
"""



def clustering_score_from_acf(acf_vals: np.ndarray, lag_lo: int = 1, lag_hi: int = 5) -> float:
    """
    Single scalar summary: average ACF over lags [lag_lo..lag_hi].
    """
    acf_vals = np.asarray(acf_vals)
    lag_hi = min(lag_hi, len(acf_vals) - 1)
    if lag_hi < lag_lo:
        return float("nan")
    return float(np.mean(acf_vals[lag_lo:lag_hi + 1]))


def max_drawdown(x: np.ndarray) -> float:
    """
    Approximate max drawdown computed on cumulative sum of return series
    Rough diagnostic (not strategy backtest)
    """
    x = np.asarray(x)
    cum = np.cumsum(x)
    peak = np.maximum.accumulate(cum)
    dd = cum - peak
    return float(dd.min())

def skewness(x: np.ndarray) -> float:
    x = np.asarray(x)
    m = x.mean()
    s = x.std() + 1e-12
    return float(np.mean(((x - m) / s) ** 3))

def excess_kurtosis(x: np.ndarray) -> float:
    # Gaussian baseline of 0
    x = np.asarray(x)
    m = x.mean()
    s = x.std() + 1e-12
    return float(np.mean(((x - m) / s) ** 4) - 3.0)

def tail_prob(x: np.ndarray, k: float = 2.0) -> float:
    """
    Tail probabllity: P(|x|) > k * std(x))
    In order to declare fat-tailed returns, it must exceed Gaussian expectation.
    """
    x = np.asarray(x)
    s = x.std() + 1e-12
    return float(np.mean(np.abs(x) > (k * s)))

def wave_stats(x: np.ndarray, name: str = "", k: float = 2.0) -> dict:
    """
    Computes small set of diagnostics for return series.
    """
    x = np.asarray(x)
    return {
        "name": name,
        "mean": float(x.mean()),
        "std": float(x.std()),
        "mdd": max_drawdown(x),
        "skew": skewness(x),
        "ex_kurt": excess_kurtosis(x),
        "tailP(|r|>kσ)": tail_prob(x, k=k),
    }


# Autocorrelation and volatility clustering
def acf_1d(x: np.ndarray, nlags: int = 20) -> np.ndarray:
    """
    Simple autocorrelation function for 1D array x.
    with acf[0] = 1.
    Dependency free
    """
    x = np.asarray(x, dtype=np.float64)
    x = x - x.mean()
    denom = np.dot(x, x) + 1e-12
    out = np.empty(nlags + 1, dtype=np.float64)
    out[0] = 1.0
    for k in range(1, nlags + 1):
        out[k] = np.dot(x[:-k], x[k:]) / denom
    return out

def vol_cluster_curves(r: np.ndarray, nlags: int = 20, kind: str = "sq") -> np.ndarray:
    """
    Volatility clustering diagnostic
    kind='sq' uses r^2
    kind='abs' uses |r|
    """
    r = np.asarray(r, dtype=np.float64)
    if kind == "sq":
        v = r ** 2
    elif kind == "abs":
        v = np.abs(r)
    else:
        raise ValueError("kind must be 'sq' or 'abs'")
    return acf_1d(v, nlags=nlags)

def mean_acf_over_waves(waves: np.ndarray, nlags: int = 20, kind: str = "sq") -> tuple[np.ndarray, np.ndarray]:
    """
    Computes mean acf across multiple return window
    waves: (N, T)
    returns: mean_acf, stderr_acf across waves
    """
    waves = np.asarray(waves)
    A = np.stack([vol_cluster_curves(waves[i], nlags=nlags, kind=kind) for i in range(waves.shape[0])], axis=0)
    mean = A.mean(axis=0)
    stderr = A.std(axis=0) / np.sqrt(A.shape[0] + 1e-12)
    return mean, stderr
def mean_acf_r2_over_windows(waves: np.ndarray, nlags: int = 20) -> np.ndarray:
    """
    Finds mean r-squared over multiple return windows
    waves: (N, T) returns windows
    returns mean ACF of r^2
    """
    waves = np.asarray(waves, dtype=np.float64)
    T = waves.shape[1]
    nlags = int(min(nlags, T - 2))
    A = np.stack([acf_1d(waves[i]**2, nlags=nlags) for i in range(waves.shape[0])], axis=0)
    return A.mean(axis=0)

@torch.no_grad()
def save_regime_ab_plot(
    diffusion,
    c_base: torch.Tensor,
    epoch: int,
    save_dir: str,
    shape=(1, 1, 32, 128),
    seed: int = 123,
    inverse_fn=None,
):
    os.makedirs(save_dir, exist_ok=True)

    # Build A/B conditioning with last dim different, for regime comparison
    cA = c_base.clone().float()
    cB = c_base.clone().float()
    cA[-1] = 1.0
    cB[-1] = 2.0

    cA = cA.view(1, -1).to(diffusion.device)
    cB = cB.view(1, -1).to(diffusion.device)

    g = torch.Generator(device=diffusion.device)
    g.manual_seed(seed)
    x_T = torch.randn(shape, device=diffusion.device, generator=g)

    xA = diffusion.sample(cA, shape, x_T=x_T)
    xB = diffusion.sample(cB, shape, x_T=x_T)

    scalA = xA[0, 0].detach().cpu().numpy()
    scalB = xB[0, 0].detach().cpu().numpy()

    waveA = inverse_fn(scalA) if inverse_fn else None
    waveB = inverse_fn(scalB) if inverse_fn else None

    nlags = 20
    acfA = vol_cluster_curves(waveA, nlags=nlags, kind="sq")
    acfB = vol_cluster_curves(waveB, nlags=nlags, kind="sq")

    fig = plt.figure(figsize=(12, 7))

    # --- Scalogram A ---
    ax1 = plt.subplot(2, 2, 1)
    im1 = ax1.imshow(scalA, aspect="auto", cmap="viridis")
    ax1.set_title("Generated Scalogram (Regime A: regime=1.0)")
    ax1.set_xlabel("Time"); ax1.set_ylabel("Scale")
    plt.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04)

    # --- Scalogram B ---
    ax2 = plt.subplot(2, 2, 2)
    im2 = ax2.imshow(scalB, aspect="auto", cmap="viridis")
    ax2.set_title("Generated Scalogram (Regime B: regime=2.0)")
    ax2.set_xlabel("Time"); ax2.set_ylabel("Scale")
    plt.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04)

    # --- Waves A vs B ---
    ax3 = plt.subplot(2, 2, 3)
    ax3.plot(waveA, label="Wave A (last=1.0)")
    ax3.plot(waveB, label="Wave B (last=2.0)")
    ax3.legend()
    ax3.set_title("Inverse Wave (A vs B) — same noise, different conditioning")
    ax3.set_xlabel("Time"); ax3.set_ylabel("Value")

    # --- Volatility clustering proof: ACF of squared returns ---
    nlags = 20
    acfA = vol_cluster_curves(waveA, nlags=nlags, kind="sq")
    acfB = vol_cluster_curves(waveB, nlags=nlags, kind="sq")

    ax4 = plt.subplot(2, 2, 4)
    lags = np.arange(nlags + 1)
    ax4.plot(lags, acfA, marker="o", label="ACF(r²) A")
    ax4.plot(lags, acfB, marker="o", label="ACF(r²) B")
    ax4.axhline(0.0, linewidth=1)
    ax4.set_title("Volatility Clustering Diagnostic: ACF of r²")
    ax4.set_xlabel("Lag"); ax4.set_ylabel("ACF")
    ax4.legend()
    ax4.set_xlim(0, nlags)

    plt.tight_layout()
    out = os.path.join(save_dir, f"epoch_{epoch:04d}_regime_AB.png")
    plt.savefig(out, dpi=200)
    plt.close(fig)
    return out


def plot_scalogram_comparison(data_matrix, title, plot_path):
    """
    Plots a single (Features, Time) matrix as a 2D scalogram and saves it.
    """
    # Detach from PyTorch and convert to NumPy array for plotting
    if torch.is_tensor(data_matrix):
        data_matrix = data_matrix.detach().cpu().numpy()

    # The data matrix is (Features, Time), which is the correct orientation for imshow
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    im = ax.imshow(data_matrix, aspect='auto', interpolation='none', cmap='viridis')
    
    # Add colorbar for magnitude reference
    plt.colorbar(im, ax=ax, label='Feature Magnitude')
    
    ax.set_xlabel(f'Time Step in Window (W={data_matrix.shape[1]})')
    ax.set_ylabel(f'Wavelet Feature Index (C={data_matrix.shape[0]})')
    ax.set_title(title)
    plt.tight_layout()
    
    # Save plot to disk
    plt.savefig(plot_path)
    plt.close(fig)  # Close to save memory

def save_scalogram_plot(data_matrix, title, save_path):
    """
    Saves a scalogram plot.
    """
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    
    # Plot the data
    im = ax.imshow(data_matrix, aspect='auto', interpolation='none', cmap='viridis')
    
    # Add colorbar
    plt.colorbar(im, ax=ax, label='Feature Magnitude')
    
    # Labels and title
    ax.set_xlabel(f'Time Step in Window (W={data_matrix.shape[1]})')
    ax.set_ylabel(f'Wavelet Scales (C={data_matrix.shape[0]})')
    ax.set_title(title)
    
    plt.tight_layout()
    
    # Save to file
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close(fig)

def inverse_wavelet_from_scalogram(scalogram_2d, scales=np.arange(1, 33)):
    """
    Pseudo-inverse CWT reconstruction.
    Works for Morlet CWT when true icwt is unavailable.

    Input:
        scalogram_2d: (n_scales, n_times)

    Output:
        recon: (n_times,) reconstructed return series
    """
    S = np.asarray(scalogram_2d, dtype=np.float64)

    # Ensure shape = (n_scales, n_times)
    if S.shape[0] != len(scales) and S.shape[1] == len(scales):
        S = S.T

    if S.shape[0] != len(scales):
        raise ValueError(f"Expected one dimension = n_scales={len(scales)}, got {S.shape}")

    # Scale-weighted sum (standard pseudo-inverse)
    weights = 1.0 / np.sqrt(scales)[:, None]   # (n_scales, 1)
    recon = np.sum(S * weights, axis=0)

    # Normalize energy
    recon /= np.sqrt(len(scales))

    return recon.astype(np.float32)



def to_daily_log_returns(price_series):
    price_series = np.asarray(price_series)
    price_series = np.clip(price_series, 1e-12, None)   # avoid log(0)
    return np.diff(np.log(price_series))


import numpy as np
import matplotlib.pyplot as plt

# Finds log returns
def to_daily_log_returns(series_1d, eps=1e-12):
    """
    Computes log returns from positive-values price series.
    """
    s = np.asarray(series_1d, dtype=np.float64)
    s = np.clip(s, eps, None)
    return np.diff(np.log(s))

# --- fat tail plots ---
def save_fat_tail_diagnostics(
    x_real_batch,
    x_fake_batch,
    *,
    epoch,
    save_path,
    inverse_fn,
    max_rows=64,
    use_log_returns=True,
):
    """
    Saves a 2x2 diagnostic figure:
    - histogram (real vs gen)
    - QQ plot (generated vs normal)
    - survival plot of |returns|
    - volatility clustering: mean ACF of r^2 (real vs gen)
    """


    # Tensors to numpy
    real = x_real_batch.detach().float().cpu().numpy()[:, 0]  # (B, S, T)
    fake = x_fake_batch.detach().float().cpu().numpy()[:, 0]

    B = min(max_rows, real.shape[0])

    real_returns_all = []
    fake_returns_all = []
    real_waves = []
    fake_waves = []


    # Waves and returns
    for i in range(B):
        real_series = inverse_fn(real[i])
        fake_series = inverse_fn(fake[i])

        if use_log_returns:
            real_ret = to_daily_log_returns(real_series)
            fake_ret = to_daily_log_returns(fake_series)
        else:
            real_ret = np.asarray(real_series)
            fake_ret = np.asarray(fake_series)

        real_ret = real_ret[np.isfinite(real_ret)]
        fake_ret = fake_ret[np.isfinite(fake_ret)]

        if len(real_ret) < 10 or len(fake_ret) < 10:
            continue

        real_returns_all.append(real_ret)
        fake_returns_all.append(fake_ret)
        real_waves.append(real_ret)
        fake_waves.append(fake_ret)


    # Fat tails
    real_r = np.concatenate(real_returns_all)
    fake_r = np.concatenate(fake_returns_all)

    # Volatility clustering
    # Align lengths
    min_len = min(map(len, real_waves + fake_waves))
    real_waves = np.stack([w[:min_len] for w in real_waves], axis=0)
    fake_waves = np.stack([w[:min_len] for w in fake_waves], axis=0)

    nlags = min(20, min_len - 2)

    real_acf_mean, real_acf_se = mean_acf_over_waves(real_waves, nlags=nlags, kind="sq")
    gen_acf_mean,  gen_acf_se  = mean_acf_over_waves(fake_waves,  nlags=nlags, kind="sq")

    real_cluster_score = clustering_score_from_acf(real_acf_mean, 1, 5)
    gen_cluster_score  = clustering_score_from_acf(gen_acf_mean,  1, 5)

    print(
        f"[Epoch {epoch}] Vol clustering score "
        f"(mean ACF r^2 lags 1..5): real={real_cluster_score:.4f}, "
        f"gen={gen_cluster_score:.4f}"
    )
    mu, sd = fake_r.mean(), fake_r.std(ddof=1)
    xs = np.linspace(mu-6*sd, mu+6*sd, 600)



    # 5) Plots
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    fig.suptitle(f"Epoch {epoch} — Fat Tails & Volatility Clustering", fontsize=14)

    #  Histogram ---
    ax = axes[0, 0]
    ax.hist(real_r, bins=80, density=True, alpha=0.6, label="Real")
    ax.hist(fake_r, bins=80, density=True, alpha=0.6, label="Generated")
    ax.set_title("Return Distribution")
    norm_pdf = (1.0 / (1 * np.sqrt(2*np.pi))) * np.exp(-0.5 * ((xs - mu)/sd)**2) if sd > 0 else np.zeros_like(xs)
    ax.plot(xs, norm_pdf, linewidth=2, label="Normal fit (gen)")
    ax.set_title("Histogram (real vs gen) + normal overlay")
    ax.legend()

    # QQ plot - generated vs normal
    ax = axes[0, 1]
    fake_sorted = np.sort(fake_r)
    n = len(fake_sorted)
    # Normal quantiles via inverse error function approximation 
    # q = Phi^{-1}((i-0.5)/n)
    p = (np.arange(1, n+1) - 0.5) / n
    # Approx inverse normal (Acklam-ish simple approx is long); instead do a quick erf^{-1} using numpy via scipy-less:
    if hasattr(np, "erfcinv"):
        z = -np.sqrt(2) * np.erfcinv(2*p)
        ax.plot(z, fake_sorted, marker='.', linestyle='none', markersize=3)
        # reference line
        ax.plot([z.min(), z.max()], [mu + sd*z.min(), mu + sd*z.max()])
    else:
        ax.plot(fake_sorted)  # fallback
    ax.set_title("QQ plot vs Normal (generated)")

    # Survival plot 
    ax = axes[1, 0]
    x = np.sort(np.abs(fake_r))
    surv = 1.0 - np.arange(1, len(x) + 1) / len(x)
    ax.plot(x, surv)
    ax.set_yscale("log")
    ax.set_xlabel("|r|")
    ax.set_ylabel("P(|r| > x)")
    ax.set_title("Survival of |Returns| (log y)")

    # Volatility clustering
    ax = axes[1, 1]
    lags = np.arange(nlags + 1)

    ax.plot(lags, real_acf_mean, label="Real ACF(r²)")
    ax.fill_between(lags, real_acf_mean - real_acf_se, real_acf_mean + real_acf_se, alpha=0.2)

    ax.plot(lags, gen_acf_mean, label="Gen ACF(r²)")
    ax.fill_between(lags, gen_acf_mean - gen_acf_se, gen_acf_mean + gen_acf_se, alpha=0.2)

    ax.axhline(0.0, linewidth=1)
    ax.set_title("Volatility Clustering (ACF of r²)")
    ax.set_xlabel("Lag")
    ax.set_ylabel("ACF")
    ax.legend()

    ax.text(
        0.02, 0.02,
        f"cluster score (lags1–5):\n"
        f"  real = {real_cluster_score:.3f}\n"
        f"  gen  = {gen_cluster_score:.3f}",
        transform=ax.transAxes,
        fontsize=9,
        family="monospace",
        va="bottom"
    )

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)


import os
import numpy as np
import matplotlib.pyplot as plt

def save_epoch_metrics_png(
    metrics_history: list[dict],
    *,
    epoch: int,
    save_dir: str,
):
    """
    Saves a metrics-only dashboard PNG per epoch.
    Uses the history list (one dict per epoch).
    """

    os.makedirs(save_dir, exist_ok=True)

    def series(key, default=np.nan):
        xs = []
        for m in metrics_history:
            v = m.get(key, default)
            xs.append(np.nan if v is None else v)
        return np.array(xs, dtype=float)

    # x-axis epochs
    ep = series("epoch")
    if np.all(np.isnan(ep)):
        ep = np.arange(1, len(metrics_history) + 1)

    train_loss = series("train_loss")
    val_loss   = series("val_loss")

    # distribution alignment
    real_std = series("real_std")
    gen_std  = series("gen_std")
    std_ratio = gen_std / (real_std + 1e-12)

    real_exk = series("real_exkurt")
    gen_exk  = series("gen_exkurt")
    exk_gap  = gen_exk - real_exk

    real_sk = series("real_skew")
    gen_sk  = series("gen_skew")
    skew_gap = gen_sk - real_sk

    # tails (abs quantiles)
    real_q99 = series("real_abs_q990")
    gen_q99  = series("gen_abs_q990")
    q99_ratio = gen_q99 / (real_q99 + 1e-12)

    real_q995 = series("real_abs_q995")
    gen_q995  = series("gen_abs_q995")
    q995_ratio = gen_q995 / (real_q995 + 1e-12)

    # volatility clustering score
    real_cluster = series("real_cluster_1_5")
    gen_cluster  = series("gen_cluster_1_5")
    cluster_ratio = gen_cluster / (real_cluster + 1e-12)

    # Epoch snapshot
    m = metrics_history[-1]
    def f(key, fmt="{:.4g}", default="NA"):
        v = m.get(key, None)
        if v is None or (isinstance(v, float) and np.isnan(v)):
            return default
        try:
            return fmt.format(v)
        except Exception:
            return str(v)

    # Figure
    fig = plt.figure(figsize=(14, 10))
    fig.suptitle(f"Epoch {epoch} — Metrics Only (Fat Tails & Vol Clustering)", fontsize=16)

    # Loss
    ax = plt.subplot(2, 3, 1)
    ax.set_title("Loss")
    ax.set_xlabel("Epoch")
    ax.legend()

    # Std ratio (gen/real)
    ax = plt.subplot(2, 3, 2)
    ax.plot(ep, std_ratio)
    ax.axhline(1.0, linewidth=1)
    ax.set_title("Vol level match: std(gen)/std(real)")
    ax.set_xlabel("Epoch")

    # Tail ratio (q99, q99.5)
    ax = plt.subplot(2, 3, 3)
    ax.plot(ep, q99_ratio, label="|r| q99 ratio")
    ax.plot(ep, q995_ratio, label="|r| q99.5 ratio")
    ax.axhline(1.0, linewidth=1)
    ax.set_title("Tail match (Generated / Real)")
    ax.set_xlabel("Epoch")
    ax.legend()

    # Excess kurtosis gap
    ax = plt.subplot(2, 3, 4)
    ax.plot(ep, exk_gap)
    ax.axhline(0.0, linewidth=1)
    ax.set_title("Excess kurtosis gap: gen − real")
    ax.set_xlabel("Epoch")

    # Volatility clustering ratio
    ax = plt.subplot(2, 3, 5)
    ax.plot(ep, cluster_ratio)
    ax.axhline(1.0, linewidth=1)
    ax.set_title("Vol clustering match: score(gen)/score(real)")
    ax.set_xlabel("Epoch")

    # Text summary (epoch snapshot)
    ax = plt.subplot(2, 3, 6)
    ax.axis("off")
    text = (
        f"Snapshot @ epoch {epoch}\n\n"
        f"Loss:\n"
        f"  train={f('train_loss')}  val={f('val_loss')}\n\n"
        f"Distribution:\n"
        f"  real std={f('real_std')}  gen std={f('gen_std')}\n"
        f"  real skew={f('real_skew')} gen skew={f('gen_skew')}\n"
        f"  real exk={f('real_exkurt')} gen exk={f('gen_exkurt')}\n\n"
        f"Tails (abs):\n"
        f"  real q99={f('real_abs_q990')}  gen q99={f('gen_abs_q990')}\n"
        f"  real q99.5={f('real_abs_q995')} gen q99.5={f('gen_abs_q995')}\n\n"
        f"Vol clustering:\n"
        f"  real score={f('real_cluster_1_5')}  gen score={f('gen_cluster_1_5')}\n"
    )
    ax.text(0.02, 0.98, text, va="top", family="monospace", fontsize=10)

    plt.tight_layout()

    out_path = os.path.join(save_dir, f"metrics_epoch_{epoch:04d}.png")
    plt.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close(fig)


def plot_scalogram_and_waveform(
    x_real_batch,
    x_fake_batch,
    epoch,
    save_path,
    max_rows=3,
):
    """
    Creates a 3-across figure per sample:

        Real scalogram | Generated scalogram | Reconstructed waveform

    x_real_batch, x_fake_batch: tensors of shape (B, 1, n_scales, n_times)
    """
    # Convert to numpy
    if torch.is_tensor(x_real_batch):
        x_real = x_real_batch.detach().cpu().numpy()
    else:
        x_real = x_real_batch

    if torch.is_tensor(x_fake_batch):
        x_fake = x_fake_batch.detach().cpu().numpy()
    else:
        x_fake = x_fake_batch

    B, _, n_scales, n_times = x_real.shape
    rows = min(B, max_rows)

    fig, axes = plt.subplots(rows, 3, figsize=(12, 4 * rows))
    if rows == 1:
        axes = np.expand_dims(axes, axis=0)

    for i in range(rows):
        real_s = x_real[i, 0]
        fake_s = x_fake[i, 0]

        # Column 1: real scalogram
        ax = axes[i, 0]
        im1 = ax.imshow(real_s, aspect='auto', origin='lower', cmap='viridis')
        ax.set_title(f"Real Scalogram #{i}")
        ax.set_xlabel("Time")
        ax.set_ylabel("Scale")
        fig.colorbar(im1, ax=ax, fraction=0.046, pad=0.04)

        # Column 2: generated scalogram
        ax = axes[i, 1]
        im2 = ax.imshow(fake_s, aspect='auto', origin='lower', cmap='viridis')
        ax.set_title(f"Generated Scalogram #{i}")
        ax.set_xlabel("Time")
        ax.set_ylabel("Scale")
        fig.colorbar(im2, ax=ax, fraction=0.046, pad=0.04)

        # Column 3: inverse-wavelet time series from generated scalogram
        recon = inverse_wavelet_from_scalogram(fake_s)
        ax = axes[i, 2]
        ax.plot(recon)
        ax.set_title(f"Inverse Wavelet (Gen #{i})")
        ax.set_xlabel("Time")
        ax.set_ylabel("Value")

    plt.suptitle(f"Epoch {epoch:04d} – Real vs Generated vs Reconstructed", fontsize=14)
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)

import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

def plot_fat_tail_checks(returns, title="Daily returns"):
    r = np.asarray(returns)
    r = r[np.isfinite(r)]

    fig, axes = plt.subplots(2, 2, figsize=(12, 8))

    # 1) Histogram + normal overlay
    ax = axes[0, 0]
    ax.hist(r, bins=80, density=True)
    mu, sigma = r.mean(), r.std(ddof=1)
    xs = np.linspace(mu - 5*sigma, mu + 5*sigma, 800)
    ax.plot(xs, stats.norm.pdf(xs, mu, sigma))
    ax.set_title(f"{title} (hist + normal overlay)")

    # 2) QQ plot vs Normal
    ax = axes[0, 1]
    stats.probplot(r, dist="norm", plot=ax)
    ax.set_title("QQ plot vs Normal")

    # 3) Tail survival plot (log y-scale)
    ax = axes[1, 0]
    x = np.sort(np.abs(r))
    surv = 1.0 - (np.arange(1, len(x)+1) / len(x))
    ax.plot(x, surv)
    ax.set_yscale("log")
    ax.set_title("Survival of |returns| (log scale)")
    ax.set_xlabel("|r|")
    ax.set_ylabel("P(|r| > x)")

    # 4) Summary stats
    ax = axes[1, 1]
    kurt = stats.kurtosis(r, fisher=True, bias=False)  # excess kurtosis
    skew = stats.skew(r, bias=False)
    ax.axis("off")
    ax.text(
        0.02, 0.98,
        f"n = {len(r)}\nmean = {mu:.4g}\nstd = {sigma:.4g}\nskew = {skew:.4g}\nexcess kurtosis = {kurt:.4g}",
        va="top"
    )
    ax.set_title("Moments (fat tails → high excess kurtosis)")

    plt.tight_layout()
    plt.show()



In [24]:
def basic_stats(x: torch.Tensor):
    """
    Compute per-channel mean and variance over batch and time.
    x: (B, C, T) or (B, D)

    Returns:
        mean 
        var
    """
    if x.ndim == 3:
        x_flat = x.permute(0, 2, 1).reshape(-1, x.size(1))
    else:
        x_flat = x
    mean = x_flat.mean(dim=0)
    var = x_flat.var(dim=0)
    return mean, var

def avg_pairwise_dist(x: torch.Tensor) -> float:
    """
    Average pairwise Euclidean distance between samples in a batch.

    Parameters
    x:
        Tensor of shape (B, ). Will be flattened per sample.
    max_pairs:
        Computes distance on a random subset of pairs to avoid large complexity.

    Returns
    float
        Mean pairwise distance.
    """
    B = x.size(0)
    x_flat = x.view(B, -1)
    diff = x_flat.unsqueeze(1) - x_flat.unsqueeze(0)
    dist_mat = (diff ** 2).mean(dim=-1).sqrt()
    mask = torch.triu(torch.ones_like(dist_mat), diagonal=1) > 0
    return dist_mat[mask].mean().item()

def rbf_kernel(x: torch.Tensor, y: torch.Tensor, sigma: float = 1.0) -> torch.Tensor:
    """
    RBF kernel K(x, y) = exp(-||x-y||^2 / (2*sigma^2)).

    x : (N, D)
    y : (M, D)
    returns : (N, M)
    """
    x = x.unsqueeze(1)
    y = y.unsqueeze(0)
    dist2 = ((x - y) ** 2).sum(-1)
    return torch.exp(-dist2 / (2 * sigma ** 2))

def mmd_flat(x: torch.Tensor, y: torch.Tensor, sigma: float = 10.0) -> float:
    """
    Maximum Mean Discrepancy (MMD) between two sample sets using RBF kernel.

    Parameters
    ----------
    x, y:
        Tensors of shape (B, ) flattened to (B, D).
    sigma:
        Kernel bandwidth

    Returns
    float
        MMD estimate (scalar).
    """
    x_flat = x.view(x.size(0), -1)
    y_flat = y.view(y.size(0), -1)
    Kxx = rbf_kernel(x_flat, x_flat, sigma).mean()
    Kyy = rbf_kernel(y_flat, y_flat, sigma).mean()
    Kxy = rbf_kernel(x_flat, y_flat, sigma).mean()
    return (Kxx + Kyy - 2 * Kxy).item()

def compute_fid(real_features, fake_features):
    """Frechet Inception Distance approximation"""
    mu_real, sigma_real = torch.mean(real_features, dim=0), torch.cov(real_features.T)
    mu_fake, sigma_fake = torch.mean(fake_features, dim=0), torch.cov(fake_features.T)
    
    diff = mu_real - mu_fake
    covmean = torch.sqrt(sigma_real @ sigma_fake)
    
    if torch.iscomplex(covmean):
        covmean = covmean.real
        
    fid = diff.dot(diff) + torch.trace(sigma_real + sigma_fake - 2 * covmean)
    return fid.item()

def compute_metrics(real_batch, fake_batch):
    """Computes comprehensive evaluation metrics"""
    with torch.no_grad():
        # Basic statistics
        mu_real, var_real = basic_stats(real_batch)
        mu_fake, var_fake = basic_stats(fake_batch)
        
        mean_diff = (mu_real - mu_fake).abs().mean().item()
        var_diff = (var_real - var_fake).abs().mean().item()
        
        # MMD
        mmd_val = mmd_flat(real_batch, fake_batch, sigma=10.0)
        
        # Diversity metrics
        real_flat = real_batch.view(real_batch.size(0), -1)
        fake_flat = fake_batch.view(fake_batch.size(0), -1)
        
        real_diversity = real_flat.std(dim=0).mean().item()
        fake_diversity = fake_flat.std(dim=0).mean().item()
        diversity_ratio = fake_diversity / max(real_diversity, 1e-8)
        
        # Energy statistics
        real_energy = (real_batch ** 2).mean().item()
        fake_energy = (fake_batch ** 2).mean().item()
        energy_diff = abs(real_energy - fake_energy) / max(real_energy, 1e-8)
        
        # Spectrum correlation (frequency domain)
        real_fft = torch.fft.rfft(real_batch, dim=-1)
        fake_fft = torch.fft.rfft(fake_batch, dim=-1)
        
        real_mag = torch.abs(real_fft)
        fake_mag = torch.abs(fake_fft)
        
        # Flatten for correlation
        real_mag_flat = real_mag.view(-1)
        fake_mag_flat = fake_mag.view(-1)
        
        # Handle constant tensors
        if real_mag_flat.std() > 1e-8 and fake_mag_flat.std() > 1e-8:
            spectrum_corr = torch.corrcoef(
                torch.stack([real_mag_flat, fake_mag_flat])
            )[0, 1].item()
        else:
            spectrum_corr = 0.0
    
    return {
        'mean_diff': mean_diff,
        'var_diff': var_diff,
        'mmd': mmd_val,
        'real_diversity': real_diversity,
        'fake_diversity': fake_diversity,
        'diversity_ratio': diversity_ratio,
        'real_energy': real_energy,
        'fake_energy': fake_energy,
        'energy_diff': energy_diff,
        'spectrum_correlation': spectrum_corr,
        'combined_score': mean_diff + 0.5 * var_diff + 0.1 * mmd_val + 0.2 * energy_diff
    }

In [25]:
# Training Loop with Visualization - MODIFIED to save EVERY epoch



def train_ddpm_with_plots(config, diffusion, train_loader, val_loader, experiment_name="exp1"):
    """
    Trains a conditional DDPM and save:
      1) Checkpoints (best + periodic)
      2) Qualitative plots every epoch (real vs generated scalograms)
      3) Finance diagnostics every epoch:
         - inverse-wavelet reconstructed "return" waves
         - volatility clustering ACF(r^2)
         - fat-tail diagnostics (hist/QQ/survival)
      4) Metric history (loss, MMD, moment diffs, learning rate, grad norms)


    Inputs
    config : dict
        Needs:
          - config['training']: learning_rate, weight_decay, num_epochs, max_grad_norm, etc.
          - config['paths']: checkpoint_dir
          - optionally config['data']: x_mean, x_std for de-normalization
    diffusion : GaussianDiffusion2D
        Diffusion wrapper with .p_losses() (training loss) and .sample() (reverse process).
    train_loader / val_loader : DataLoader
        Yield tuples (x_batch, c_batch), where:
          x_batch: (B, 1, H, W) scalograms
          c_batch: (B, cond_dim) conditioning vectors
    experiment_name : str
        Used to separate output folders for different runs.

    Returns
    checkpoint_path : str
        Path to the "best" checkpoint (NOTE: current code overwrites every epoch).
    best_val_loss : float
    history : dict
        Contains loss curves and simple distribution metrics.
    """
    train_cfg = config['training']
    paths = config['paths']
    
    
    # Create directories
    checkpoint_dir = paths['checkpoint_dir']
    plot_dir = '/home/dsranelli/bigproject/artifacts_all/best_samples'
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(plot_dir, exist_ok=True)
    
    # Create subdirectories for this experiment
    exp_plot_dir = os.path.join(plot_dir, experiment_name)
    os.makedirs(exp_plot_dir, exist_ok=True)
    
    # Create epoch-specific subdirectory
    epochs_dir = os.path.join(exp_plot_dir, "epoch_plots_with_wave")
    os.makedirs(epochs_dir, exist_ok=True)

    # Optimizer and learning rate scheduler
    optimizer = torch.optim.AdamW(
        diffusion.parameters(),
        lr=config['training']['learning_rate'],
        weight_decay=config['training']['weight_decay']
    )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=config['training']['num_epochs'],
        eta_min=config['training']['learning_rate'] * 0.1
    )



    # Training hyperparameters
    num_epochs = train_cfg['num_epochs']
    max_grad_norm = train_cfg['max_grad_norm']
    early_stop_pat = train_cfg.get('early_stop_patience', None)
    
    
    # Training history for plotting
    history = {
        'epochs': [],
        'train_loss': [],
        'val_loss': [],
        'mean_diff': [],
        'var_diff': [],
        'mmd': [],
        'learning_rates': [],
        'grad_norms': []
    }
    
    print(f"\nStarting training for experiment: {experiment_name}")
    print(f"Checkpoints saved to: {checkpoint_dir}")
    print(f"Epoch plots saved to: {epochs_dir}")

    ema_decay = 0.995
    ema_diffusion = copy.deepcopy(diffusion).to(diffusion.device)
    for p in ema_diffusion.parameters():
        p.requires_grad_(False)

    def update_ema(ema_model, model, decay):
        """
        EMA update: ema = decay * ema + (1 - decay) * model
        
        """
        
        with torch.no_grad():
            for (k, v), (_, v_model) in zip(ema_model.state_dict().items(),
                                            model.state_dict().items()):
                v.copy_(decay * v + (1.0 - decay) * v_model)
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(plot_dir, exist_ok=True)
    
    # Create subdirectories for this experiment
    exp_plot_dir = os.path.join(plot_dir, experiment_name)
    os.makedirs(exp_plot_dir, exist_ok=True)
    
    # Create epoch-specific subdirectory
    epochs_dir = os.path.join(exp_plot_dir, "epoch_plots_with_wave")
    os.makedirs(epochs_dir, exist_ok=True)
    # Directory for reverse-wavelet visualizations:
    waves_dir = os.path.join(exp_plot_dir, "epoch_plots1000")
    os.makedirs(waves_dir, exist_ok=True)



    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        
        # Train
    
        diffusion.train()
        train_loss_accum = 0.0
        n_train = 0
        grad_norms = []
        
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [train]")
        for step, (x_batch, c_batch) in enumerate(train_pbar):
            x_batch = x_batch.to(diffusion.device)
            c_batch = c_batch.to(diffusion.device)
            
            optimizer.zero_grad(set_to_none = True)

            # Training objective
            loss = diffusion.p_losses(x_batch, c_batch)
            loss.backward()
            
            # Gradient clipping
            max_grad = train_cfg.get('max_grad_norm', 1.0)
            if max_grad is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    diffusion.parameters(),  # <-- main model, not EMA
                    max_norm=max_grad
                ).item()
            else:
                grad_norm = torch.norm(
                    torch.stack([
                        p.grad.norm() for p in diffusion.parameters()
                        if p.grad is not None
                    ])
                ).item()
            optimizer.step()
            grad_norms.append(grad_norm)

            # Update EMA weights after each optimizer step
            update_ema(ema_diffusion, diffusion, ema_decay)

            # Update statistics
            bs = x_batch.size(0)
            train_loss_accum += loss.item() * bs
            n_train += bs
            
            # Update progress bar
            train_pbar.set_postfix({'loss': loss.item()})
        


        
            train_loss = train_loss_accum / max(n_train, 1)
            avg_grad_norm = np.mean(grad_norms) if grad_norms else 0
            current_lr = optimizer.param_groups[0]['lr']

        # Step LR schedule once per epoch
        scheduler.step()
        
        history['train_loss'].append(train_loss)
        history['grad_norms'].append(avg_grad_norm)
        history['learning_rates'].append(current_lr)
        
        # Validation
        diffusion.eval()
        val_loss_accum = 0.0
        n_val = 0
        
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [val]")
        with torch.no_grad():
            for x_batch, c_batch in val_pbar:
                x_batch = x_batch.to(diffusion.device)
                c_batch = c_batch.to(diffusion.device)
                
                loss = diffusion.p_losses(x_batch, c_batch)
                bs = x_batch.size(0)
                val_loss_accum += loss.item() * bs
                n_val += bs
                
                val_pbar.set_postfix({'loss': loss.item()})
        
        val_loss = val_loss_accum / max(n_val, 1)
        history['val_loss'].append(val_loss)
        history['epochs'].append(epoch + 1)

        
        # Sample generation and metrics
        with torch.no_grad():
            try:
                # Get validation batch for visualization
                x_val_batch, c_val_batch = next(iter(val_loader))
                x_val_batch = x_val_batch[:64].to(diffusion.device)
                c_val_batch = c_val_batch[:64].to(diffusion.device)
                
                # Generate samples using ema
                x_fake_batch = ema_diffusion.sample(c_val_batch, x_val_batch.shape)
                # De-normalize for plotting / metric checks
                if 'x_mean' in config['data']:
                    mu = config['data']['x_mean']
                    sigma = config['data']['x_std']
                    x_val_batch_plot = x_val_batch * sigma + mu
                    x_fake_batch_plot = x_fake_batch * sigma + mu
                else:
                    x_val_batch_plot = x_val_batch
                    x_fake_batch_plot = x_fake_batch
                
                # Distribution checks
                mu_real, var_real = basic_stats(x_val_batch)
                mu_fake, var_fake = basic_stats(x_fake_batch)
                mean_diff = (mu_real - mu_fake).abs().mean().item()
                var_diff = (var_real - var_fake).abs().mean().item()
                mmd_val = mmd_flat(x_val_batch, x_fake_batch, sigma=10.0)
                
                history['mean_diff'].append(mean_diff)
                history['var_diff'].append(var_diff)
                history['mmd'].append(mmd_val)
                
                # Convert to numpy for plotting
                real_np = x_val_batch.cpu().numpy()
                fake_np = x_fake_batch.cpu().numpy()

            # Edge case for empty validation loader
            except StopIteration:
                mean_diff = var_diff = mmd_val = float('nan')
                history['mean_diff'].append(float('nan'))
                history['var_diff'].append(float('nan'))
                history['mmd'].append(float('nan'))
                real_np = fake_np = None
        
        epoch_time = time.time() - epoch_start_time
        
        # Print Progress
        print(f"\n{'='*60}")
        print(f"Epoch {epoch+1}/{num_epochs} - Time: {epoch_time:.1f}s")
        print(f"{'='*60}")
        print(f"Train Loss: {train_loss:.6f}")
        print(f"Val Loss:   {val_loss:.6f}")
        print(f"LR:         {current_lr:.2e}")
        print(f"Grad Norm:  {avg_grad_norm:.2f}")
        
        if not np.isnan(mean_diff):
            print(f"{'='*60}")
            print("Generation Metrics:")
            print(f"  Mean Diff:  {mean_diff:.6f}")
            print(f"  Var Diff:   {var_diff:.6f}")
            print(f"  MMD:        {mmd_val:.6f}")
        print(f"{'='*60}")
        
        # Save Visualizations for each Epoch
        if real_np is not None and fake_np is not None:
            # Create epoch directory
            epoch_dir = os.path.join(epochs_dir, f"epoch_{epoch+1:04d}")
            os.makedirs(epoch_dir, exist_ok=True)
            
            # Save individual samples
            for i in range(min(4, real_np.shape[0])):
                # Real sample
                real_data = real_np[i, 0]
                real_path = os.path.join(epoch_dir, f"sample_{i}_real.png")
                save_scalogram_plot(
                    real_data,
                    f'Real Scalogram - Epoch {epoch+1}, Sample {i} (Loss: {val_loss:.4f})',
                    real_path
                )
                
                # Generated sample
                fake_data = fake_np[i, 0]
                fake_path = os.path.join(epoch_dir, f"sample_{i}_generated.png")
                save_scalogram_plot(
                    fake_data,
                    f'Generated Scalogram - Epoch {epoch+1}, Sample {i} (Loss: {val_loss:.4f})',
                    fake_path
                )
            
            # Save comparison grid
            fig, axes = plt.subplots(2, 4, figsize=(16, 8))
            fig.suptitle(f'Epoch {epoch+1} - Val Loss: {val_loss:.4f}', fontsize=16)
            vmin, vmax = -3, 3   # adjust if needed
            for i in range(min(4, real_np.shape[0])):
                im1 = axes[0, i].imshow(
                    real_np[i, 0],
                    aspect='auto',
                    cmap='viridis',
                    vmin=vmin,
                    vmax=vmax
                )
                axes[0, i].set_title(f'Real Sample {i}')
                axes[0, i].set_ylabel('Scales')
                axes[0, i].set_xlabel('Time')
                plt.colorbar(im1, ax=axes[0, i], shrink=0.7)

                # Fake
                im2 = axes[1, i].imshow(
                    fake_np[i, 0],
                    aspect='auto',
                    cmap='viridis',
                    vmin=vmin,
                    vmax=vmax
                )
                axes[1, i].set_title(f'Generated Sample {i}')
                axes[1, i].set_ylabel('Scales')
                axes[1, i].set_xlabel('Time')
                plt.colorbar(im2, ax=axes[1, i], shrink=0.7)
            
            plt.tight_layout()
            grid_path = os.path.join(epoch_dir, "comparison_grid.png")
            plt.savefig(grid_path, dpi=150, bbox_inches='tight')
            plt.close(fig)
            
            
            # Save 3-across plot: [Real | Generated | Inverse Wavelet]
            wave_plot_path = os.path.join(
                waves_dir,
                f"epoch_{epoch+1:04d}_waves.png"
            )

            # Fixed base conditioning vector once (so only last column changes)
            if epoch == 0:
                c_base_fixed = c_val_batch[0].detach().cpu()
                print("Cached c_base_fixed. Last element (before override):", float(c_base_fixed[-1]))

            # Save deterministic Regime A/B plot each epoch
            ab_path = save_regime_ab_plot(
                diffusion=ema_diffusion,
                c_base=c_base_fixed,
                epoch=epoch+1,
                save_dir=waves_dir,
                shape=(1, 1, 32, 128),
                seed=123,
                inverse_fn=inverse_wavelet_from_scalogram,
            )
            print("Saved regime A/B plot:", ab_path)

            plot_scalogram_and_waveform(
                x_real_batch=x_val_batch,
                x_fake_batch=x_fake_batch,
                epoch=epoch+1,
                save_path=wave_plot_path,
                max_rows=3,   # up to 3 samples per epoch
            )
            # Save fat-tail diagnostics right after wave plots
            fat_tail_path = os.path.join(waves_dir, f"epoch_{epoch+1:04d}_fat_tails.png")

            save_fat_tail_diagnostics(
                x_real_batch=x_val_batch_plot if 'x_mean' in config['data'] else x_val_batch,
                x_fake_batch=x_fake_batch_plot if 'x_mean' in config['data'] else x_fake_batch,
                epoch=epoch+1,
                save_path=fat_tail_path,
                inverse_fn=inverse_wavelet_from_scalogram,
                max_rows=20000,
                use_log_returns=False,
            )







            # Metrics history
            if "metrics_history" not in locals():
                metrics_history = []

            # Batch selection
            x_real_for_metrics = x_val_batch_plot if 'x_mean' in config['data'] else x_val_batch
            x_fake_for_metrics = x_fake_batch_plot

            # Convert to numpy scalograms: (B,1,S,T) -> (B,S,T)
            real_np = x_real_for_metrics.detach().float().cpu().numpy()[:, 0]
            fake_np = x_fake_for_metrics.detach().float().cpu().numpy()[:, 0]

            B = min(128, real_np.shape[0])

            # Reconstruct "returns waves" per sample using inverse
            real_waves = []
            gen_waves  = []
            for i in range(B):
                r = inverse_wavelet_from_scalogram(real_np[i])
                g = inverse_wavelet_from_scalogram(fake_np[i])

                r = np.asarray(r, dtype=np.float64)
                g = np.asarray(g, dtype=np.float64)

                r = r[np.isfinite(r)]
                g = g[np.isfinite(g)]
                if len(r) < 20 or len(g) < 20:
                    continue

                real_waves.append(r)
                gen_waves.append(g)

            # align lengths for ACF
            min_len = min(map(len, real_waves + gen_waves))
            real_waves = np.stack([w[:min_len] for w in real_waves], axis=0)  # (N,T)
            gen_waves  = np.stack([w[:min_len] for w in gen_waves], axis=0)

            # flatten for distribution metrics
            real_r = real_waves.reshape(-1)
            gen_r  = gen_waves.reshape(-1)

            # compute clustering scores
            real_acf = mean_acf_r2_over_windows(real_waves, nlags=20)
            gen_acf  = mean_acf_r2_over_windows(gen_waves,  nlags=20)

            row = {
                "epoch": int(epoch),


                "real_std": float(np.std(real_r, ddof=1)),
                "gen_std": float(np.std(gen_r,  ddof=1)),

                "real_skew": float(skewness(real_r)),
                "gen_skew": float(skewness(gen_r)),

                "real_exkurt": float(excess_kurtosis(real_r)),
                "gen_exkurt": float(excess_kurtosis(gen_r)),

                "real_abs_q990": float(np.quantile(np.abs(real_r), 0.990)),
                "gen_abs_q990": float(np.quantile(np.abs(gen_r),  0.990)),

                "real_abs_q995": float(np.quantile(np.abs(real_r), 0.995)),
                "gen_abs_q995": float(np.quantile(np.abs(gen_r),  0.995)),

                "real_cluster_1_5": float(clustering_score_from_acf(real_acf, 1, 5)),
                "gen_cluster_1_5": float(clustering_score_from_acf(gen_acf,  1, 5)),
            }

            metrics_history.append(row)

            # Save PNG dashboard for this epoch
            save_epoch_metrics_png(
                metrics_history,
                epoch=int(epoch),
                save_dir="/home/dsranelli/bigproject/artifacts_all/best_samples/full_training/epoch_plots1000/metrics_png"
            )


            
            
            # Save metrics plot for this epoch
            fig, axes = plt.subplots(2, 3, figsize=(15, 10))
            
            # Loss plot (up to current epoch)
            epochs_so_far = list(range(1, len(history['epochs']) + 1))
            axes[0, 0].plot(epochs_so_far, history['train_loss'], 'b-', label='Train', marker='o', markersize=3)
            axes[0, 0].plot(epochs_so_far, history['val_loss'], 'r-', label='Val', marker='s', markersize=3)
            axes[0, 0].axvline(x=epoch+1, color='g', linestyle='--', alpha=0.5)
            axes[0, 0].set_xlabel('Epoch')
            axes[0, 0].set_ylabel('Loss')
            axes[0, 0].set_title('Training Progress')
            axes[0, 0].legend()
            axes[0, 0].grid(True, alpha=0.3)
            
            # Learning rate
            axes[0, 1].plot(epochs_so_far, history['learning_rates'], 'g-', marker='^', markersize=3)
            axes[0, 1].axvline(x=epoch+1, color='g', linestyle='--', alpha=0.5)
            axes[0, 1].set_xlabel('Epoch')
            axes[0, 1].set_ylabel('Learning Rate')
            axes[0, 1].set_title('Learning Rate Schedule')
            axes[0, 1].grid(True, alpha=0.3)
            axes[0, 1].set_yscale('log')
            
            # Gradient norms
            axes[0, 2].plot(epochs_so_far, history['grad_norms'], 'm-', marker='d', markersize=3)
            axes[0, 2].axvline(x=epoch+1, color='g', linestyle='--', alpha=0.5)
            axes[0, 2].set_xlabel('Epoch')
            axes[0, 2].set_ylabel('Gradient Norm')
            axes[0, 2].set_title('Gradient Norms')
            axes[0, 2].grid(True, alpha=0.3)
            
            # Mean difference
            if not all(np.isnan(history['mean_diff'])):
                axes[1, 0].plot(epochs_so_far, history['mean_diff'], 'c-', marker='o', markersize=3)
                axes[1, 0].axvline(x=epoch+1, color='g', linestyle='--', alpha=0.5)
                axes[1, 0].set_xlabel('Epoch')
                axes[1, 0].set_ylabel('Mean Difference')
                axes[1, 0].set_title('Mean Distribution Difference')
                axes[1, 0].grid(True, alpha=0.3)
            
            # Variance difference
            if not all(np.isnan(history['var_diff'])):
                axes[1, 1].plot(epochs_so_far, history['var_diff'], 'orange', marker='s', markersize=3)
                axes[1, 1].axvline(x=epoch+1, color='g', linestyle='--', alpha=0.5)
                axes[1, 1].set_xlabel('Epoch')
                axes[1, 1].set_ylabel('Variance Difference')
                axes[1, 1].set_title('Variance Distribution Difference')
                axes[1, 1].grid(True, alpha=0.3)
            
            # MMD
            if not all(np.isnan(history['mmd'])):
                axes[1, 2].plot(epochs_so_far, history['mmd'], 'purple', marker='^', markersize=3)
                axes[1, 2].axvline(x=epoch+1, color='g', linestyle='--', alpha=0.5)
                axes[1, 2].set_xlabel('Epoch')
                axes[1, 2].set_ylabel('MMD')
                axes[1, 2].set_title('Maximum Mean Discrepancy')
                axes[1, 2].grid(True, alpha=0.3)
            
            plt.suptitle(f'Training Progress - Epoch {epoch+1}', fontsize=14)
            plt.tight_layout()
            metrics_path = os.path.join(epoch_dir, "training_progress.png")
            plt.savefig(metrics_path, dpi=150, bbox_inches='tight')
            plt.close(fig)
            
            # Save epoch summary as JSON
            epoch_summary = {
                'epoch': epoch + 1,
                'train_loss': float(train_loss),
                'val_loss': float(val_loss),
                'learning_rate': float(current_lr),
                'grad_norm': float(avg_grad_norm),
                'mean_diff': float(mean_diff) if not np.isnan(mean_diff) else None,
                'var_diff': float(var_diff) if not np.isnan(var_diff) else None,
                'mmd': float(mmd_val) if not np.isnan(mmd_val) else None,
                'epoch_time': float(epoch_time),
                'timestamp': datetime.now().isoformat()
            }
            
            with open(os.path.join(epoch_dir, "summary.json"), 'w') as f:
                json.dump(epoch_summary, f, indent=2)
            
            # Save simple text summary
            with open(os.path.join(epoch_dir, "summary.txt"), 'w') as f:
                f.write(f"Epoch {epoch+1} Summary\n")
                f.write(f"=====================\n\n")
                f.write(f"Train Loss: {train_loss:.6f}\n")
                f.write(f"Val Loss:   {val_loss:.6f}\n")
                f.write(f"Learning Rate: {current_lr:.2e}\n")
                f.write(f"Gradient Norm: {avg_grad_norm:.2f}\n")
                if not np.isnan(mean_diff):
                    f.write(f"Mean Difference: {mean_diff:.6f}\n")
                    f.write(f"Variance Difference: {var_diff:.6f}\n")
                    f.write(f"MMD: {mmd_val:.6f}\n")
                f.write(f"Epoch Time: {epoch_time:.1f}s\n")
                f.write(f"Timestamp: {datetime.now()}\n")
            
            print(f"  ✓ Visualizations saved to: {epoch_dir}")
        
            
        # Save best model checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f"{experiment_name}_best.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': diffusion.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'train_loss': train_loss,
            'config': config,
            'history': history
        }, checkpoint_path)
        
        # Save periodic checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            periodic_path = os.path.join(checkpoint_dir, f"{experiment_name}_epoch_{epoch+1:04d}.pt")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': diffusion.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'train_loss': train_loss,
                'config': config,
                'history': history
            }, periodic_path)
            print(f"  ✓ Periodic checkpoint saved: {periodic_path}")

        
        # Save intermediate history every epoch
        history_path = os.path.join(exp_plot_dir, "training_history.npy")
        np.save(history_path, history)
    
    # Final Summary
    print(f"\n{'='*80}")
    print(f"TRAINING COMPLETE")
    print(f"{'='*80}")
    print(f"Best epoch: {best_epoch}")
    print(f"Best validation loss: {best_val_loss:.6f}")
    print(f"Total epochs trained: {epoch + 1}")
    print(f"\nCheckpoints saved to: {checkpoint_dir}")
    print(f"Epoch visualizations saved to: {epochs_dir}")
    print(f"Training history saved to: {os.path.join(exp_plot_dir, 'training_history.npy')}")
    
    # Creates final summary plot
    if len(history['epochs']) > 0:
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Loss curves
        axes[0, 0].plot(history['epochs'], history['train_loss'], 'b-', label='Train', alpha=0.7)
        axes[0, 0].plot(history['epochs'], history['val_loss'], 'r-', label='Val', alpha=0.7)
        axes[0, 0].axvline(x=best_epoch, color='g', linestyle='--', label=f'Best (epoch {best_epoch})')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title('Final Training Curves')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Learning rate
        axes[0, 1].plot(history['epochs'], history['learning_rates'], 'g-')
        axes[0, 1].axvline(x=best_epoch, color='g', linestyle='--')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Learning Rate')
        axes[0, 1].set_title('Learning Rate Schedule')
        axes[0, 1].grid(True, alpha=0.3)
        axes[0, 1].set_yscale('log')
        
        # Metrics
        if not all(np.isnan(history['mmd'])):
            axes[1, 0].plot(history['epochs'], history['mmd'], 'purple', label='MMD')
            axes[1, 0].axvline(x=best_epoch, color='g', linestyle='--')
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].set_ylabel('MMD')
            axes[1, 0].set_title('Distribution Distance (MMD)')
            axes[1, 0].legend()
            axes[1, 0].grid(True, alpha=0.3)
        
        if not all(np.isnan(history['mean_diff'])):
            axes[1, 1].plot(history['epochs'], history['mean_diff'], 'c-', label='Mean Diff')
            axes[1, 1].plot(history['epochs'], history['var_diff'], 'orange', label='Var Diff')
            axes[1, 1].axvline(x=best_epoch, color='g', linestyle='--')
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].set_ylabel('Difference')
            axes[1, 1].set_title('Distribution Differences')
            axes[1, 1].legend()
            axes[1, 1].grid(True, alpha=0.3)
        
        plt.suptitle(f'Final Training Summary - {experiment_name}', fontsize=16)
        plt.tight_layout()
        final_plot_path = os.path.join(exp_plot_dir, "final_summary.png")
        plt.savefig(final_plot_path, dpi=150, bbox_inches='tight')
        plt.close(fig)
        print(f"Final summary plot saved to: {final_plot_path}")
        checkpoint_dir = paths['checkpoint_dir']
        



    
    print(f"{'='*80}")
    
    return checkpoint_path, best_val_loss, history

In [None]:
# Full training
print("Running full training!!!")
results = run_single_experiment_with_plots(
    CONFIG, {}, run_id="full_training"
)

print("\n" + "="*80)
print("Execution complete!")
print("="*80)