In [1]:
from statistics import mean, stdev

In [None]:
import lightning
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
import torch
import plotly.graph_objs as go
from sklearn.metrics import mean_absolute_percentage_error

In [3]:
from nepare.nn import NeuralPairwiseRegressor as NPR, FeedforwardNeuralNetwork as FNN
from nepare.data import PairwiseAugmentedDataset, PairwiseAnchoredDataset, PairwiseInferenceDataset

In [4]:
TRAIN_MARKER = dict(mode="markers", marker=dict(color="red", opacity=0.2, size=4))
TRAIN_PRED = dict(mode="markers", marker=dict(color="red", opacity=0.4, symbol="diamond", size=4))
VAL_MARKER = dict(mode="markers", marker=dict(color="green", opacity=0.2, size=4))
VAL_PRED = dict(mode="markers", marker=dict(color="green", opacity=0.4, symbol="diamond", size=4))
TEST_MARKER = dict(mode="markers", marker=dict(color="blue", opacity=0.2, size=4))
TEST_PRED = dict(mode="markers", marker=dict(color="blue", opacity=0.4, symbol="diamond", size=4))

In [None]:
SEED = 1701
lightning.seed_everything(SEED)

In [6]:
X = torch.rand((200, 2), dtype=torch.float32)
y = torch.sin(8*X[:, 0]) + 3 * X[:, 1].pow(2)  # sin(x1) + 3 * x2^2
y = y.reshape(-1, 1)

In [7]:
train_idxs = torch.argwhere((X < 0.5).any(dim=1)).flatten()
val_idxs = train_idxs[0:int(len(train_idxs)*0.2)]
train_idxs = train_idxs[int(len(train_idxs)*0.2):]
test_idxs = torch.argwhere((X >= 0.5).all(dim=1)).flatten()

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=X[train_idxs, 0], y=X[train_idxs, 1], z=y[train_idxs].flatten(), **TRAIN_MARKER))
fig.add_trace(go.Scatter3d(x=X[val_idxs, 0], y=X[val_idxs, 1], z=y[val_idxs].flatten(), **VAL_MARKER))
fig.add_trace(go.Scatter3d(x=X[test_idxs, 0], y=X[test_idxs, 1], z=y[test_idxs].flatten(), **TEST_MARKER))
fig.show()

In [None]:
training_dataset = torch.utils.data.TensorDataset(X[train_idxs], y[train_idxs])
validation_dataset = torch.utils.data.TensorDataset(X[val_idxs], y[val_idxs])
testing_dataset = torch.utils.data.TensorDataset(X[test_idxs], y[test_idxs])
predict_dataset = torch.utils.data.TensorDataset(X[test_idxs])
train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=8, shuffle=True)
val_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=8)
test_loader = torch.utils.data.DataLoader(testing_dataset, batch_size=8)
predict_loader = torch.utils.data.DataLoader(predict_dataset, batch_size=4)
fnn = FNN(2, 50, 3)
early_stopping = EarlyStopping(monitor="validation/loss")
model_checkpoint = ModelCheckpoint(monitor="validation/loss")
trainer = lightning.Trainer(max_epochs=50, log_every_n_steps=1, callbacks=[early_stopping, model_checkpoint])
trainer.fit(fnn, train_loader, val_loader)
fnn = FNN.load_from_checkpoint(model_checkpoint.best_model_path)  # reload best model based on early stopping
trainer.test(fnn, test_loader)
y_pred = torch.vstack(trainer.predict(fnn, predict_loader)).numpy()


In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=X[train_idxs, 0], y=X[train_idxs, 1], z=y[train_idxs].flatten(), **TRAIN_MARKER))
fig.add_trace(go.Scatter3d(x=X[val_idxs, 0], y=X[val_idxs, 1], z=y[val_idxs].flatten(), **VAL_MARKER))
fig.add_trace(go.Scatter3d(x=X[test_idxs, 0], y=X[test_idxs, 1], z=y[test_idxs].flatten(), **TEST_MARKER))
fig.add_trace(go.Scatter3d(x=X[test_idxs, 0], y=X[test_idxs, 1], z=y_pred.flatten(), **TEST_PRED))
fig.show()
print(mean_absolute_percentage_error(y[test_idxs], y_pred, sample_weight=y[test_idxs]))

