In [None]:
from wn import net
from wn.data import MatchDataset

import torch
from torch.utils.data import DataLoader
from torch import nn
from torch.optim import AdamW

import pickle
import os
from time import perf_counter

In [None]:
# Load the tensorized data
with open("data/tensor_list.pkl", "rb") as f:
    input_data, y  = pickle.load(f)

ds = MatchDataset(input_data, y)

# Load the interface
with open("data/match_interface.pkl", "rb") as f:
    match_interface = pickle.load(f)

In [None]:
# Set up the network for a test.

# Special tabular input layer
input_layer = net.TabularInputLayer(
    interface=match_interface,
    col_encoding_size=32,
    embedding_size=96,
    append_cls=True,
)

# Transformer encoder
tr = nn.TransformerEncoder(
    encoder_layer = nn.TransformerEncoderLayer(
        d_model=128,
        nhead=4,
        dim_feedforward=256,
        batch_first=True,
    ),
    num_layers=4,
)

# Ugly convenience layer
cls_selector = net.SelectCLSEncoding()

# Output layer
output_layers = net.OutputLayers(128, 3, 1)

whole_net = nn.Sequential(
    input_layer,
    tr,
    cls_selector,
    output_layers,
)

n_weights = sum([p.numel() for p in whole_net.parameters() if p.requires_grad])
print(f"Network has {n_weights} weights.")

# Setup device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
whole_net.to(device)

print(f"Using {device}")

In [None]:
# Example training step
# Create a dataloader, optimizer, and criterion

dl = DataLoader(ds, batch_size=64, shuffle=True)
print(f"{len(dl)} batches")

optimizer = AdamW([p for p in whole_net.parameters() if p.requires_grad])
criterion = nn.BCEWithLogitsLoss(reduction="sum")

# Per epoch

for epoch in range(5):

    print(f"Starting epoch {epoch+1}")

    whole_net.train()

    tick = big_tick = perf_counter()
    running_loss = 0.0
    running_n = 0
    running_correct = 0

    for i, batch in enumerate(dl):

        optimizer.zero_grad()

        # Get a batch
        x, y = batch

        x = net.to_(x, device)
        y = y.to(device)

        y_hat = whole_net(x)
        labels = y_hat > 0
        correct = (labels == y).sum()

        loss = criterion(y_hat, y)

        loss.backward()
        optimizer.step()

        running_correct += correct.item()
        running_loss += loss.item()
        running_n += y_hat.shape[0]
        
        if i % 500 == 499:
            print(
                f"Epoch {epoch + 1}, Batch {i+1 :4}: {running_loss / running_n :.3f} | ",
                f"Accuracy: {running_correct / running_n :.2f} | ",
                f"{running_n / (perf_counter() - tick) :6.1f} obs/sec | ",
                f"{perf_counter() - big_tick :.2f} seconds"
            )
            running_loss = 0.0
            running_n = 0
            running_correct = 0
            tick = perf_counter()