Total energy now contains nonlocal density-density interactions; we use the von Neumann boundary conditions (this ensures the interaction kernel is diagonal upon dct)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import math, os, time, copy
import torch.fft as tfft
import pandas as pd
import torch_dct as dct
from numpy import size

torch.random.manual_seed(1234) # for reproducibility

# Global settings
dtype = torch.float64
device = "cpu"

data_regime = "rough" # "smooth" or "rough"

N_grid = 512 # number of grid points
if data_regime == "smooth":
    M_cutoff = 50 # maximum harmonic
    m = torch.arange(1, M_cutoff+1, dtype=dtype, device=device)             # (M,)
    x = torch.linspace(0, 1, N_grid, dtype=dtype, device=device)            # (N,)
    #design matrix needed to sample densities
    DesignMatrix = torch.cos(torch.pi * torch.outer(m, x))                  # (M, N)
    DerDM = -torch.pi * m[:, None] * torch.sin(torch.pi * torch.outer(m, x))  # (M, N) # derivative of design matrix
    std_harm = 2.0 / (1.0 + m)**2
elif data_regime == "rough":
    M_cutoff = N_grid - 1 # maximum harmonic
    m = torch.arange(1, M_cutoff+1, dtype=dtype, device=device)             # (M,)
    x = torch.linspace(0, 1, N_grid, dtype=dtype, device=device)            # (N,)
    #design matrix needed to sample densities
    DesignMatrix = torch.cos(torch.pi * torch.outer(m, x))                  # (M, N)
    DerDM = -torch.pi * m[:, None] * torch.sin(torch.pi * torch.outer(m, x))  # (M, N) # derivative of design matrix
    std_harm = 2.0 / (1.0 + 0.0 * m)**2
else:
    raise ValueError("regime must be 'smooth' or 'rough'")

N_train = 1500
N_test = 250
N_val = 250

N_batch = 50
N_epochs = 10000
lr = 1e-3 # we will use a LR scheduler, so this is just an initial value
min_delta = 1e-5 # min change in the monitored quantity to qualify as an improvement
patience = 30    # epochs to wait for improvement before stopping training'

pad_mode = "vonNeumann" # padding mode for convolution-based routines: either "zero" or "vonNeumann"

Setup the functional and sample density profiles 

In [None]:
def sample_density(*, rho_b=0.0):
    """
    Sample rho_j = rho_avg + sum_{m=1}^M a_m cos(m pi x_j), with x_j in [0,1]
    this sampling of amplitudes a_m implies that the derivatives at the boundaries are zero
    Sampling is done such that the generated density has zero mean over the grid points
    Returns:
      rho : (N,) density profile
    """    
    a = torch.normal(torch.zeros_like(std_harm), std_harm)
    rho = a @ DesignMatrix
    d_rho_a = a @ DerDM  # derivative of rho w.r.t. x

    return rho, d_rho_a, a

def sample_density_batch(B: int, rho_b=0.0):
    """
    Sample a batch of B density profiles
    Spatial average density rho_avg is set to zero
    Returns rho: (B, N_grid)
    """

    a = torch.normal(torch.zeros(B, std_harm.numel(),dtype=dtype, device=device), std_harm.expand(B, -1))
    rho = a @ DesignMatrix  # (B, N_grid)
    d_rho_a = a @ DerDM  # (B, N_grid) # derivative of rho w.r.t. x

    return rho, d_rho_a, a

def rho_to_cosine_coeffs(rho_batch: torch.Tensor) -> torch.Tensor:
    """
    rho_batch: (B, N) real tensor
        rho[j] sampled on j = 0,...,N-1.
    Returns:
        a_batch: (B, N) cosine coefficients a_m in
            rho_j = sum_m a_m cos(pi m j / (N-1)).
    """
    # DCT-I along the last dimension
    # torch_dct.dct1 works along the last axis by default
    X = dct.dct1(rho_batch)          # shape (B, N)
    B, N = X.shape

    a = X.clone()
    a[:, 0]      = 0.5 * X[:, 0]      / (N - 1)        # m = 0
    a[:, 1:N-1]  = X[:, 1:N-1] / (N - 1)   # 1 <= m <= N-2
    a[:, N-1]    = 0.5 * X[:, N-1]    / (N - 1)        # m = N-1

    return a

def cosine_coeffs_to_rho(a_batch: torch.Tensor) -> torch.Tensor:
    """
    Inverse of rho_to_cosine_coeffs, using inverse DCT-I.

    a_batch: (B, N) of a_m
    Returns:
        rho_batch: (B, N)
    """
    B, N = a_batch.shape

    # Map back to standard DCT-I coefficients X_m
    X = a_batch.clone()
    X[:, 0]      = a_batch[:, 0]      * (N - 1) * 2.0
    X[:, 1:N-1]  = a_batch[:, 1:N-1]  * (N - 1)
    X[:, N-1]    = a_batch[:, N-1]    * (N - 1) * 2.0

    # Inverse DCT-I along last dimension
    rho_rec = dct.idct1(X)

    return rho_rec

# density–density interaction kernels K(r)
# r: tensor (can be negative)
def K_gaussian(r, sigma=1.0):       # strictly local-ish
    r = r.to(dtype=dtype)
    return torch.exp(-(r**2) / (sigma**2))

def K_exp(r, xi=2.0):              # short–to–intermediate range
    r = r.to(dtype=dtype)
    return torch.exp(-torch.abs(r) / xi)

def K_yukawa(r, lam=10.0):
    r_abs = torch.abs(r).to(dtype=dtype)
    out = torch.exp(-r_abs / lam) / r_abs.clamp(min=1.0)
    return out * (r_abs > 0)     # zero at r == 0

def K_power(r, alpha=1.0):         # unscreened long range (Coulomb-like for alpha=1)
    r = torch.abs(r).to(dtype=dtype)
    out = 1.0 / (r.clamp(min=1.0)**alpha)
    return out * (r > 0)     # zero at r == 0

def E_int_conv(rho: torch.Tensor, kernel: str, **kwargs) -> torch.Tensor: 
    """
    Interaction energy using convolution
    rho: (N,) or (B, N)
    kernel: "gaussian", "exp", "yukawa", "power"
    kwargs: parameters for the kernel function (sigma, xi, lam, alpha, etc.)
    Returns: scalar (if input 1D) or (B,) (if input 2D)
    """
    if kernel == "gaussian":
        K_fun = K_gaussian
    elif kernel == "exp":
        K_fun = K_exp
    elif kernel == "yukawa":
        K_fun = K_yukawa
    elif kernel == "power":
        K_fun = K_power
    else:
        raise ValueError(f"Unknown kernel: {kernel}")

    # ensure batch dim
    if rho.dim() == 1:
        rho = rho.unsqueeze(0)
    B, N = rho.shape
    device, dtype = rho.device, rho.dtype
    r_vals = torch.arange(-(N-1), N, device=device, dtype=dtype)  # (2N-1,)
    k_full = K_fun(r_vals, **kwargs)                              # (2N-1,)
    weight = k_full.view(1, 1, -1)                # (1,1,2N-1)

    if pad_mode == "zero":     
        u = F.conv1d(rho.unsqueeze(1), weight, padding=N-1).squeeze(1) # (B, N)
    elif pad_mode == "vonNeumann":
        # even reflection padding
        rho_pad = F.pad(rho.unsqueeze(1), (N-1, N-1), mode='reflect')   # (B,1,N+2R)
        u = F.conv1d(rho_pad, weight).squeeze(1)
    else:
        raise ValueError(f"Unknown padding: {pad_mode}")
    
    E = 0.5 * (rho * u).sum(dim=-1) / N  # (B,)
    return E.squeeze(0) if E.numel() == 1 else E

