 initial notes
 - dyadic regression (social sciences, economics, many notions of how to actually deploy model and how uncertainties across all pairs of a given test point with all training points can be informative)
 - pairwise regression (PAirwise Difference REgression, PADRE)
 - delta learning


 - Bilinear Transduction as a method for Out-Of-Sample prediction
 - similar words out of domain, domain adaptation, out of combination

In [203]:
import torch
import plotly.graph_objs as go
import pytorch_lightning as pl
import numpy as np

In [204]:
SEED = 1701
pl.seed_everything(SEED)
rng = np.random.default_rng()

Seed set to 1701


In [205]:
x = torch.rand((200, 2), dtype=torch.float32)
# y = x[:, 0].pow(2) + 1 * x[:, 1]  # x1^2 + x2
y = torch.sin(8*x[:, 0]) + 3 * x[:, 1].pow(2)  # sin(x1) + 3 * x2^2
y = y.reshape(-1, 1)

In [206]:
# ooc
train_idxs = torch.argwhere((x < 0.6).any(dim=1)).flatten()
test_idxs = torch.argwhere((x >= 0.6).all(dim=1)).flatten()
# oos
# train_idxs = torch.argwhere(x[:, 0].lt(0.8)).flatten()
# test_idxs = torch.argwhere(x[:, 0].greater_equal(0.8)).flatten()
# very oos
# train_idxs = torch.argwhere((x <= 0.8).all(dim=1)).flatten()
# test_idxs = torch.argwhere((x > 0.8).any(dim=1)).flatten()

In [207]:
TRAIN_MARKER = dict(mode="markers", marker=dict(color="red", opacity=0.2, size=4))
TRAIN_PRED = dict(mode="markers", marker=dict(color="red", opacity=0.4, symbol="diamond", size=4))
TEST_MARKER = dict(mode="markers", marker=dict(color="blue", opacity=0.2, size=4))
TEST_PRED = dict(mode="markers", marker=dict(color="blue", opacity=0.4, symbol="diamond", size=4))

In [208]:
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=x[train_idxs, 0], y=x[train_idxs, 1], z=y[train_idxs].flatten(), **TRAIN_MARKER))
fig.add_trace(go.Scatter3d(x=x[test_idxs, 0], y=x[test_idxs, 1], z=y[test_idxs].flatten(), **TEST_MARKER))
fig.show()

In [209]:
class MLP(pl.LightningModule):
    def __init__(self, input_size, hidden_size, num_layers):
        super().__init__()
        modules = []
        for i in range(num_layers):
            modules.append(torch.nn.Linear(input_size if i == 0 else hidden_size, hidden_size))
            modules.append(torch.nn.ReLU())
        modules.append(torch.nn.Linear(hidden_size, 1))
        self.mlp = torch.nn.Sequential(*modules)

    def forward(self, x):
        return self.mlp(x)

    def _step(self, batch, name):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.mse_loss(y_hat, y)
        self.log(f"{name}/loss", loss, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self._step(batch, "train")

    def predict_step(self, batch, batch_idx):
        return self(batch[0])

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-3)

In [210]:
train_dataset = torch.utils.data.TensorDataset(x[train_idxs], y[train_idxs])
predict_dataset = torch.utils.data.TensorDataset(x[test_idxs], y[test_idxs])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
train_loader_noshuffle = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=False)
predict_loader = torch.utils.data.DataLoader(predict_dataset, batch_size=4)

# Initialize model
model = MLP(input_size=2, hidden_size=100, num_layers=3)

# Train
trainer = pl.Trainer(max_epochs=20, log_every_n_steps=1)
trainer.fit(model, train_loader)
y_test_pred = torch.vstack(trainer.predict(model, predict_loader))
y_train_pred = torch.vstack(trainer.predict(model, train_loader_noshuffle))
print(f"Training MSE: {torch.nn.functional.mse_loss(y_train_pred, y[train_idxs]):.3e}")
print(f"Testing MSE: {torch.nn.functional.mse_loss(y_test_pred, y[test_idxs]):.3e}")

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name | Type       | Params | Mode 
--------------------------------------------
0 | mlp  | Sequential | 20.6 K | train
--------------------------------------------
20.6 K    Trainable params
0         Non-trainable params
20.6 K    Total params
0.082     Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode


