In [None]:
from infrastructure.classes import Experiment, TrainParameters
from infrastructure.training import train
import torch

experiment = Experiment(
    tensor_parameters=dict(dtype=torch.float64, device=torch.device("cpu")),
    dataset_name="star",
    dataset_number_of_points=5000,
    dataset_parameters={"n_lobes": 4, "seed":1},
    dataloader_parameters=dict(batch_size=256, shuffle=True),
    pushforward_operator_name="entropic_optimal_transport_quantile_regression",
    pushforward_operator_parameters=dict(
        feature_dimension=2,
        response_dimension=2,
        hidden_dimension=4,
        number_of_hidden_layers=16,
        epsilon=1e-5,
        number_of_samples_for_entropy_dual_estimation=2048,
        activation_function_name="Mish"
    ),
    train_parameters=TrainParameters(
        number_of_epochs_to_train=5,
        verbose=True,
        optimizer_parameters=dict(lr=0.01),
        scheduler_parameters=dict(eta_min=0)
    )
)

model = train(experiment)
_ = model.eval()

In [None]:
%matplotlib qt
from utils.plot import plot_potentials_from_star_dataset

plot_potentials_from_star_dataset(
    model=model,
    number_of_conditional_points=4,
    number_of_points_to_sample=100,
    dataset_parameters=experiment.dataset_parameters,
    tensor_parameters=experiment.tensor_parameters
)

In [None]:
%matplotlib qt
from utils.plot import plot_quantile_levels_from_star_dataset

plot_quantile_levels_from_star_dataset(
    model=model,
    number_of_points_to_sample=1000,
    tensor_parameters=experiment.tensor_parameters,
    conditional_value=0.7,
    dataset_parameters=experiment.dataset_parameters,
    number_of_quantile_levels=4
)