def kernel_eigenvals_dct(K: torch.Tensor) -> torch.Tensor:
    """
    Compute the convolution eigenvalues for kernel K using DCT-I.

    K: (N,) real tensor, kernel values at r = 0, 1, ..., N-1
       (assumed even: K_r = K_{-r})

    Returns:
        lam_K: (N,) real tensor, eigenvalues λ_m of the convolution operator
               with that kernel under even-reflection (Neumann BC).
    """
    K = K.squeeze()
    N = K.shape[-1]
    X = dct.dct1(K)                     # (N,)

    # λ_m = X_m + (-1)^m * K_{N-1}
    lam_K = X + K[-1] * ((-1.0) ** torch.arange(N, device=K.device, dtype=K.dtype))     # (N,)
    return lam_K

def kernel_from_eigenvals_dct(lam_K: torch.Tensor) -> torch.Tensor:
    """
    Invert kernel_eigenvals_dct:
    given eigenvalues lam_K (λ_m), reconstruct kernel K (K_r, r=0..N-1).

    lam_K: (N,)
    Returns:
        K: (N,) real tensor, kernel values at r = 0, 1, ..., N-1
    """
    
    device, dtype = lam_K.device, lam_K.dtype
    N = lam_K.shape[-1]

    # v_m = (-1)^m
    v = (-1.0) ** torch.arange(N, device=device, dtype=dtype)                # (N,)
    A = dct.idct1(lam_K)           # (N,)
    c = dct.idct1(v)               # (N,)

    # K_{N-1} = A_{N-1} / (1 + c_{N-1})   (do this per batch)
    denom = 1.0 + c[-1]
    K_nm1 = A[-1] / denom       

    # K = A - K_{N-1} * c
    K = A - K_nm1 * c  
    return K

def E_int_dct(rho: torch.Tensor, kernel: str, **kwargs) -> torch.Tensor:
    """
    Interaction energy using DCT-I eigenvalues of the convolution operator
    rho: (N,) or (B, N)
    kernel: "gaussian", "exp", "yukawa", "power"
    kwargs: parameters for the kernel function (sigma, xi, lam, alpha, etc.)
    Returns: scalar (if input 1D) or (B,) (if input 2D)
    """
    if kernel == "gaussian":
        K_fun = K_gaussian
    elif kernel == "exp":
        K_fun = K_exp
    elif kernel == "yukawa":
        K_fun = K_yukawa
    elif kernel == "power":
        K_fun = K_power
    else:
        raise ValueError(f"Unknown kernel: {kernel}")

    # ensure batch dim
    if rho.dim() == 1:
        rho = rho.unsqueeze(0)
    B, N = rho.shape
    device, dtype = rho.device, rho.dtype
    r_vals = torch.arange(0, N, device=device, dtype=dtype)  # (N,)
    K_vals = K_fun(r_vals, **kwargs)                         # (N,)

    lam_K = kernel_eigenvals_dct(K_vals).to(device=device, dtype=dtype) # (N,)
    a = rho_to_cosine_coeffs(rho)        # (B, N)
    u = cosine_coeffs_to_rho(lam_K.unsqueeze(0) * a)         # (B, N)
    E = 0.5 * (rho * u).sum(dim=-1) / N  # (B,)
    
    return E.squeeze(0) if E.numel() == 1 else E

def E_int_dct_v2(rho: torch.Tensor, kernel: str, **kwargs) -> torch.Tensor:
    """
    Interaction energy using DCT-I eigenvalues of the convolution operator
    rho: (N,) or (B, N)
    kernel: "gaussian", "exp", "yukawa", "power"
    kwargs: parameters for the kernel function (sigma, xi, lam, alpha, etc.)
    Returns: scalar (if input 1D) or (B,) (if input 2D)
    """
    # select kernel
    if kernel == "gaussian":
        K_fun = K_gaussian
    elif kernel == "exp":
        K_fun = K_exp
    elif kernel == "yukawa":
        K_fun = K_yukawa
    elif kernel == "power":
        K_fun = K_power
    else:
        raise ValueError(f"Unknown kernel: {kernel}")

    if rho.dim() == 1:
        rho = rho.unsqueeze(0)
    B, N = rho.shape
    device, dtype = rho.device, rho.dtype

    r_vals = torch.arange(0, N, device=device, dtype=dtype)   # (N,)
    K_vals = K_fun(r_vals, **kwargs)                          # (N,)
    lam_K = kernel_eigenvals_dct(K_vals).to(device=device, dtype=dtype)  # (N,)
    a = rho_to_cosine_coeffs(rho)                             # (B, N)

    # spectral "convolution" coefficients: λ_m a_m
    spec = lam_K.unsqueeze(0) * a                             # (B, N)

    # ---------- diagonal (bulk) term ----------
    # norms n_m = <φ_m|φ_m> in the weighted scalar product
    norms = torch.full((N,), (N - 1.0) / 2.0,
                       device=device, dtype=dtype)
    norms[0] = N - 1.0
    norms[-1] = N - 1.0

    # E_diag = (1/(2N)) * sum_m λ_m a_m^2 n_m
    E_diag = 0.5 * (spec * a * norms.unsqueeze(0)).sum(dim=-1) / N   # (B,)

    # ---------- boundary (mixing) term ----------
    # sum_m λ_m a_m
    sum_spec = spec.sum(dim=-1)                                      # (B,)
    # sum_m λ_m a_m (-1)^m
    sign = (-1.0) ** torch.arange(N, device=device, dtype=dtype)     # (N,)
    sum_spec_signed = (spec * sign.unsqueeze(0)).sum(dim=-1)         # (B,)

    # E_bnd = (1/(4N)) * [ ρ_0 * sum_m λ_m a_m + ρ_{N-1} * sum_m λ_m a_m (-1)^m ]
    E_bnd = 0.25 * (rho[:, 0] * sum_spec + rho[:, -1] * sum_spec_signed) / N  # (B,)

    E = E_diag + E_bnd
    return E.squeeze(0) if E.numel() == 1 else E

lam = 10.0
alpha = 1.0
xi = 100.0
kernel_regime = "exp"  

