In [None]:
from prior.mlp_scm import MLPSCM
from TempMLP_SCM import TemporalMLPSCM
from metrics import evaluate_dataset_temporality

model = MLPSCM(seq_len=500, num_features=10, is_causal=True)
X, y = model.forward()

model = TemporalMLPSCM(seq_len=500, num_features=20, alpha=0.5, beta=1.5, period=10, use_periodicity=True)
X, y = model.forward()

print(X.shape, y.shape)
evaluate_dataset_temporality(X)

In [None]:
from metrics import (
    dataset_signature, compute_correlation_signature,
    pairwise_distances, diversity_metrics,
    compute_pairwise_corr_distances, summarize_corr_diversity,
    plot_all_diversity
)
from TempMLP_SCM import TemporalMLPSCM

# Ton modèle
model = TemporalMLPSCM(
    seq_len=100,
    num_features=10,
    num_causes=10,
    num_layers=4,
    hidden_dim=32,
    alpha=0.3,
    beta=1.2,
    period=20,
    use_periodicity=False,
    device="cpu",
)

# -------------------------------------------------------
# 1. Generate datasets
# -------------------------------------------------------
X_list = []
for k in range(10):
    X, y = model.generate_dataset(n_individuals=50)
    X_list.append(X)

# -------------------------------------------------------
# 2. Compute signatures
# -------------------------------------------------------
signatures = [dataset_signature(X) for X in X_list]

combined_sigs = [s["combined"]  for s in signatures]
marg_sigs     = [s["marginal"]  for s in signatures]
temp_sigs     = [s["temporal"]  for s in signatures]
struct_sigs   = [s["structure"] for s in signatures]

corr_sigs = [compute_correlation_signature(X) for X in X_list]

# -------------------------------------------------------
# 3. Compute distance matrices
# -------------------------------------------------------
D_global     = pairwise_distances(combined_sigs)
D_marginal   = pairwise_distances(marg_sigs)
D_temporal   = pairwise_distances(temp_sigs)
D_structure  = pairwise_distances(struct_sigs)
D_corr       = compute_pairwise_corr_distances(corr_sigs)

# -------------------------------------------------------
# 4. Print stats
# -------------------------------------------------------
print("=== GLOBAL DIVERSITY ===")
print(diversity_metrics(D_global))

print("\n=== MARGINAL DIVERSITY ===")
print(diversity_metrics(D_marginal))

print("\n=== TEMPORAL DIVERSITY ===")
print(diversity_metrics(D_temporal))

print("\n=== STRUCTURE DIVERSITY ===")
print(diversity_metrics(D_structure))

print("\n=== CORRELATION STRUCTURE DIVERSITY ===")
print(summarize_corr_diversity(D_corr))

# -------------------------------------------------------
# 5. Visualisation
# -------------------------------------------------------
plot_all_diversity(D_global, D_corr)

In [None]:
import random
import numpy as np
import torch.nn as nn
from TempMLP_SCM import TemporalMLPSCM

def sample_random_hyperparameters():
    activations = [nn.Tanh, nn.ReLU, nn.GELU]

    hp = {
        # architecture
        "num_layers": random.choice([2, 3, 4, 5]),
        "hidden_dim": random.choice([16, 32, 64, 128]),
        "num_causes": random.choice([4, 6, 8, 10]),

        # temporal dynamics
        "alpha": np.random.uniform(0.1, 0.9),
        "beta": np.random.uniform(0.0, 1.5),
        "period": random.choice([5, 10, 20, 30, 40]),
        "use_periodicity": random.choice([True, False]),

        # noise & init
        "noise_std": np.exp(np.random.uniform(np.log(1e-3), np.log(5e-2))),
        "mlp_dropout_prob": np.random.uniform(0.0, 0.3),
        "block_wise_dropout": random.choice([True, False]),
        "init_std": np.random.uniform(0.5, 1.5),

        # activations
        "mlp_activations": random.choice(activations),
    }
    return hp

def instantiate_model_from_hparams(
    hp,
    seq_len=100,
    num_features=10,
    num_outputs=1,
    device="cpu"
):
    model = TemporalMLPSCM(
        seq_len=seq_len,
        num_features=num_features,
        num_outputs=num_outputs,
        num_layers=hp["num_layers"],
        hidden_dim=hp["hidden_dim"],
        num_causes=hp["num_causes"],
        alpha=hp["alpha"],
        beta=hp["beta"],
        period=hp["period"],
        use_periodicity=hp["use_periodicity"],
        noise_std=hp["noise_std"],
        mlp_dropout_prob=hp["mlp_dropout_prob"],
        block_wise_dropout=hp["block_wise_dropout"],
        init_std=hp["init_std"],
        mlp_activations=hp["mlp_activations"],
        device=device,
    )
    return model

model_hparams = sample_random_hyperparameters()
model = instantiate_model_from_hparams(model_hparams, device="cpu")

# -------------------------------------------------------
# 1. Generate datasets
# -------------------------------------------------------
X_list = []
for k in range(10):
    X, y = model.generate_dataset(n_individuals=50)
    X_list.append(X)

# -------------------------------------------------------
# 2. Compute signatures
# -------------------------------------------------------
signatures = [dataset_signature(X) for X in X_list]
print(signatures)

combined_sigs = [s["combined"]  for s in signatures]
marg_sigs     = [s["marginal"]  for s in signatures]
temp_sigs     = [s["temporal"]  for s in signatures]
struct_sigs   = [s["structure"] for s in signatures]

corr_sigs = [compute_correlation_signature(X) for X in X_list]

# -------------------------------------------------------
# 3. Compute distance matrices
# -------------------------------------------------------
D_global     = pairwise_distances(combined_sigs)
D_marginal   = pairwise_distances(marg_sigs)
D_temporal   = pairwise_distances(temp_sigs)
D_structure  = pairwise_distances(struct_sigs)
D_corr       = compute_pairwise_corr_distances(corr_sigs)

# -------------------------------------------------------
# 4. Print stats
# -------------------------------------------------------
print("=== GLOBAL DIVERSITY ===")
print(diversity_metrics(D_global))

print("\n=== MARGINAL DIVERSITY ===")
print(diversity_metrics(D_marginal))

print("\n=== TEMPORAL DIVERSITY ===")
print(diversity_metrics(D_temporal))

print("\n=== STRUCTURE DIVERSITY ===")
print(diversity_metrics(D_structure))

print("\n=== CORRELATION STRUCTURE DIVERSITY ===")
print(summarize_corr_diversity(D_corr))

# -------------------------------------------------------
# 5. Visualisation
# -------------------------------------------------------
plot_all_diversity(D_global, D_corr)

In [None]:
model = TemporalMLPSCM(
    seq_len=500,
    num_features=20,
    num_outputs=1,
    num_layers=4,
    hidden_dim=32,
    num_causes=8,
    alpha=0.3,         # AR(1) plus faible
    beta=1.2,          # périodicité plus forte
    period=20,
    use_periodicity=False,
)

X, y = model.forward()
X,y = model.generate_dataset(n_individuals=100)
print(X.shape, y.shape)
evaluate_dataset_temporality(X)