Epoch 19: 100%|██████████| 43/43 [00:00<00:00, 52.21it/s, v_num=76, train/loss=0.00209] 

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 43/43 [00:00<00:00, 51.08it/s, v_num=76, train/loss=0.00209]
Predicting DataLoader 0: 100%|██████████| 8/8 [00:00<00:00, 305.51it/s]
Predicting DataLoader 0: 100%|██████████| 43/43 [00:00<00:00, 306.01it/s]
Training MSE: 9.846e-03
Testing MSE: 3.662e-01


In [211]:
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=x[train_idxs, 0], y=x[train_idxs, 1], z=y[train_idxs].flatten(), **TRAIN_MARKER))
fig.add_trace(go.Scatter3d(x=x[train_idxs, 0], y=x[train_idxs, 1], z=y_train_pred.flatten(), **TRAIN_PRED))
fig.add_trace(go.Scatter3d(x=x[test_idxs, 0], y=x[test_idxs, 1], z=y[test_idxs].flatten(), **TEST_MARKER))
fig.add_trace(go.Scatter3d(x=x[test_idxs, 0], y=x[test_idxs, 1], z=y_test_pred.flatten(), **TEST_PRED))
fig.show()

Things to implement and see how they do:
 - regular old delta learning
 - bilinear transduction
 - learned distance delta learning

In [212]:
class LearnableMahalanobis(torch.nn.Module):
    def __init__(self, size: int = 2) -> None:
        super().__init__()
        self.L = torch.nn.Parameter(torch.rand((50, size)))

    def forward(self, batch: torch.Tensor):
        x_1, x_2 = batch
        dx = (x_1 - x_2)
        return (dx @ (self.L.T @ self.L) @ dx.T).diag().unsqueeze(-1)

In [213]:
class ManhattanDistance(torch.nn.Module):
    def forward(self, batch: torch.Tensor):
        return (batch[0] - batch[1]).sum(dim=1).unsqueeze(-1)