if kernel_regime == "power":
    def E_tot(rho: torch.Tensor) -> torch.Tensor:
        return E_int_dct(rho, kernel="power", alpha=alpha)
elif kernel_regime == "yukawa": 
    def E_tot(rho: torch.Tensor) -> torch.Tensor:
        return E_int_dct(rho, kernel="yukawa", lam=lam)
elif kernel_regime == "exp":
    def E_tot(rho: torch.Tensor) -> torch.Tensor:
        return E_int_dct(rho, kernel="exp", xi=xi)

R = 100 
r_grid = torch.arange(-R, R+1)

plt.figure(figsize=(6,4))
if kernel_regime == "yukawa":
    plt.plot(r_grid, K_yukawa(r_grid, lam=lam), 'o-', linewidth=2) 
elif kernel_regime == "power":
    plt.plot(r_grid, K_power(r_grid, alpha=alpha), 'o-', linewidth=2) 
elif kernel_regime == "exp":
    plt.plot(r_grid, K_exp(r_grid, xi=xi), 'o-', linewidth=2)
plt.axhline(0, color='k', linewidth=0.5)
plt.xlabel("r = i - j")
plt.ylabel("K_r")
plt.xlim(-R, R)
# plt.ylim((0.0, 1.0))
plt.title(f"Local kernel")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

rho_batch, d_rho_batch, a_batch = sample_density_batch(3)  # (B, N_grid)

a_batch_rec = rho_to_cosine_coeffs(rho_batch)
rho_batch_rec = cosine_coeffs_to_rho(a_batch_rec)

plt.figure(figsize=(8,4))
plt.plot(x.numpy(), rho_batch[0, :].numpy(), lw=2)
plt.plot(x.numpy(), rho_batch_rec[0, :].numpy(), lw=2, ls='--')
plt.xlabel("x")
plt.ylabel("rho(x)")
plt.title("Random sampled densities from cosine basis")
plt.grid(True)
plt.show()

vec = torch.arange(0, N_grid, dtype=dtype, device=device)
plt.figure(figsize=(8,4))
plt.plot(m.numpy(), a_batch[0, :].numpy(), lw=2)
plt.plot(vec.numpy(), a_batch_rec[0, :].numpy(), lw=2, ls='--')
plt.xlabel("m")
plt.ylabel("a_m")
plt.xlim(0, M_cutoff+5)
plt.title("Random sampled densities from cosine basis")
plt.grid(True)
plt.show()

if data_regime == "rough":
    print("Max abs diff in a_m:", torch.max(torch.abs(a_batch - a_batch_rec[:, 1:])).item())
else:
    print("Max abs diff in a_m:", torch.max(torch.abs(a_batch - a_batch_rec[:, 1:M_cutoff+1])).item())
print("Max abs diff in rho:", torch.max(torch.abs(rho_batch - rho_batch_rec)).item())
print("Zero freq component (should be 0):", a_batch_rec[:, 0])

E_int_v1 = E_int_conv(rho_batch[0, :], kernel="power", alpha=alpha)
E_int_v2 = E_int_dct(rho_batch[0, :], kernel="power", alpha=alpha)
E_int_v3 = E_int_dct_v2(rho_batch[0, :], kernel="power", alpha=alpha)
print(f"E_int_conv = {E_int_v1.item():.6f}, E_int_dct = {E_int_v2.item():.6f}, E_int_dct_v2 = {E_int_v3.item():.6f}, diff = {abs(E_int_v1 - E_int_v2).item()}, diff_v2 = {abs(E_int_v1 - E_int_v3).item()}")

#Testing kernel eigenvalues and kernel reconstruction
r_vals = torch.arange(0, N_grid, device=device, dtype=dtype)
if kernel_regime == "power":
    K_vals = K_power(r_vals, alpha=alpha)
elif kernel_regime == "yukawa":
    K_vals = K_yukawa(r_vals, lam=lam)
elif kernel_regime == "exp":
    K_vals = K_exp(r_vals, xi=xi)

lam_K = kernel_eigenvals_dct(K_vals)
K_rec = kernel_from_eigenvals_dct(lam_K)

print("Max abs diff in kernel reconstruction:", torch.max(torch.abs(K_vals - K_rec)).item())

Feature processing (generation and normalization)

In [None]:
# we save features as (B, N_grid, N_feat), where N_feat is the number of features per grid point
# generate train/test split

def compute_normalization_stats(features):
    """
    Compute mean and std for features with shape (N_data, N_grid, N_feat)
    Averages over both data and spatial dimensions
    
    Args:
        features: torch.Tensor of shape (N_data, N_grid, N_feat)
    
    Returns:
        mean: torch.Tensor of shape (1, 1, N_feat)
        std: torch.Tensor of shape (1, 1, N_feat)
    """

    mean_feat = features.mean(dim=(0, 1), keepdim=True)  # Shape: (1, 1, N_feat)
    std_feat = features.std(dim=(0, 1), keepdim=True) # Shape: (1, 1, N_feat)
    
    return mean_feat, std_feat

def normalize_features(features, mean_feat, std_feat):
    """
    Normalize features using provided or computed statistics
    
    Args:
        features: torch.Tensor of shape (B, N_grid, N_feat)
        mean: torch.Tensor of shape (1, 1, N_feat)
        std: torch.Tensor of shape (1, 1, N_feat)

    Returns:
        normalized_features: torch.Tensor of same shape as input
        mean: mean used for normalization
        std: std used for normalization
    """
    normalized_features = (features - mean_feat) / std_feat

    return normalized_features

def generate_loc_features_rs(rho: torch.Tensor, N_feat=2) -> torch.Tensor:
    """
    Generate local features from density rho
    rs, real space   
    Args:
        rho: torch.Tensor of shape (B, N_grid)
        N_feat: int, number of features to generate

    Returns:
        features: torch.Tensor of shape (B, N_grid, N_feat)
        each feature is of the form rho^k, k=1,...,N_feat
    """
    features = [rho.unsqueeze(-1) ** k for k in range(1, N_feat + 1)]
    return torch.cat(features, dim=-1)

def generate_loc_features_ms(d_rho: torch.Tensor, N_feat=2) -> torch.Tensor:
    """
    Generate local features from density derivative d_rho
    ms, momentum space
    Args:
        d_rho: torch.Tensor of shape (B, N_grid)
        N_feat: int, number of features to generate

    Returns:
        features: torch.Tensor of shape (B, N_grid, N_feat)
        each feature is of the form d_rho^k, k=1,...,N_feat
    """
    features = [d_rho.unsqueeze(-1) ** k for k in range(1, N_feat + 1)]
    return torch.cat(features, dim=-1)

rho_train, d_rho_train, a_train = sample_density_batch(N_train)  # (N_train, N_grid)
rho_test, d_rho_test, a_test = sample_density_batch(N_test)   # (N_test, N_grid)
rho_val, d_rho_val, a_val = sample_density_batch(N_val)    # (N_val, N_grid)

