In [None]:
import jax.random as random
from your_module.datasets import LinearGaussianDataset  # 请按实际路径调整
import numpy as np


In [None]:

def generate_linear_gaussian(seed: int,
                             intrinsic_dim: int,
                             ambient_dim: int,
                             noise_var: float = 0.0,
                             n_samples: int = 20000):
    """
    对应 run.py 中：
      --dataset linear_gaussian
      -dd <intrinsic_dim>      -> dimension
      --padding_dim <pad_dim>   -> padding_dimension
      --var_added <noise_var>   -> var_added

    ambient_dim = intrinsic_dim + padding_dim
    """
    padding_dim = ambient_dim - intrinsic_dim
    ds = LinearGaussianDataset(seed=seed,
                                dimension=intrinsic_dim,
                                intrinsic_dimension=intrinsic_dim,
                                padding_dimension=padding_dim,
                                var_added=noise_var)
    # get_batch 返回 shape = (n_samples, intrinsic_dim + padding_dim)
    X = ds.get_batch(n_samples)
    return np.array(X)

if __name__ == "__main__":
    # 对应论文中这几条命令的（seed=2）示例：
    configs = [
        # intrinsic=3, ambient=12
        {"seed": 2,  "intrinsic_dim": 3,  "ambient_dim": 12},
        # intrinsic=3, ambient=20
        {"seed": 2,  "intrinsic_dim": 3,  "ambient_dim": 20},
        # intrinsic=6, ambient=12
        {"seed": 2,  "intrinsic_dim": 6,  "ambient_dim": 12},
        # intrinsic=6, ambient=20
        {"seed": 2,  "intrinsic_dim": 6,  "ambient_dim": 20},
        # intrinsic=9, ambient=12
        {"seed": 2,  "intrinsic_dim": 9,  "ambient_dim": 12},
        # intrinsic=9, ambient=20
        {"seed": 2,  "intrinsic_dim": 9,  "ambient_dim": 20},
        # intrinsic=12, ambient=20
        {"seed": 2,  "intrinsic_dim": 12, "ambient_dim": 20},
    ]

    for cfg in configs:
        X = generate_linear_gaussian(seed=cfg["seed"],
                                     intrinsic_dim=cfg["intrinsic_dim"],
                                     ambient_dim=cfg["ambient_dim"],
                                     noise_var=0.0,        # 对应 var_added=0
                                     n_samples=20000)
        print(f"Generated X with shape {X.shape}  "
              f"(intrinsic={cfg['intrinsic_dim']}, ambient={cfg['ambient_dim']})")
