In [None]:
from wn import net
from wn.data import MatchDataset

import torch
from torch.utils.data import DataLoader
from torch import nn
from torch.optim import AdamW

import pickle
import os

In [None]:
# Load the tensorized data
with open("data/tensor_list.pkl", "rb") as f:
    input_data, y  = pickle.load(f)

ds = MatchDataset(input_data, y)

# Load the interface
with open("data/match_interface.pkl", "rb") as f:
    match_interface = pickle.load(f)

In [None]:
# Set up the network for a test.

# Special tabular input layer
input_layer = net.TabularInputLayer(
    interface=match_interface,
    col_encoding_size=32,
    embedding_size=96,
    append_cls=True,
)

# Transformer encoder
tr = nn.TransformerEncoder(
    encoder_layer = nn.TransformerEncoderLayer(
        d_model=128,
        nhead=4,
        dim_feedforward=256,
        batch_first=True,
    ),
    num_layers=4,
)

# Ugly convenience layer
cls_selector = net.SelectCLSEncoding()

# Output layer
output_layers = net.OutputLayers(128, 3, 1)

whole_net = nn.Sequential(
    input_layer,
    tr,
    cls_selector,
    output_layers,
)

print(f"Network has {sum([p.numel() for p in whole_net.parameters() if p.requires_grad])} weights.")

In [None]:
# Example training step
# Create a dataloader, optimizer, and criterion
dl = DataLoader(ds, batch_size=64, shuffle=True)
print(f"{len(dl)} batches")

optimizer = AdamW([p for p in whole_net.parameters() if p.requires_grad])
criterion = nn.BCEWithLogitsLoss(reduction="sum")

whole_net.train()

for i, batch in enumerate(dl):
    optimizer.zero_grad()

    # Get a batch
    x, y = batch

    y_hat = whole_net(x)

    loss = criterion(y_hat, y)
    loss.sum().backward()

    optimizer.step()
    
    if i % 20 == 19:
        print(loss.sum())