N_feat = 1 

features_train_rs = generate_loc_features_rs(rho_train, N_feat=N_feat)  # (N_train, N_grid, N_feat)
features_test_rs  = generate_loc_features_rs(rho_test, N_feat=N_feat)   # (N_test, N_grid, N_feat)
features_val_rs   = generate_loc_features_rs(rho_val, N_feat=N_feat)    # (N_val, N_grid, N_feat)

features_train_ms = generate_loc_features_ms(d_rho_train, N_feat=N_feat)  # (N_train, N_grid, N_feat)
features_test_ms  = generate_loc_features_ms(d_rho_test, N_feat=N_feat)   # (N_test, N_grid, N_feat)
features_val_ms   = generate_loc_features_ms(d_rho_val, N_feat=N_feat)    # (N_val, N_grid, N_feat)

features_train = torch.cat([features_train_rs, features_train_ms], dim=-1)
features_test  = torch.cat([features_test_rs, features_test_ms], dim=-1)
features_val   = torch.cat([features_val_rs, features_val_ms], dim=-1)

targets_train = E_tot(rho_train)            # (N_train,)
targets_test  = E_tot(rho_test)             # (N_test,)
targets_val   = E_tot(rho_val)              # (N_val,)

# Normalize features
mean_feat, std_feat = compute_normalization_stats(features_train)
features_train_norm = normalize_features(features_train, mean_feat, std_feat)
features_test_norm = normalize_features(features_test, mean_feat, std_feat)
features_val_norm = normalize_features(features_val, mean_feat, std_feat)

# Normalize targets
E_mean = targets_train.mean()
E_std = targets_train.std()
targets_train_norm = (targets_train - E_mean) / E_std
targets_test_norm = (targets_test - E_mean) / E_std
targets_val_norm = (targets_val - E_mean) / E_std

# Datasets
train_dataset = TensorDataset(features_train_norm, targets_train_norm)
val_dataset   = TensorDataset(features_val_norm,   targets_val_norm)
test_dataset  = TensorDataset(features_test_norm,  targets_test_norm)

# Loaders
train_loader = DataLoader(train_dataset, batch_size=N_batch, shuffle=True,  drop_last=False)
val_loader   = DataLoader(val_dataset,   batch_size=N_batch, shuffle=False, drop_last=False)
test_loader  = DataLoader(test_dataset,  batch_size=N_batch, shuffle=False, drop_last=False)

Learning models: 
1. A real space model of a local interaction kernel kernel (either a window or a mixture of Guassians)
2. A momentum space model via DCT-I (we either parametrize the interaction kernel in momentum space or real space, the parametrization can be a model with a few learnable parameters)
3. Hybrid where the short-range part is captured via the real space approach, while the long-range part via the momentum space one

In [None]:
class LearnableKernelConv1d(nn.Module):
    """
    Learnable 1D convolution kernel K_r with range R
    Produces phi = [K * rho] (linear convolution with padding)
    """
    def __init__(self, R=5, even_kernel=True, pad_mode="zero"):
        super().__init__()
        self.R = R
        self.pad_mode = pad_mode
        self.even_kernel = even_kernel
        
        if even_kernel:
            # learn half + center: w[0] (center), w[1..R] (positive r)
            self.kernel_half = nn.Parameter(torch.randn(R+1) * 0.01)
        else:
            # fully unconstrained kernel of size 2R+1
            self.kernel = nn.Parameter(torch.randn(2*R+1) * 0.01)

    def build_kernel(self):
        """
        Returns kernel of shape (1,1,2R+1) as required by conv1d
        """
        if self.even_kernel:
            center = self.kernel_half[0:1]          # (1,)
            pos = self.kernel_half[1:]             # (R,)
            neg = pos.flip(0)               # symmetric
            full = torch.cat([neg, center, pos], dim=0)  # (2R+1,)
        else:
            full = self.kernel
        return full.view(1,1,-1)  # (out=1, in=1, kernel_size)

    def forward(self, rho):
        """
        rho: (B, N_grid)
        Returns: phi: (B, N_grid)
        """
        B, N = rho.shape
        kernel = self.build_kernel()
        kernel = kernel.to(dtype=rho.dtype, device=rho.device)
        R = self.R
        
        if self.pad_mode == "zero":
            x = F.pad(rho.unsqueeze(1), (R, R), mode='constant', value=0.0)
        elif self.pad_mode == "reflect":
            x = F.pad(rho.unsqueeze(1), (R, R), mode='reflect')
        else:
            raise ValueError("pad_mode must be zero or reflect")
        
        phi = F.conv1d(x, kernel, padding=0).squeeze(1)
        return phi

class KernelOnlyEnergyNN(nn.Module):
    """
    E_tot = (1 / 2N) * sum_{i,j} rho_i rho_j K_{i-j} = (1 / 2N) * sum_{i} rho_i [K * rho]_i
    The kernel K is assumed local and learnable via convolution
    """

    def __init__(self, R=5):
        super().__init__()
        self.R = R
        self.kernel_conv = LearnableKernelConv1d(R, even_kernel=True, pad_mode="zero")

    def forward(self, features):
        """
        Args:
            features: (B, N_grid, N_feat) - only the first feature (density) is used
        
        Returns:
            local_energies: (B, N_grid) - energy at each grid point
            total_energy: (B,) - sum over grid points
        """
        rho_norm = features[..., 0]          # (B, N_grid), isolate density
        rho = rho_norm * std_feat[0,0,0] + mean_feat[0,0,0]  # denormalize density

        # apply learnable convolution kernel
        phi = self.kernel_conv(rho)  # (B, N_grid) # in physical units

        local_energies = 0.5 * rho * phi       # (B, N_grid) # physical units
        total_energy = local_energies.sum(dim=1) / N_grid       # (B,) # physical units
        
        total_energy_norm = (total_energy - E_mean) / E_std  # normalize for training

        return total_energy_norm
    