In [None]:
y[test_idxs].flatten()

In [None]:
y_pred

In [12]:
training_dataset = PairwiseAugmentedDataset(X[train_idxs], y[train_idxs], how='full')
validation_dataset = PairwiseAnchoredDataset(X[train_idxs], y[train_idxs], X[val_idxs], y[val_idxs], how='full')
# for metrics reported in the pairwise space
testing_dataset = PairwiseAnchoredDataset(X[train_idxs], y[train_idxs], X[test_idxs], y[test_idxs], how='full')
# for metrics reported in the absolute space
predict_dataset = PairwiseInferenceDataset(X[train_idxs], y[train_idxs], X[test_idxs], how='full')
train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=64)
test_loader = torch.utils.data.DataLoader(testing_dataset, batch_size=64)
predict_loader = torch.utils.data.DataLoader(predict_dataset, batch_size=4)

In [13]:
npr = NPR(2, 50, 3)
early_stopping = EarlyStopping(monitor="validation/loss")
model_checkpoint = ModelCheckpoint(monitor="validation/loss")

In [None]:
trainer = lightning.Trainer(max_epochs=50, log_every_n_steps=1, callbacks=[early_stopping, model_checkpoint])
trainer.fit(npr, train_loader, val_loader)

In [15]:
npr = NPR.load_from_checkpoint(model_checkpoint.best_model_path)  # reload best model based on early stopping

In [None]:
trainer.test(npr, test_loader)

In [None]:
test_pred = torch.vstack(trainer.predict(npr, predict_loader))  # refactor this and the below cells into the nepare.inference.predict function

In [18]:
# do the collation in Python-land for simplicity, sacrificing speed for now
absolute_predictions = {idx: [] for idx in range(predict_loader.dataset.Xs[1].shape[0])}
for pair, prediction in zip(predict_loader.dataset.pairs, test_pred):
    # for a pair of inputs i,j the network predicts delta_i,j in that order
    # map back to the actual values here
    if pair.src_2 == 1:  # inference point is in position two
        # y_1 - y_2 = f(x_1,x_2) -> y_1 - f(x_1,x_2) = y_2
        _pred = predict_loader.dataset.y_anchors[pair.idx_1] - prediction
        absolute_predictions[pair.idx_2].append(_pred.item())
    else:
        # y_1 - y_2 = f(x_1,x_2) -> y_1 = f(x_1,x_2) + y_2
        _pred = prediction + predict_loader.dataset.y_anchors[pair.idx_2]
        absolute_predictions[pair.idx_1].append(_pred.item())

In [None]:
y_pred = []
y_stdev = []
for idx, preds in absolute_predictions.items():
    print(f"Index {idx}: actual={y[test_idxs][idx].item():.3f}, predicted={mean(preds):.3f}+/-{stdev(preds):.3f}")
    y_pred.append(mean(preds))
    y_stdev.append(stdev(preds))

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=X[train_idxs, 0], y=X[train_idxs, 1], z=y[train_idxs].flatten(), **TRAIN_MARKER))
fig.add_trace(go.Scatter3d(x=X[val_idxs, 0], y=X[val_idxs, 1], z=y[val_idxs].flatten(), **VAL_MARKER))
fig.add_trace(go.Scatter3d(x=X[test_idxs, 0], y=X[test_idxs, 1], z=y[test_idxs].flatten(), **TEST_MARKER))
fig.add_trace(go.Scatter3d(x=X[test_idxs, 0], y=X[test_idxs, 1], z=y_pred, error_z=dict(
            type='data', # value of error bar given in data coordinates
            array=y_stdev,
            visible=True), **TEST_PRED))
fig.show()

In [None]:
print(mean_absolute_percentage_error(y[test_idxs], y_pred, sample_weight=y[test_idxs]))