In [None]:
from wn import net
from wn.data import DataInterface, tr

import pandas as pd
import numpy as np

import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torch.optim import AdamW

import pickle
import os
from time import perf_counter

In [None]:
rng = np.random.default_rng()

x1 = rng.random(100_000)
x2 = rng.random(100_000)
x3 = rng.random(100_000)

# For convenience later
all_x = np.stack([x1, x2, x3], axis=1)

c1 = rng.integers(0, 2, size=100_000)
c2 = rng.integers(0, 2, size=100_000)

c_and = c1 * c2

# The function here is which max than if c1 or c2 is 0,
# and which min if both c1 and c2 are 1
y_ = np.zeros_like(x1, dtype=int)
y_[c_and == 0] = np.argmax(all_x[c_and == 0, :], axis=1)
y_[c_and == 1] = np.argmin(all_x[c_and == 1, :], axis=1)

dt = pd.DataFrame(
    {
        "x1": x1,
        "x2": x2,
        "x3": x3,
        "c1": c1,
        "c2": c2,
    }
)

interface = DataInterface(
    {
        "x1": "numeric",
        "x2": "numeric",
        "x3": "numeric",
        "c1": "categorical",
        "c2": "categorical",
    }
)

interface.complete(dt)

In [None]:
class ToyDataset(Dataset):
    def __init__(self, x, y, interface):

        super().__init__()
        self.data = {k: tr(x[k], k, interface) for k in interface.type_map}
        self.y = y
        self.interface = interface

    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, idx):

        x = {k: v[idx, :] for k, v in self.data.items()}
        y = self.y[idx, :]

        return x, y


ds = ToyDataset(dt, torch.tensor(y_, dtype=torch.long).unsqueeze(-1), interface)

In [None]:
# Set up the network for a test.

# Special tabular input layer
input_layer = net.TabularInputLayer(
    interface=interface,
    col_encoding_size=8,
    embedding_size=24,
    append_cls=True,
)

# Transformer encoder
transformer = nn.TransformerEncoder(
    encoder_layer=nn.TransformerEncoderLayer(
        d_model=32,
        nhead=4,
        dim_feedforward=32,
        batch_first=True,
    ),
    num_layers=4,
)

# Ugly convenience layer
cls_selector = net.SelectCLSEncoding()

# Output layer
output_layers = net.OutputLayers(32, 3, 3)

whole_net = nn.Sequential(
    input_layer,
    transformer,
    cls_selector,
    output_layers,
)

n_weights = sum([p.numel() for p in whole_net.parameters() if p.requires_grad])
print(f"Network has {n_weights} weights.")

# Setup device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
whole_net.to(device)

print(f"Using {device}")

In [None]:
# Example training step
# Create a dataloader, optimizer, and criterion

dl = DataLoader(ds, batch_size=64, shuffle=True)
print(f"{len(dl)} batches")

optimizer = AdamW([p for p in whole_net.parameters() if p.requires_grad])
criterion = nn.CrossEntropyLoss(reduction="sum")

# Per epoch

for epoch in range(5):

    print(f"Starting epoch {epoch+1}")

    whole_net.train()

    tick = big_tick = perf_counter()
    running_loss = 0.0
    running_n = 0
    running_correct = 0

    for i, batch in enumerate(dl):

        optimizer.zero_grad()

        # Get a batch
        x, y = batch

        x = net.to_(x, device)
        y = y.squeeze().to(device)

        y_hat = whole_net(x)
        labels = torch.max(y_hat, dim=1).indices
        correct = (labels == y).sum()

        loss = criterion(y_hat, y)

        loss.backward()
        optimizer.step()

        running_correct += correct.item()
        running_loss += loss.item()
        running_n += y_hat.shape[0]

        if i % 50 == 49:
            print(
                f"Epoch {epoch + 1}, Batch {i+1 :4}: {running_loss / running_n :.3f} | ",
                f"Accuracy: {running_correct / running_n :.2f} | ",
                f"{running_n / (perf_counter() - tick) :6.1f} obs/sec | ",
                f"{perf_counter() - big_tick :.2f} seconds",
            )
            running_loss = 0.0
            running_n = 0
            running_correct = 0
            tick = perf_counter()

This architecture can in fact learn this problem, so it's not totally broken at least.