class GaussianMixtureKernelConv1d(nn.Module):
    """
    Learnable 1D kernel K_r represented as a sum of Gaussians:
        K(r) = sum_{n=1}^M A_n * exp(-r^2 / sigma_n^2)

    Produces phi = [K * rho] via conv1d.
    """

    def __init__(self, R: int, n_components: int, pad_mode: str = "zero"):
        super().__init__()
        self.R = R
        self.n_components = n_components
        self.pad_mode = pad_mode

        # r-grid as a buffer: [-R, ..., R] - still is a cutoff in real space
        r_vals = torch.arange(-R, R + 1, dtype=dtype)
        self.register_buffer("r_vals", r_vals)  # (2R+1,)

        # ---- amplitudes A_n (sorted at init) ----
        # sample, then sort descending so A_1 >= A_2 >= ... >= A_M
        amps = 0.01 * torch.randn(n_components, dtype=dtype)
        amps, _ = torch.sort(amps, descending=True)
        self.amplitudes = nn.Parameter(amps)

        # Log-sigmas so sigma_n = softplus(log_sigma_n) > 0
        init_sigmas = torch.arange(1, n_components+1, dtype=dtype)**2
        log_sigmas = torch.log(torch.expm1(init_sigmas))  # inverse softplus, so softplus(raw)=init_sigma
        self.log_sigmas = nn.Parameter(log_sigmas)

    def build_kernel(self):
        """
        Returns kernel of shape (1, 1, 2R+1) as required by conv1d.
        """
        # (2R+1,) -> (1, L)
        r = self.r_vals.view(1, -1)              # (1, 2R+1)
        r2 = r * r                               # r^2

        sigmas = F.softplus(self.log_sigmas) + 1e-8   # (M,)
        sigma2 = sigmas.view(-1, 1) ** 2              # (M,1)

        # contributions from each Gaussian: (M, L)
        # exp(-r^2 / sigma_n^2)
        gauss = torch.exp(-r2 / sigma2)          # (M, 2R+1)

        # weighted sum over components
        kernel_1d = (self.amplitudes.view(-1, 1) * gauss).sum(dim=0)  # (2R+1,)

        # conv1d expects (out_channels, in_channels, kernel_size)
        return kernel_1d.view(1, 1, -1)

    def forward(self, rho: torch.Tensor) -> torch.Tensor:
        """
        rho: (B, N_grid)
        Returns: phi = (K * rho): (B, N_grid)
        """
        B, N = rho.shape
        kernel = self.build_kernel()
        R = self.R

        if self.pad_mode == "zero":
            x = F.pad(rho.unsqueeze(1), (R, R), mode="constant", value=0.0)
        elif self.pad_mode == "reflect":
            x = F.pad(rho.unsqueeze(1), (R, R), mode="reflect")
        else:
            raise ValueError("pad_mode must be 'zero' or 'reflect'")

        phi = F.conv1d(x, kernel, padding=0).squeeze(1)   # (B, N)
        return phi


class GaussianMixtureEnergyNN(nn.Module):
    """
    E_tot = (1 / 2N) * sum_{i,j} rho_i rho_j K_{i-j}
          = (1 / 2N) * sum_i rho_i [K * rho]_i

    K is parameterized as a sum of Gaussians in r.
    """

    def __init__(self, R=20, n_components=3, pad_mode="zero"):
        super().__init__()
        self.R = R
        self.kernel_conv = GaussianMixtureKernelConv1d(
            R=R,
            n_components=n_components,
            pad_mode=pad_mode,
        )

    def forward(self, features):
        """
        Args:
            features: (B, N_grid, N_feat) - only the first feature (density) is used

        Returns:
            total_energy_norm: (B,) - normalized total energy
        """
        # globals: std_feat, mean_feat, E_mean, E_std, N_grid must exist
        rho_norm = features[..., 0]          # (B, N_grid)
        rho = rho_norm * std_feat[0, 0, 0] + mean_feat[0, 0, 0]

        # apply Gaussian-mixture kernel
        phi = self.kernel_conv(rho)          # (B, N_grid)

        local_energies = 0.5 * rho * phi     # (B, N_grid)
        total_energy = local_energies.sum(dim=1) / N_grid    # (B,)

        total_energy_norm = (total_energy - E_mean) / E_std
        return total_energy_norm
    
class LearnableRSNonLocalKernelDCT(nn.Module):
    """
    Learnable 1D nonlocal kernel K_r with full range N_grid
    Produces phi = [K * rho] via DCT-I
    von Neumann BCs (even reflection)
    RS = real space
    """
    def __init__(self, zero_r_flag=True):
        super().__init__()
        self.zero_r_flag = zero_r_flag
        if zero_r_flag:
            self.rs_kernel = nn.Parameter(torch.randn(N_grid) * 0.01)  # full kernel K_r, r=0..N_grid-1
            self.rs_kernel.data[0] = 0.0    # enforce K_0 = 0
        else:
            self.rs_kernel = nn.Parameter(torch.randn(N_grid) * 0.01) # full kernel K_r, r=0..N_grid-1

    def forward(self, rho):
        """
        rho: (B, N_grid)
        Returns: phi: (B, N_grid)
        """
        kernel = self.rs_kernel.to(dtype=rho.dtype, device=rho.device) # (N_grid,)
        if self.zero_r_flag:
            kernel = kernel.clone()
            kernel[..., 0] = 0.0

        lam_K = kernel_eigenvals_dct(kernel).to(device=rho.device, dtype=rho.dtype) # (N_grid,)
        a = rho_to_cosine_coeffs(rho)
        phi = cosine_coeffs_to_rho(lam_K.unsqueeze(0) * a)

        return phi

class ExpMixtureRSNonLocalKernelDCT(nn.Module):
    """
    Learnable 1D kernel K_r represented as a sum of exponentials:
        K(r) = sum_{n=1}^M A_n * exp(-r / sigma_n)
    """

    def __init__(self, zero_r_flag=True, n_components=3, N=N_grid): # N = grid size
        super().__init__()
        self.zero_r_flag = zero_r_flag
        self.n_components = n_components
        self.N = N

        r_vals = torch.arange(0, N, dtype=dtype)
        self.register_buffer("r_vals", r_vals)  # (N,)

        # ---- amplitudes A_n (sorted at init) ----
        # sample, then sort descending so A_1 >= A_2 >= ... >= A_M
        amps = 0.01 * torch.randn(n_components, dtype=dtype)
        amps, _ = torch.sort(amps, descending=True)
        self.amplitudes = nn.Parameter(amps)

        # Log-sigmas so sigma_n = softplus(log_sigma_n) > 0
        init_sigmas = 10 + 20 * torch.arange(n_components, dtype=dtype)
        log_sigmas = torch.log(torch.expm1(init_sigmas))  # inverse softplus, so softplus(raw)=init_sigma
        self.log_sigmas = nn.Parameter(log_sigmas)

    def build_kernel(self):
        r = self.r_vals.view(1, -1)              # (1, N)
        sigmas = F.softplus(self.log_sigmas) + 1e-8   # (M,)
        exp_mixt = torch.exp(-r / sigmas.view(-1, 1))          # (M, N)
        return (self.amplitudes.view(-1, 1) * exp_mixt).sum(dim=0)  # (N,)

    def forward(self, rho: torch.Tensor) -> torch.Tensor:
        """
        rho: (B, N_grid)
        Returns: phi = (K * rho): (B, N_grid)
        """
        kernel = self.build_kernel()
        if self.zero_r_flag:
            kernel = kernel.clone()
            kernel[..., 0] = 0.0

        lam_K = kernel_eigenvals_dct(kernel).to(device=rho.device, dtype=rho.dtype) # (N_grid,)
        a = rho_to_cosine_coeffs(rho)
        return cosine_coeffs_to_rho(lam_K.unsqueeze(0) * a)
    
