In [8]:
from typing import Any, Dict, List

import torch

from torch_geometric.data import HeteroData
from torch_geometric.nn import MessagePassing, HeteroConv, MLP
from torch_geometric.typing import EdgeType, NodeType

from sentence_transformers import SentenceTransformer

import torch_frame
from torch_frame.nn import TabTransformer

# from db_transformer.schema import Schema
from db_transformer.data.relbench.ctu_dataset import CTUDataset

%reload_ext autoreload
%autoreload 2

In [13]:
dataset = CTUDataset("Chess", force_remake=False)

data, column_defs = dataset.build_hetero_data()

In [12]:
class Model(torch.nn.Module):
    pass

In [14]:
print(data['game'].tf)

TensorFrame(
  num_cols=33,
  num_rows=295,
  numerical (4): ['BlackElo', 'game_id', 'opening_id', 'whiteElo'],
  categorical (12): ['b1', 'b2', 'b3', 'b4', 'event', 'game_result', 'opening', 'site', 'w1', 'w2', 'w3', 'w4'],
  timestamp (1): ['event_date'],
  embedding (16): ['ECO', 'b10', 'b5', 'b6', 'b7', 'b8', 'b9', 'black', 'round', 'w10', 'w5', 'w6', 'w7', 'w8', 'w9', 'white'],
  has_target=False,
  device='cpu',
)


In [11]:
transformer = TabTransformer(
                  channels=32,
                  out_channels=2,
                  num_layers=2,
                  num_heads=2,
                  encoder_pad_size=2,
                  attn_dropout=0.3,
                  ffn_dropout=0.3,
                  col_stats=data['game'].col_stats,
                  col_names_dict=data['game'].tf.col_names_dict,
              )

In [10]:
transformer(data['game'].tf)

tensor([[ 3.2959e-01, -1.6237e-01],
        [ 2.2755e-01, -1.1533e-01],
        [-5.8350e-01, -6.7785e-01],
        [ 3.3720e-01, -2.8958e-01],
        [-6.3100e-01, -1.1450e-01],
        [ 6.4985e-01, -2.6905e-01],
        [-9.2802e-03, -1.2599e-02],
        [-5.8783e-02, -4.3250e-02],
        [ 1.4249e-01, -3.6134e-01],
        [ 8.5389e-02,  2.1043e-01],
        [-2.7181e-01, -2.9079e-01],
        [-2.0875e-01, -8.0574e-01],
        [ 3.8368e-01, -5.5429e-01],
        [-1.5287e-01, -3.1342e-01],
        [-7.6406e-01, -3.6313e-01],
        [ 2.5291e-01, -1.4300e-01],
        [ 3.8533e-01, -9.4152e-01],
        [-1.3602e-01, -6.6054e-01],
        [-1.3088e-01, -3.1372e-01],
        [-2.2780e-01,  2.0198e-01],
        [ 8.0682e-01, -1.4439e+00],
        [ 8.9984e-02, -4.2073e-01],
        [ 2.6843e-01, -3.9216e-01],
        [-1.0941e-01,  3.2122e-02],
        [-1.4794e-01, -6.4846e-01],
        [-1.5513e-01, -8.4613e-01],
        [ 1.2024e+00, -2.2786e-01],
        [ 5.3427e-01, -1.827