In [1]:
from wn import net
from wn.data import MatchDataset

import torch
from torch.utils.data import DataLoader, Subset
from torch import nn
from torch.optim import AdamW

import pickle
import os
from time import perf_counter

In [2]:
# Load the tensorized data
with open("data/tensor_list.pkl", "rb") as f:
    input_data, y  = pickle.load(f)

ds = MatchDataset(input_data, y)

# Split into training and validation
idx = torch.randperm(len(input_data["p1_dob"]))
split_idx = idx.shape[0] // 4
train_ds = Subset(ds, idx[split_idx:])
validation_ds = Subset(ds, idx[:split_idx])

# Load the interface
with open("data/match_interface.pkl", "rb") as f:
    match_interface = pickle.load(f)

In [3]:
# Set up the network for a test.

col_encoding_size = 16
dim_model = 64
dim_ff = 64
n_transformer_layers = 4
n_transformer_heads = 4
n_output_layers = 3

# Special tabular input layer
input_layer = net.TabularInputLayer(
    interface=match_interface,
    col_encoding_size=col_encoding_size,
    embedding_size=dim_model-col_encoding_size,
    append_cls=True,
)

# Transformer encoder
tr = nn.TransformerEncoder(
    encoder_layer = nn.TransformerEncoderLayer(
        d_model=dim_model,
        nhead=n_transformer_heads,
        dim_feedforward=dim_ff,
        batch_first=True,
    ),
    num_layers=n_transformer_layers,
)

# Ugly convenience layer
cls_selector = net.SelectCLSEncoding()

# Output layer
output_layers = net.OutputLayers(dim_model, n_output_layers, 1)

whole_net = nn.Sequential(
    input_layer,
    tr,
    cls_selector,
    output_layers,
)

n_weights = sum([p.numel() for p in whole_net.parameters() if p.requires_grad])
print(f"Network has {n_weights} weights.")

# Setup device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
whole_net.to(device)

print(f"Using {device}")

Network has 123841 weights.
Using cuda:0


In [4]:
# Example training step
# Create a dataloader, optimizer, and criterion

# Large batches and low learning rates seem basically required for this network
# to reliably learn anything. Small batches just don't learn at all, large
# batches with higher learning rates have a tendency to spontaneously forget
# everything during training and then can't recover.
batch_size = 1024
learning_rate = 0.0001

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=3)
validation_dl = DataLoader(validation_ds, batch_size=batch_size, shuffle=True, num_workers=3)

print(f"Training: {len(train_dl)} batches of size {batch_size}")
print(f"Validation: {len(validation_dl)} batches")

optimizer = AdamW(
    filter(lambda p: p.requires_grad, whole_net.parameters()),
    lr=learning_rate,
)
criterion = nn.BCEWithLogitsLoss(reduction="sum")

# Per epoch
big_tick = perf_counter()

# For this amount of data, 5 epochs gives a decent indication of what kind
# of performance to expect, it seems.
for epoch in range(5):

    print(f"Starting epoch {epoch+1 :2} ------")

    # Training

    whole_net.train()

    tick = perf_counter()
    running_loss = 0.0
    running_n = 0
    running_correct = 0

    for i, batch in enumerate(train_dl):

        optimizer.zero_grad()

        # Get a batch
        x, y = batch

        x = net.to_(x, device)
        y = y.to(device)

        y_hat = whole_net(x)
        labels = y_hat > 0
        correct = (labels == y).sum()

        loss = criterion(y_hat, y)

        loss.backward()
        optimizer.step()

        running_correct += correct.item()
        running_loss += loss.item()
        running_n += y_hat.shape[0]
        
        if i % 50 == 49:
            print(
                f"Epoch {epoch + 1}, Batch {i+1 :4}: {running_loss / running_n :.3f} | ",
                f"Accuracy: {running_correct / running_n :.2f} | ",
                f"{running_n / (perf_counter() - tick) :6.0f} obs/sec | ",
                f"{perf_counter() - big_tick :.2f} s"
            )
            running_loss = 0.0
            running_n = 0
            running_correct = 0
            tick = perf_counter()

    # Validation

    whole_net.eval()

    with torch.no_grad():

        tick = perf_counter()
        valid_loss = 0.0
        valid_n = 0
        valid_correct = 0

        for i, batch in enumerate(validation_dl):

            # Get a batch
            x, y = batch

            x = net.to_(x, device)
            y = y.to(device)

            y_hat = whole_net(x)
            labels = y_hat > 0
            correct = (labels == y).sum()

            loss = criterion(y_hat, y)

            valid_correct += correct.item()
            valid_loss += loss.item()
            valid_n += y_hat.shape[0]
            
        print(
            f"Epoch {epoch + 1} validation loss: {running_loss / running_n :.3f} | ",
            f"Accuracy: {valid_correct / valid_n :.2f} | ",
            f"{valid_n / (perf_counter() - tick) :6.0f} obs/sec | ",
            f"{perf_counter() - big_tick :.2f} s"
        )

Training: 244 batches of size 1024
Validation: 82 batches
Starting epoch  1 ------
Epoch 1, Batch   50: 0.700 |  Accuracy: 0.51 |    9740 obs/sec |  5.26 s
Epoch 1, Batch  100: 0.698 |  Accuracy: 0.50 |   33402 obs/sec |  6.79 s
Epoch 1, Batch  150: 0.694 |  Accuracy: 0.51 |   33532 obs/sec |  8.32 s
Epoch 1, Batch  200: 0.692 |  Accuracy: 0.52 |   33687 obs/sec |  9.84 s
Epoch 1 validation loss: 0.687 |  Accuracy: 0.57 |   19270 obs/sec |  15.77 s
Starting epoch  2 ------
Epoch 2, Batch   50: 0.683 |  Accuracy: 0.56 |   13963 obs/sec |  19.44 s
Epoch 2, Batch  100: 0.683 |  Accuracy: 0.55 |   33690 obs/sec |  20.96 s
Epoch 2, Batch  150: 0.680 |  Accuracy: 0.56 |   32220 obs/sec |  22.55 s
Epoch 2, Batch  200: 0.679 |  Accuracy: 0.57 |   32947 obs/sec |  24.10 s
Epoch 2 validation loss: 0.676 |  Accuracy: 0.59 |   18746 obs/sec |  30.14 s
Starting epoch  3 ------
Epoch 3, Batch   50: 0.675 |  Accuracy: 0.57 |   13837 obs/sec |  33.84 s
Epoch 3, Batch  100: 0.671 |  Accuracy: 0.58 |   