In [1]:
import lightning
import torch
import plotly.graph_objs as go

In [2]:
from nepare.nn import NeuralPairwiseRegressor as NPR
from nepare.data import PairwiseAugmentedDataset, PairwiseAnchoredDataset

In [21]:
TRAIN_MARKER = dict(mode="markers", marker=dict(color="red", opacity=0.2, size=4))
TRAIN_PRED = dict(mode="markers", marker=dict(color="red", opacity=0.4, symbol="diamond", size=4))
VAL_MARKER = dict(mode="markers", marker=dict(color="green", opacity=0.2, size=4))
VAL_PRED = dict(mode="markers", marker=dict(color="green", opacity=0.4, symbol="diamond", size=4))
TEST_MARKER = dict(mode="markers", marker=dict(color="blue", opacity=0.2, size=4))
TEST_PRED = dict(mode="markers", marker=dict(color="blue", opacity=0.4, symbol="diamond", size=4))

In [4]:
SEED = 1701
lightning.seed_everything(SEED)

Seed set to 1701


1701

In [5]:
X = torch.rand((200, 2), dtype=torch.float32)
y = torch.sin(8*X[:, 0]) + 3 * X[:, 1].pow(2)  # sin(x1) + 3 * x2^2
y = y.reshape(-1, 1)

In [26]:
train_idxs = torch.argwhere((X < 0.3).any(dim=1)).flatten()
val_idxs = torch.argwhere((X >= 0.3).logical_and(X < 0.7).any(dim=1)).flatten()  # wrong, fix this
test_idxs = torch.argwhere((X >= 0.7).all(dim=1)).flatten()

In [27]:
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=X[train_idxs, 0], y=X[train_idxs, 1], z=y[train_idxs].flatten(), **TRAIN_MARKER))
fig.add_trace(go.Scatter3d(x=X[val_idxs, 0], y=X[val_idxs, 1], z=y[val_idxs].flatten(), **VAL_MARKER))
fig.add_trace(go.Scatter3d(x=X[test_idxs, 0], y=X[test_idxs, 1], z=y[test_idxs].flatten(), **TEST_MARKER))
fig.show()

In [9]:
training_dataset = PairwiseAugmentedDataset(X[train_idxs], y[train_idxs], how='full')
validation_dataset = PairwiseAnchoredDataset(X[train_idxs], y[train_idxs], X[val_idxs], y[val_idxs], how='full')
testing_dataset = PairwiseAnchoredDataset(X[train_idxs], y[train_idxs], X[test_idxs], y[test_idxs], how='full')