In [214]:
class DeltaLearning(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.distance = ManhattanDistance()
        self.mlp = MLP(1, 10, 3)

    def forward(self, batch):
        dist = self.distance(batch)
        return self.mlp(dist)

    def _step(self, batch, name):
        x1, x2, y = batch
        y_hat = self((x1, x2))
        loss = torch.nn.functional.mse_loss(y_hat, y)
        self.log(f"{name}/loss", loss, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self._step(batch, "train")

    def predict_step(self, batch, batch_idx):
        return self(batch[0:2])

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-3)

In [215]:
class LearnableDeltaLearning(DeltaLearning):
    def __init__(self):
        super().__init__()
        self.distance = LearnableMahalanobis(2)

In [216]:
class LazyDeltaLearning(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.mlp = MLP(4, 30, 3)

    def forward(self, batch):
        return self.mlp(torch.hstack(batch))

    def _step(self, batch, name):
        x1, x2, y = batch
        y_hat = self((x1, x2))
        loss = torch.nn.functional.mse_loss(y_hat, y)
        self.log(f"{name}/loss", loss, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self._step(batch, "train")

    def predict_step(self, batch, batch_idx):
        return self(batch[0:2])

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-3)

In [217]:
# convert to a delta learning problem
x1_train = torch.empty((train_idxs.shape[0]**2, 2))
x2_train = torch.empty((train_idxs.shape[0]**2, 2))
pairwise_y_train = torch.empty((train_idxs.shape[0]**2, 1))
count = 0
for i in range(train_idxs.shape[0]):
    for j in range(train_idxs.shape[0]):
        x1_train[count, :] = x[train_idxs][i, :]
        x2_train[count, :] = x[train_idxs][j, :]
        pairwise_y_train[count, 0] = y[train_idxs][i, 0] - y[train_idxs][j, 0]
        count += 1

x1_test = torch.empty((test_idxs.shape[0]**2, 2))
x2_test = torch.empty((test_idxs.shape[0]**2, 2))
pairwise_y_test = torch.empty((test_idxs.shape[0]**2, 1))
count = 0
for i in range(test_idxs.shape[0]):
    for j in range(test_idxs.shape[0]):
        x1_test[count, :] = x[test_idxs][i, :]
        x2_test[count, :] = x[test_idxs][j, :]
        pairwise_y_test[count, 0] = y[test_idxs][i, 0] - y[test_idxs][j, 0]
        count += 1

In [218]:
train_dataset = torch.utils.data.TensorDataset(x1_train, x2_train, pairwise_y_train)
predict_dataset = torch.utils.data.TensorDataset(x1_test, x2_test, pairwise_y_test)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
train_loader_noshuffle = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=False)
predict_loader = torch.utils.data.DataLoader(predict_dataset, batch_size=4)

# Initialize model
model = LazyDeltaLearning()

# Train
trainer = pl.Trainer(max_epochs=2, log_every_n_steps=1)
trainer.fit(model, train_loader)
y_test_pred = torch.vstack(trainer.predict(model, predict_loader))
y_train_pred = torch.vstack(trainer.predict(model, train_loader_noshuffle))
print(f"Training MSE: {torch.nn.functional.mse_loss(y_train_pred, pairwise_y_train):.3e}")
print(f"Testing MSE: {torch.nn.functional.mse_loss(y_test_pred, pairwise_y_test):.3e}")

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name | Type | Params | Mode 
--------------------------------------
0 | mlp  | MLP  | 2.0 K  | train
--------------------------------------
2.0 K     Trainable params
0         Non-trainable params
2.0 K     Total params
0.008     Total estimated model params size (MB)
9         Modules in train mode
0         Modules in eval mode


Epoch 1: 100%|██████████| 7141/7141 [02:06<00:00, 56.43it/s, v_num=77, train/loss=0.000593]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 7141/7141 [02:06<00:00, 56.42it/s, v_num=77, train/loss=0.000593]
Predicting DataLoader 0: 100%|██████████| 241/241 [00:00<00:00, 249.57it/s]
Epoch 0:   0%|          | 0/1600 [1:43:30<?, ?it/s] [00:24<00:04, 249.61it/s]
Epoch 0:   0%|          | 0/1600 [1:30:21<?, ?it/s]
Epoch 0:   0%|          | 0/1600 [1:25:06<?, ?it/s]
Epoch 0:   0%|          | 0/1600 [1:24:42<?, ?it/s]
Predicting DataLoader 0: 100%|██████████| 7141/7141 [00:29<00:00, 241.19it/s]
Training MSE: 2.731e-02
Testing MSE: 1.235e-01


In [219]:
# undo the delta learning mapping
n_anchors = 10
anchor_idxs = rng.choice(torch.arange(y[train_idxs].shape[0]), n_anchors)
x_anchors = torch.tile(x[train_idxs][anchor_idxs], (x[test_idxs].shape[0], 1))
y_anchors = torch.tile(y[train_idxs][anchor_idxs], (x[test_idxs].shape[0], 1))
x_test = x[test_idxs].repeat_interleave(n_anchors, dim=0)

predict_dataset = torch.utils.data.TensorDataset(x_anchors, x_test, y_anchors)
predict_loader = torch.utils.data.DataLoader(predict_dataset, batch_size=4)
pairwise_test_pred = torch.vstack(trainer.predict(model, predict_loader))

pairwise_test_pred = y_anchors - pairwise_test_pred
test_predict = torch.tensor(list(map(torch.mean, torch.split(pairwise_test_pred, n_anchors))))
# do the same for training
x_anchors = torch.tile(x[train_idxs][anchor_idxs], (x[train_idxs].shape[0], 1))
y_anchors = torch.tile(y[train_idxs][anchor_idxs], (x[train_idxs].shape[0], 1))
x_test = x[train_idxs].repeat_interleave(n_anchors, dim=0)

predict_dataset = torch.utils.data.TensorDataset(x_anchors, x_test, y_anchors)
predict_loader = torch.utils.data.DataLoader(predict_dataset, batch_size=4)
pairwise_test_pred = torch.vstack(trainer.predict(model, predict_loader))

pairwise_test_pred = y_anchors - pairwise_test_pred
train_predict = torch.tensor(list(map(torch.mean, torch.split(pairwise_test_pred, n_anchors))))


Predicting:   0%|          | 0/78 [00:00<?, ?it/s]

Predicting DataLoader 0: 100%|██████████| 78/78 [00:00<00:00, 299.97it/s]
Predicting DataLoader 0: 100%|██████████| 423/423 [00:01<00:00, 250.58it/s]


In [220]:
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=x[train_idxs, 0], y=x[train_idxs, 1], z=y[train_idxs].flatten(), **TRAIN_MARKER))
fig.add_trace(go.Scatter3d(x=x[train_idxs, 0], y=x[train_idxs, 1], z=train_predict.flatten(), **TRAIN_PRED))
fig.add_trace(go.Scatter3d(x=x[test_idxs, 0], y=x[test_idxs, 1], z=y[test_idxs].flatten(), **TEST_MARKER))
fig.add_trace(go.Scatter3d(x=x[test_idxs, 0], y=x[test_idxs, 1], z=test_predict.flatten(), **TEST_PRED))
fig.show()