class LearnableMSNonLocalKernelDCT(nn.Module):
    """
    Learnable nonlocal kernel parameterized directly in momentum space

    We learn λ_m for m = 1..range_ms.
    λ_0 is set to 0 (no uniform component), and λ_m = 0 for m > range_ms.
    """
    def __init__(self, range_ms: int = 50):
        super().__init__()
        self.range_ms = range_ms
        # learnable λ_m for m=1..range_ms
        self.ms_kernel = nn.Parameter(torch.randn(range_ms) * 0.01)

    def forward(self, rho: torch.Tensor) -> torch.Tensor:
        """
        rho: (B, N_grid)
        Returns: phi = (K * rho): (B, N_grid)
        """
        B, N_grid = rho.shape
        device, dtype = rho.device, rho.dtype

        lam_head0 = torch.zeros(1, device=device, dtype=dtype)  
        lam_active = self.ms_kernel.to(device=device, dtype=dtype)  
        lam_tail = torch.zeros(N_grid - 1 - self.range_ms, device=device, dtype=dtype)
        lam_K = torch.cat([lam_head0, lam_active, lam_tail], dim=0)  # (N_grid,)

        a = rho_to_cosine_coeffs(rho)                      # (B, N_grid)
        return cosine_coeffs_to_rho(lam_K.unsqueeze(0) * a) # (B, N_grid)
    
class DCTKernelEnergyNN(nn.Module):
    """
    E_tot = (1 / 2N) sum_i rho_i [K * rho]_i
    where K is represented via either:
            (i) its DCT-I eigenvalues λ_m, which are learnable parameters,
            (ii) via a real-space kernel K_r (K_-r = K_r) but computations are done via DCT-I (speedup N log N vs N^2),
            (iii) via a mixture of learnable functions (with a few parameters), either in real space or in DCT-I space.
    """
    def __init__(self, learning_mode='dct_rs_blind', zero_r_flag=False, n_components=3, range_ms=50):
        super().__init__()
        self.learning_mode = learning_mode
        if learning_mode == 'dct_rs_blind':
            self.nonlocal_kernel = LearnableRSNonLocalKernelDCT(zero_r_flag=zero_r_flag)  # zero_r_flag=True enforces K_0 = 0
        elif learning_mode == 'dct_exp_rs_mixture':
            self.nonlocal_kernel = ExpMixtureRSNonLocalKernelDCT(n_components=n_components, N=N_grid, zero_r_flag=zero_r_flag)
        elif learning_mode == 'dct_ms_blind':
            self.nonlocal_kernel = LearnableMSNonLocalKernelDCT(range_ms=range_ms)
    
    def forward(self, features):
        """
        Args:
            features: (B, N_grid, N_feat) - only the first feature (density) is used

        Returns:
            total_energy_norm: (B,) - normalized total energy
        """
        # globals: std_feat, mean_feat, E_mean, E_std, N_grid must exist
        rho_norm = features[..., 0]          # (B, N_grid)
        rho = rho_norm * std_feat[0, 0, 0] + mean_feat[0, 0, 0]

        phi = self.nonlocal_kernel(rho)          # (B, N_grid)

        local_energies = 0.5 * rho * phi     # (B, N_grid)
        total_energy = local_energies.sum(dim=1) / N_grid    # (B,)

        total_energy_norm = (total_energy - E_mean) / E_std
        return total_energy_norm
    

class HybridKernelEnergyNN(nn.Module):
    """
    Hybrid local + nonlocal kernel energy model
    E_tot = E_local + E_nonlocal
    """
    def __init__(self, R=5, n_exp_components=1):
        super().__init__()
        self.R = R
        self.n_exp_components = n_exp_components

        self.local_kernel = LearnableKernelConv1d(R=R, even_kernel=True, pad_mode="reflect")
        self.nonlocal_kernel = ExpMixtureRSNonLocalKernelDCT(n_components=n_exp_components, N=N_grid, zero_r_flag=False)

    def forward(self, features):
        rho_norm = features[..., 0]          # (B, N_grid)
        rho = rho_norm * std_feat[0, 0, 0] + mean_feat[0, 0, 0]

        phi_local = self.local_kernel(rho)          # (B, N_grid)
        phi_nonlocal = self.nonlocal_kernel(rho)          # (B, N_grid)
        E_local = 0.5 * (rho * phi_local).sum(dim=1) / N_grid       # (B,)
        E_nonlocal = 0.5 * (rho * phi_nonlocal).sum(dim=1) / N_grid   # (B,)
        total_energy = E_local + E_nonlocal

        total_energy_norm = (total_energy - E_mean) / E_std
        return total_energy_norm

Intermediate routines for training + actual model exploration

In [None]:
def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    n_batches = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            y_pred = model(xb)
            loss = criterion(y_pred, yb)
            total_loss += loss.item()
            n_batches += 1
    return total_loss / max(1, n_batches)

def load_checkpoint(path, model_class, device="cpu"):
    """
    Loads a saved model checkpoint
    Returns:
        model: reconstructed and loaded model
        normalization: dict of normalization stats
        epoch: best epoch
        val_loss: best validation loss
    """
    ckpt = torch.load(path, map_location=device)
    config = ckpt["config"]
    model = model_class(**config).to(device)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()

    normalization = ckpt.get("normalization", None)
    epoch = ckpt.get("epoch", None)
    val_loss = ckpt.get("val_loss", None)
    return model, normalization, epoch, val_loss
    
def _run_epoch(model, loader, criterion, train: bool):
    if train:
        model.train()
    else:
        model.eval()
    running = 0.0
    n_batches = 0

    with torch.set_grad_enabled(train):
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)

            if train:
                optimizer.zero_grad()

            total_pred = model(xb)               
            loss = criterion(total_pred, yb)        

            if train:
                loss.backward()
                optimizer.step()

            running += loss.item()
            n_batches += 1

    return running / max(1, n_batches)

