This notebook is for testing the hybrid tabular + sequential data representation.

In [1]:
from wn import net
from wn.data import MatchHistoryDataset

import torch
from torch.utils.data import DataLoader, Subset
from torch import nn
from torch.optim import AdamW

import pickle
import os
from time import perf_counter

In [2]:
# Load the tensorized data
with open("data/tensor_list.pkl", "rb") as f:
    input_data, y = pickle.load(f)

with open("data/history_tensor_list.pkl", "rb") as f:
    history_data, pid = pickle.load(f)

# Load the interfaces
with open("data/match_interface.pkl", "rb") as f:
    match_interface = pickle.load(f)

with open("data/history_interface.pkl", "rb") as f:
    history_interface = pickle.load(f)

history_size = 20

# Make a dataset
ds = MatchHistoryDataset(input_data, y, history_data, pid, history_size=20)

# Split into training and validation
idx = torch.randperm(len(input_data["p1_dob"]))
split_idx = idx.shape[0] // 4  # Just gets a 75/25 split, maximum laziness
train_ds = Subset(ds, idx[split_idx:])
validation_ds = Subset(ds, idx[:split_idx])

In [3]:
# Set up the network for a test.

col_encoding_size = 16
dim_model = 64
dim_ff = 64
n_transformer_layers = 4
n_transformer_heads = 4
n_output_layers = 3

# Special tabular input layer
table_input_layer = net.TabularInputLayer(
    interface=match_interface,
    col_encoding_size=col_encoding_size,
    embedding_size=dim_model - col_encoding_size,
    append_cls=True,
)

# Input layer for sequential features
sequence_input_layer = net.SequentialInputLayer(
    interface=history_interface,
    sequence_encoding_size=[history_size, col_encoding_size],
    embedding_size=dim_model - col_encoding_size,
)

output_layers = net.OutputLayers(dim_model, n_output_layers, 1)

# Transformer encoder
tr = nn.TransformerEncoder(
    encoder_layer=nn.TransformerEncoderLayer(
        d_model=dim_model,
        nhead=n_transformer_heads,
        dim_feedforward=dim_ff,
        batch_first=True,
    ),
    num_layers=n_transformer_layers,
)

whole_net = net.FusionNet(table_input_layer, sequence_input_layer, tr, output_layers)

n_weights = sum([p.numel() for p in whole_net.parameters() if p.requires_grad])
print(f"Network has {n_weights} weights.")

# Setup device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
whole_net.to(device)

print(f"Using {device}")

Network has 149953 weights.
Using cuda:0


In [4]:
# Training

# Training parameters. Initially lifted from the simpler network.

batch_size = 1024
learning_rate = 0.0001

# Create a dataloader, optimizer, and criterion

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=3)
validation_dl = DataLoader(
    validation_ds, batch_size=batch_size, shuffle=True, num_workers=3
)

print(f"Training: {len(train_dl)} batches of size {batch_size}")
print(f"Validation: {len(validation_dl)} batches")

optimizer = AdamW(
    filter(lambda p: p.requires_grad, whole_net.parameters()),
    lr=learning_rate,
)
criterion = nn.BCEWithLogitsLoss(reduction="sum")

# Per epoch
big_tick = perf_counter()

# For this amount of data, 5 epochs gives a decent indication of what kind
# of performance to expect, it seems.
for epoch in range(7):

    print(f"Starting epoch {epoch+1 :2} ------")

    # Training

    whole_net.train()

    tick = perf_counter()
    running_loss = 0.0
    running_n = 0
    running_correct = 0

    for i, batch in enumerate(train_dl):

        optimizer.zero_grad()

        # Get a batch
        mx, hx, y, mask = batch

        mx = net.to_(mx, device)
        hx = net.to_(hx, device)
        mask = mask.to(device)
        y = y.to(device)

        y_hat = whole_net(mx, hx, mask)
        labels = y_hat > 0
        correct = (labels == y).sum()

        loss = criterion(y_hat, y)

        loss.backward()
        optimizer.step()

        running_correct += correct.item()
        running_loss += loss.item()
        running_n += y_hat.shape[0]

        if i % 50 == 49:
            print(
                f"Epoch {epoch + 1}, Batch {i+1 :4}: {running_loss / running_n :.3f} | ",
                f"Accuracy: {running_correct / running_n :.2f} | ",
                f"{running_n / (perf_counter() - tick) :6.0f} obs/sec | ",
                f"{perf_counter() - big_tick :.2f} s",
            )
            running_loss = 0.0
            running_n = 0
            running_correct = 0
            tick = perf_counter()

    # Validation

    whole_net.eval()

    with torch.no_grad():

        tick = perf_counter()
        valid_loss = 0.0
        valid_n = 0
        valid_correct = 0

        for i, batch in enumerate(validation_dl):

            # Get a batch
            mx, hx, y, mask = batch

            mx = net.to_(mx, device)
            hx = net.to_(hx, device)
            mask = mask.to(device)
            y = y.to(device)

            y_hat = whole_net(mx, hx, mask)
            labels = y_hat > 0
            correct = (labels == y).sum()

            loss = criterion(y_hat, y)

            valid_correct += correct.item()
            valid_loss += loss.item()
            valid_n += y_hat.shape[0]

        print(
            f"Epoch {epoch + 1} validation loss: {running_loss / running_n :.3f} | ",
            f"Accuracy: {valid_correct / valid_n :.2f} | ",
            f"{valid_n / (perf_counter() - tick) :6.0f} obs/sec | ",
            f"{perf_counter() - big_tick :.2f} s",
        )

Training: 244 batches of size 1024
Validation: 82 batches
Starting epoch  1 ------
Epoch 1, Batch   50: 0.680 |  Accuracy: 0.56 |    5160 obs/sec |  9.92 s
Epoch 1, Batch  100: 0.670 |  Accuracy: 0.58 |    8433 obs/sec |  15.99 s
Epoch 1, Batch  150: 0.661 |  Accuracy: 0.59 |    8331 obs/sec |  22.14 s
Epoch 1, Batch  200: 0.618 |  Accuracy: 0.62 |    7610 obs/sec |  28.87 s
Epoch 1 validation loss: 0.634 |  Accuracy: 0.51 |    5594 obs/sec |  49.10 s
Starting epoch  2 ------
Epoch 2, Batch   50: 0.582 |  Accuracy: 0.65 |    5375 obs/sec |  58.63 s
Epoch 2, Batch  100: 0.553 |  Accuracy: 0.68 |    7325 obs/sec |  65.62 s
Epoch 2, Batch  150: 0.538 |  Accuracy: 0.69 |    8054 obs/sec |  71.97 s
Epoch 2, Batch  200: 0.526 |  Accuracy: 0.70 |    8394 obs/sec |  78.07 s
Epoch 2 validation loss: 0.517 |  Accuracy: 0.70 |    6798 obs/sec |  95.87 s
Starting epoch  3 ------
Epoch 3, Batch   50: 0.513 |  Accuracy: 0.71 |    5709 obs/sec |  104.84 s
Epoch 3, Batch  100: 0.513 |  Accuracy: 0.72 