# Equidistributed initial conditions

In [None]:
import math
import torch

@torch.no_grad()
def sobol_sphere(N: int, K: int, device="cuda", dtype=torch.float32) -> torch.Tensor:
    """
    Deterministic low-discrepancy points on S^{K-1}.
    Returns A: (N,K) with ||A[i]||2=1.
    """
    eng = torch.quasirandom.SobolEngine(dimension=K, scramble=False)
    U = eng.draw(N).to(device=device, dtype=dtype)  # (N,K) in [0,1)
    eps = torch.finfo(dtype).eps
    U = U.clamp(eps, 1 - eps)
    G = math.sqrt(2.0) * torch.erfinv(2 * U - 1)    # inverse Normal CDF
    A = G / G.norm(dim=1, keepdim=True)
    return A
    
@torch.no_grad()
def gram_correct(A: torch.Tensor, Phi: torch.Tensor, w: torch.Tensor, jitter=1e-10) -> torch.Tensor:
    """
    A: (N,K) unit sphere in R^K
    Phi: (K,nx)
    w: (nx,)
    returns C: (N,K) s.t. C^T M C = I approximately (M = Phi diag(w) Phi^T)
    """
    # M = (Phi * w) @ Phi^T
    M = (Phi * w.unsqueeze(0)) @ Phi.t()
    M = 0.5 * (M + M.t()) + jitter * torch.eye(M.shape[0], device=M.device, dtype=M.dtype)
    L = torch.linalg.cholesky(M)
    # C = L^{-T} A  (solve L^T C^T = A^T)
    C_T = torch.linalg.solve_triangular(L.t(), A.t(), upper=True)
    return C_T.t()

@torch.no_grad()
def interp_time_rows(t_grid: torch.Tensor, U_grid: torch.Tensor, t_query: torch.Tensor) -> torch.Tensor:
    """
    t_grid: (ntg,)
    U_grid: (ntg,nx)
    t_query: (nT,)
    returns Uq: (nT,nx) linear interpolation in time.
    Assumes t_query in [t_grid[0], t_grid[-1]] (clamps softly).
    """
    t0 = t_grid[0]
    t1 = t_grid[-1]
    tq = t_query.clamp(t0, t1)

    idx = torch.searchsorted(t_grid, tq, right=False)
    idx = idx.clamp(1, t_grid.numel() - 1)

    t_lo = t_grid[idx - 1]
    t_hi = t_grid[idx]
    w = (tq - t_lo) / (t_hi - t_lo)

    U_lo = U_grid[idx - 1]
    U_hi = U_grid[idx]
    return (1 - w).unsqueeze(1) * U_lo + w.unsqueeze(1) * U_hi

import numpy as np
import torch
from dataclasses import dataclass
from typing import Callable, List, Optional

# uses your existing trapz_weights_1d + NeuralGalerkinDatasetConfig + NeuralGalerkinDataset
# from neural_galerkin_ode.py :contentReference[oaicite:4]{index=4} :contentReference[oaicite:5]{index=5}