def train_with_early_stopping(
    model,
    train_loader,
    val_loader,
    criterion,
    scheduler=None,
    max_epochs=10000,
    patience=10,
    min_delta=1e-5,
    ckpt_dir="checkpoints",
    run_name=None,
    learning_regime="window",
):
    os.makedirs(ckpt_dir, exist_ok=True)

    best_val = math.inf
    best_state = None
    best_epoch = -1
    since_improved = 0

    hist = {"train_loss": [], "val_loss": []}

    for epoch in range(1, max_epochs + 1):
        train_loss = _run_epoch(model, train_loader, criterion, train=True)
        val_loss = _run_epoch(model, val_loader, criterion, train=False)

        hist["train_loss"].append(train_loss)
        hist["val_loss"].append(val_loss)

        scheduler.step(val_loss)

        improved = (best_val - val_loss) > min_delta
        if improved:
            best_val = val_loss
            best_state = copy.deepcopy(model.state_dict())
            best_epoch = epoch
            since_improved = 0

            # save checkpoint with normalization stats
            ckpt_path = os.path.join(ckpt_dir, f"{run_name}_best.pt")
            if learning_regime == "window":
                torch.save({
                    "model_state_dict": best_state,
                    "epoch": best_epoch,
                    "val_loss": best_val,
                    "config": {
                        "R": model.R,
                    },
                    "normalization": {
                        "mean_feat": mean_feat.cpu(),
                        "std_feat":  std_feat.cpu(),
                        "E_mean":    E_mean.cpu(),
                        "E_std":     E_std.cpu(),
                        "N_grid":    int(N_grid),
                    }
                }, ckpt_path)
            elif learning_regime == "gaussmixt":
                torch.save({
                    "model_state_dict": best_state,
                    "epoch": best_epoch,
                    "val_loss": best_val,
                    "config": {
                        "n_components": model.kernel_conv.n_components,                    },
                    "normalization": {
                        "mean_feat": mean_feat.cpu(),
                        "std_feat":  std_feat.cpu(),
                        "E_mean":    E_mean.cpu(),
                        "E_std":     E_std.cpu(),
                        "N_grid":    int(N_grid),
                    }
                }, ckpt_path)
            elif learning_regime == "dct_rs_blind":
                torch.save({
                    "model_state_dict": best_state,
                    "epoch": best_epoch,
                    "val_loss": best_val,
                    "config": {
                        "learning_mode": "dct_rs_blind",  
                        "zero_r_flag": model.nonlocal_kernel.zero_r_flag                  }, 
                    "normalization": {
                        "mean_feat": mean_feat.cpu(),
                        "std_feat":  std_feat.cpu(),
                        "E_mean":    E_mean.cpu(),
                        "E_std":     E_std.cpu(),
                        "N_grid":    int(N_grid),
                    }
                }, ckpt_path)
            elif learning_regime == "dct_exp_rs_mixture":
                    torch.save({
                    "model_state_dict": best_state,
                    "epoch": best_epoch,
                    "val_loss": best_val,
                    "config": {
                        "learning_mode": "dct_exp_rs_mixture",  
                        "zero_r_flag": model.nonlocal_kernel.zero_r_flag,
                        "n_components": model.nonlocal_kernel.n_components,}, 
                    "normalization": {
                        "mean_feat": mean_feat.cpu(),
                        "std_feat":  std_feat.cpu(),
                        "E_mean":    E_mean.cpu(),
                        "E_std":     E_std.cpu(),
                        "N_grid":    int(N_grid),
                    }
                }, ckpt_path)
            elif learning_regime == "dct_ms_blind":
                    torch.save({
                    "model_state_dict": best_state,
                    "epoch": best_epoch,
                    "val_loss": best_val,
                    "config": {
                        "learning_mode": "dct_ms_blind",  
                        "range_ms": model.nonlocal_kernel.range_ms,}, 
                    "normalization": {
                        "mean_feat": mean_feat.cpu(),
                        "std_feat":  std_feat.cpu(),
                        "E_mean":    E_mean.cpu(),
                        "E_std":     E_std.cpu(),
                        "N_grid":    int(N_grid),
                    }
                }, ckpt_path)
            elif learning_regime == "hybrid":
                torch.save({
                        "model_state_dict": best_state,
                        "epoch": best_epoch,
                        "val_loss": best_val,
                        "config": {  
                            "R": model.R,
                            "n_exp_components": model.n_exp_components,}, 
                        "normalization": {
                            "mean_feat": mean_feat.cpu(),
                            "std_feat":  std_feat.cpu(),
                            "E_mean":    E_mean.cpu(),
                            "E_std":     E_std.cpu(),
                            "N_grid":    int(N_grid),
                        }
                    }, ckpt_path)
        
        else:
            since_improved += 1

        if (epoch % 10) == 0 or epoch == 1:
            print(f"[{epoch:04d}] train={train_loss:.6f} | val={val_loss:.6f} "
                  f"| best_val={best_val:.6f} (epoch {best_epoch})")

        if since_improved >= patience:
            print(f"Early stopping at epoch {epoch} (best @ {best_epoch}).")
            break

    # restore best
    if best_state is not None:
        model.load_state_dict(best_state)

    csv_path = os.path.join(ckpt_dir, f"{run_name}_history.csv")
    try:
        import csv
        with open(csv_path, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(["epoch", "train_loss", "val_loss"])
            for i, (tr, va) in enumerate(zip(hist["train_loss"], hist["val_loss"]), start=1):
                w.writerow([i, tr, va])
    except Exception as e:
        print(f"[warn] could not write CSV: {e}")

    return hist, best_epoch


ckpt_dir = "LearningNonLocalKernel_wRSDCTKernel_checkpoints"
flag_train = True  # set to True to train models
learning_regime = "hybrid"  # "dct_rs_blind", "dct_exp_rs_mixture", "dct_ms_blind", or "hybrid"
R = 5  # for hybrid
n_exp_components = 1  # for hybrid
range_ms = 50  # for dct_ms_blind

if flag_train:

    torch.manual_seed(1234) # for reproducibility
    if kernel_regime == "exp":
            run_name = f"rs_dct_kernel_" + data_regime + '_' + kernel_regime + f"_xi{xi}" + learning_regime
    elif kernel_regime == "yukawa":
        run_name = f"rs_dct_kernel_" + data_regime + '_' + kernel_regime + f"_lam{lam}" + learning_regime
    elif kernel_regime == "power":
        run_name = f"rs_dct_kernel_" + data_regime + '_' + kernel_regime + f"_alpha{alpha}" + learning_regime
     
    if learning_regime == "dct_exp_rs_mixture":
        model = DCTKernelEnergyNN(learning_mode='dct_exp_rs_mixture', n_components=1).to(device)
    elif learning_regime == "dct_rs_blind":
        model = DCTKernelEnergyNN(learning_mode='dct_rs_blind').to(device)
    elif learning_regime == "dct_ms_blind":
        model = DCTKernelEnergyNN(learning_mode='dct_ms_blind', range_ms=range_ms).to(device)
    elif learning_regime == "hybrid":
        model = HybridKernelEnergyNN(R=R, n_exp_components=n_exp_components).to(device)

    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    # Reduce LR when val loss plateaus
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=patience, cooldown=2, min_lr=1e-6
    )

    hist, best_epoch = train_with_early_stopping(
        model,
        train_loader,
        val_loader,
        criterion,
        scheduler=scheduler,
        max_epochs=N_epochs,
        patience=patience,
        min_delta=min_delta,
        ckpt_dir=ckpt_dir,
        run_name=run_name,
        learning_regime=learning_regime,
    )


In [None]:
if kernel_regime == "exp":
            run_name = f"rs_dct_kernel_" + data_regime + '_' + kernel_regime + f"_xi{xi}" + learning_regime
elif kernel_regime == "yukawa":
    run_name = f"rs_dct_kernel_" + data_regime + '_' + kernel_regime + f"_lam{lam}" + learning_regime   
