In [None]:
from infrastructure.classes import Experiment, TrainParameters
from infrastructure.training import train
from pushforward_operators import FastNonLinearQuantileRegression
import torch

experiment = Experiment(
    tensor_parameters=dict(dtype=torch.float64, device=torch.device("cpu")),
    dataset_name="fnlvqr_banana",
    dataset_number_of_points=20000,
    dataloader_parameters=dict(batch_size=512, shuffle=True),
    pushforward_operator_name="fast_non_linear_vector_quantile_regression",
    pushforward_operator_parameters=dict(
        fnlvqr_mlp_arguments=dict(
            verbose=True,
            num_epochs=1,
            epsilon=5e-3,
            lr=0.4,
            gpu=True,
            skip=False,
            batchnorm=False,git a
            hidden_layers=(2, 10, 20),
            batchsize_y=None,
            batchsize_u=None,
            inference_batch_size=100,
            lr_factor=0.9,
            lr_patience=300,
            lr_threshold=0.5 * 0.01,
        ),
        feature_dimension=1,
        response_dimension=2,
        hidden_dimension=32,
        number_of_hidden_layers=4,
    ),
    train_parameters=TrainParameters(
        number_of_epochs_to_train=0,
        verbose=True,
        optimizer_parameters=dict(lr=0.1),
        scheduler_parameters=dict(eta_min=0)
    )
)

# model = train(experiment)
model = FastNonLinearQuantileRegression.load_class("../../experiments_full_13_09_2025_fnlvqr/fnlvqr_banana/fast_non_linear_vector_quantile_regression/weights.pth")
model.to(**experiment.tensor_parameters)
_ = model.eval()

In [4]:
import torch
from datasets import FNLVQR_Banana

dataset = FNLVQR_Banana(experiment.tensor_parameters)
X_gt, Y_gt = dataset.sample_joint(1000)
U = torch.rand(1000, 2).to(**experiment.tensor_parameters)
X = torch.rand(1000, 1).to(**experiment.tensor_parameters) * (X_gt.max() - X_gt.min()) + X_gt.min()
Y = model.push_u_given_x(U, X)
U_approx = model.push_y_given_x(Y_gt, X_gt)

In [None]:
%matplotlib qt
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot(121, projection='3d')
ax.scatter(X[:, 0], Y[:, 0], Y[:, 1])
ax.scatter(X_gt[:, 0], Y_gt[:, 0], Y_gt[:, 1])

plt.show()