@torch.no_grad()
def create_NeuralGalerkin_dataset_fast_gpu(
    solution_functions: List,                 # ideally BurgersSolution objects
    x_grid: np.ndarray,
    t_min: float,
    t_max: float,
    basis_eval: Callable[[np.ndarray], np.ndarray],  # returns (K,nx)
    n_time_samples: int = 200,
    t_sampling: str = "grid",
    seed: Optional[int] = 0,
    weights: Optional[np.ndarray] = None,
    device: str = "cuda",
    dtype: torch.dtype = torch.float32,
    normalize_t: bool = False,
    normalize_c: bool = False,
    return_k_coords: bool = False,
    pde_name: str = "unknown",
    batch_size: int = 8,
):
    from .neural_galerkin_ode import trapz_weights_1d, NeuralGalerkinDatasetConfig, NeuralGalerkinDataset

    x_grid = np.asarray(x_grid, dtype=float)
    M = len(solution_functions)
    if weights is None:
        weights = trapz_weights_1d(x_grid)  # :contentReference[oaicite:6]{index=6}
    w = torch.tensor(weights, device=device, dtype=dtype)  # (nx,)

    # time samples
    if t_sampling == "grid":
        t_1d = torch.linspace(t_min, t_max, n_time_samples, device=device, dtype=dtype)
        T_all = t_1d.unsqueeze(0).repeat(M, 1)  # (M,nT)
    elif t_sampling == "random":
        g = torch.Generator(device="cpu")
        if seed is not None:
            g.manual_seed(int(seed))
        tr = torch.rand((M, n_time_samples), generator=g).to(device=device, dtype=dtype)
        T_all = (t_min + (t_max - t_min) * tr).sort(dim=1).values
    else:
        raise ValueError("t_sampling must be 'grid' or 'random'")

    # basis + projection matrix on GPU
    Phi_np = np.asarray(basis_eval(x_grid), dtype=float)  # (K,nx) :contentReference[oaicite:7]{index=7}
    Phi = torch.tensor(Phi_np, device=device, dtype=dtype)
    P = (w.unsqueeze(0) * Phi).t().contiguous()          # (nx,K)

    K = Phi.shape[0]
    C_all = torch.empty((M, n_time_samples, K), device=device, dtype=dtype)

    # batch over trajectories
    for i0 in range(0, M, batch_size):
        i1 = min(M, i0 + batch_size)
        for m in range(i0, i1):
            sol = solution_functions[m]

            # Fast path for BurgersSolution: use precomputed grid (CPU->GPU once)
            if hasattr(sol, "U") and hasattr(sol, "t_vals") and hasattr(sol, "z_vals"):
                # assumes x_grid matches sol.z_vals (your code already enforces this) :contentReference[oaicite:8]{index=8}
                U_grid = torch.tensor(sol.U, device=device, dtype=dtype)        # (ntg,nx)
                t_grid = torch.tensor(sol.t_vals, device=device, dtype=dtype)   # (ntg,)
                Uq = interp_time_rows(t_grid, U_grid, T_all[m])                # (nT,nx)
            else:
                # Fallback: calls the function (CPU), then moves to GPU
                t_m = T_all[m].detach().cpu().numpy()
                Tm = np.repeat(t_m[:, None], x_grid.size, axis=1)
                Xm = np.repeat(x_grid[None, :], t_m.size, axis=0)
                U_cpu = np.asarray(sol(Tm, Xm), dtype=float)
                Uq = torch.tensor(U_cpu, device=device, dtype=dtype)

            C_all[m] = Uq @ P  # (nT,K)

    # normalization stats (keep as numpy scalars/arrays for compatibility with your Dataset) :contentReference[oaicite:9]{index=9}
    t_mean = float(T_all.mean().item())
    t_std  = float(T_all.std(unbiased=False).item() + 1e-8)
    c_mean = C_all.reshape(-1, K).mean(dim=0, keepdim=True).detach().cpu().numpy()
    c_std  = (C_all.reshape(-1, K).std(dim=0, unbiased=False, keepdim=True) + 1e-8).detach().cpu().numpy()

    T_store = T_all
    C_store = C_all
    if normalize_t:
        T_store = (T_store - t_mean) / t_std
    if normalize_c:
        C_store = (C_store - torch.tensor(c_mean, device=device, dtype=dtype)) / torch.tensor(c_std, device=device, dtype=dtype)

    cfg = NeuralGalerkinDatasetConfig(
        n_time_samples=n_time_samples,
        t_sampling=t_sampling,
        seed=seed,
        normalize_t=normalize_t,
        normalize_c=normalize_c,
        return_k_coords=return_k_coords,
        pde_name=pde_name,
    )

    # Construct dataset in the simplest compatible way: pass *physical* arrays to keep save/load semantics.
    # NeuralGalerkinDataset will move to `device` internally. :contentReference[oaicite:10]{index=10}
    ds = NeuralGalerkinDataset(
        cfg,
        t=T_all.detach().cpu().numpy(),
        c=C_all.detach().cpu().numpy(),
        device=device,
        dtype=dtype,
        x_grid=x_grid,
        basis_matrix=Phi_np,
    )
    return ds