elif kernel_regime == "power":
    run_name = f"rs_dct_kernel_" + data_regime + '_' + kernel_regime + f"_alpha{alpha}" + learning_regime

path = ckpt_dir + f"/{run_name}_history.csv"
hist_df = pd.read_csv(path)
print(hist_df.head())
hist_df.plot(x="epoch", y=["train_loss", "val_loss"], logy=True, grid=True, title=run_name)


Evaluate performance of training

In [None]:
if kernel_regime == "exp":
            run_name = f"rs_dct_kernel_" + data_regime + '_' + kernel_regime + f"_xi{xi}" + learning_regime
elif kernel_regime == "yukawa":
    run_name = f"rs_dct_kernel_" + data_regime + '_' + kernel_regime + f"_lam{lam}" + learning_regime
elif kernel_regime == "power":
    run_name = f"rs_dct_kernel_" + data_regime + '_' + kernel_regime + f"_alpha{alpha}" + learning_regime

if learning_regime == "hybrid":
    model, normalization, epoch, val_loss = load_checkpoint(
        ckpt_dir + f"/{run_name}_best.pt",
        HybridKernelEnergyNN,
        device=device
    )
else:
    model, normalization, epoch, val_loss = load_checkpoint(
        ckpt_dir + f"/{run_name}_best.pt",
        DCTKernelEnergyNN,
        device=device
    )

if learning_regime == "dct_exp_rs_mixture":
     print("Learned Exponential Mixture RS DCT Kernel parameters:")
     print(model.nonlocal_kernel.n_components)
     print("Amplitudes A_n:", model.nonlocal_kernel.amplitudes.detach().cpu().numpy())
     print("Sigmas sigma_n:", F.softplus(model.nonlocal_kernel.log_sigmas).detach().cpu().numpy())
     with torch.no_grad():
        k_full = model.nonlocal_kernel.build_kernel().detach().cpu().numpy()  # (N_grid,)
elif learning_regime == "dct_rs_blind":
    with torch.no_grad():
        k_full = model.nonlocal_kernel.rs_kernel.detach().cpu().numpy()  # (N_grid,)
elif learning_regime == "dct_ms_blind":
    print("Learned MS DCT Kernel parameters:")
    print("Range m:", model.nonlocal_kernel.range_ms)
    with torch.no_grad():
        lam_K = torch.zeros(N_grid, dtype=dtype)
        lam_K[0] = 0.0
        lam_K[1:1+range_ms] = model.nonlocal_kernel.ms_kernel.detach().cpu()
        k_full = kernel_from_eigenvals_dct(lam_K).cpu().numpy()
elif learning_regime == "hybrid":
    print("Learned Hybrid Kernel parameters:")
    print("Local kernel range R:", model.R)
    print("Nonlocal exponential mixture components:", model.n_exp_components)

    print("Amplitudes A_n:", model.nonlocal_kernel.amplitudes.detach().cpu().numpy())
    print("Sigmas sigma_n:", F.softplus(model.nonlocal_kernel.log_sigmas).detach().cpu().numpy())
    
    with torch.no_grad():
        k_local_small = model.local_kernel.build_kernel().squeeze()  # (2R+1,)
        k_nonlocal = model.nonlocal_kernel.build_kernel()            # (N_grid,)

        R = model.R  # local range
        print("Local kernel values K_r (r=0..R):", k_local_small[R:].cpu().numpy())

        # take r = 0..R from symmetric kernel and pad zeros up to N_grid
        # k_local_small layout: [-R, ..., -1, 0, 1, ..., R]
        # index R is r=0, R+1..2R is r=1..R
        k_local_full = torch.cat(
            [
                k_local_small[R:],  # (R+1,)  -> r=0..R
                torch.zeros(N_grid - 1 - R, device=device, dtype=dtype),
            ],
            dim=0,
        )  # (N_grid,)
        k_full = k_local_full.cpu().numpy() + k_nonlocal.cpu().numpy()  # (N_grid,)


r_grid = np.arange(0, N_grid)

if kernel_regime == "exp":
    def K_true(r):
        return np.exp(-np.abs(r) / xi)
elif kernel_regime == "yukawa":
    def K_true(r):
        r_abs = np.abs(r)
        out = np.exp(-r_abs / lam) / np.maximum(r_abs, 1.0)
        return out * (r_abs > 0)
elif kernel_regime == "power":
    def K_true(r):
        r_abs = np.abs(r)
        out = 1.0 / (np.maximum(r_abs, 1.0)**alpha)
        return out * (r_abs > 0)    

k_true = K_true(r_grid)
lamb_K_true = kernel_eigenvals_dct(torch.tensor(k_true, dtype=dtype))
lamb_K_true[0] = 0.0  # enforce zero mode
k_true = kernel_from_eigenvals_dct(lamb_K_true).cpu().numpy()

lamb_K_full = kernel_eigenvals_dct(torch.tensor(k_full, dtype=dtype))
lamb_K_full[0] = 0.0  # enforce zero mode
k_full = kernel_from_eigenvals_dct(lamb_K_full).cpu().numpy()

plt.figure(figsize=(6,4))
plt.plot(r_grid, k_full , 'o-', label='Learned kernel (window)', linewidth=2)
plt.plot(r_grid, k_true, 's--', label='True kernel', alpha=0.7)
plt.axhline(0, color='k', linewidth=0.5)
plt.xlabel("r = i - j")
plt.ylabel("K_r")
plt.xlim(0, 100)
# plt.ylim((-0.01, 0.01))
plt.legend()
plt.grid(True)
if kernel_regime == "exp":
    plt.title(f"Local kernel learning ({data_regime}, exp, xi={xi})")
elif kernel_regime == "yukawa":
    plt.title(f"Local kernel learning ({data_regime}, yukawa, lam={lam})")
elif kernel_regime == "power":
    plt.title(f"Local kernel learning ({data_regime}, power, alpha={alpha})")
plt.tight_layout()
plt.show()

plt.figure(figsize=(6,4))
plt.plot(r_grid, lamb_K_full , 'o-', label='Learned kernel (window)', linewidth=2)
plt.plot(r_grid, lamb_K_true, 's--', label='True kernel', alpha=0.7)
plt.axhline(0, color='k', linewidth=0.5)
plt.xlabel("m")
plt.ylabel("lambda_m")
plt.xlim(0, 100)
#plt.ylim((-0.05, 0.0))
plt.legend()
plt.grid(True)
if kernel_regime == "exp":
    plt.title(f"Local kernel learning ({data_regime}, exp, xi={xi})")
elif kernel_regime == "yukawa":
    plt.title(f"Local kernel learning ({data_regime}, yukawa, lam={lam})")
elif kernel_regime == "power":
    plt.title(f"Local kernel learning ({data_regime}, power, alpha={alpha})")
plt.tight_layout()
plt.show()