In [69]:
import os, sys

sys.path.append("../")

from typing import List, Dict, Union, Optional, Any, Tuple

import torch
from torch.nn import functional as F

import lightning as L

from torch_geometric.data import HeteroData
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import conv, Sequential
import torch_geometric.transforms as T
from torch_geometric.typing import EdgeType, NodeType

import torch_frame
from torch_frame import stype, NAStrategy
from torch_frame.nn import StypeEncoder, EmbeddingEncoder, LinearEncoder, TimestampEncoder
from torch_frame.data import StatType



from db_transformer.nn import (
    BlueprintModel,
    TabTransformerEmbedder,
    EmbeddingTranscoder,
    SelfAttention,
    NodeApplied
)
from db_transformer.data.ctu_dataset import CTUDataset

%reload_ext autoreload
%autoreload 2

device = torch.device("cuda" if False and torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

In [46]:
class LightningModel(L.LightningModule):
    def __init__(self, model: TabTransformerEmbedder, target_table: str, lr: float) -> None:
        super().__init__()
        self.model = model
        self.target_table = target_table
        self.lr = lr
        self.loss_module = torch.nn.CrossEntropyLoss()

    def forward(self, data: HeteroData):
        out = self.model(data.collect("tf"), data.collect("edge_index", allow_empty=True))

        target = data[self.target_table].y
        loss = self.loss_module(out, target)
        acc = (out.argmax(dim=-1) == target).type(torch.float).mean()
        return loss, acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, batch):
        loss, acc = self.forward(batch)
        batch_size = batch[self.target_table].y.shape[0]
        self.log("train_loss", loss, batch_size=batch_size, prog_bar=True)
        self.log("train_acc", acc, batch_size=batch_size, prog_bar=True)
        return loss

    def validation_step(self, batch):
        _, acc = self.forward(batch)
        batch_size = batch[self.target_table].y.shape[0]

        self.log("val_acc", acc, batch_size=batch_size, prog_bar=True)

    def test_step(self, batch):
        _, acc = self.forward(batch)
        batch_size = batch[self.target_table].y.shape[0]

        self.log("test_acc", acc, batch_size=batch_size, prog_bar=True)

In [47]:
dataset = CTUDataset("Chess", data_dir="../datasets", force_remake=False)

data = dataset.build_hetero_data(device, force_rematerilize=False)

n_total = data[dataset.defaults.target_table].y.shape[0]
data = T.RandomNodeSplit(split="train_rest", num_val=int(0.30 * n_total), num_test=0)(data)

Building data:   0%|          | 0/2 [00:00<?, ?it/s]

Table game has stypes:
	categorical: ['b1', 'b2', 'b3', 'b4', 'event', 'opening', 'site', 'w1', 'w2', 'w3', 'w4']
	timestamp: ['event_date']
	numerical: ['BlackElo', 'whiteElo']
	embedding: ['ECO', 'b10', 'b5', 'b6', 'b7', 'b8', 'b9', 'black', 'round', 'w10', 'w5', 'w6', 'w7', 'w8', 'w9', 'white']
Table opening has stypes:
	categorical: ['b1', 'b2', 'b3', 'b4', 'code', 'name', 'w1', 'w2', 'w3', 'w4']
	embedding: ['variation']


In [48]:
target = dataset.defaults.target

num_neighbors = {edge_type: [30] * 5 for edge_type in data.collect("edge_index").keys()}

train_loader = NeighborLoader(
    data,
    num_neighbors=num_neighbors,
    batch_size=100,
    input_nodes=(target[0], data[target[0]].train_mask),
    subgraph_type="bidirectional",
)

val_loader = NeighborLoader(
    data,
    num_neighbors=num_neighbors,
    batch_size=100,
    input_nodes=(target[0], data[target[0]].val_mask),
    subgraph_type="bidirectional",
)

In [70]:
target = dataset.defaults.target

model = BlueprintModel(
    target=target,
    embed_dim=64,
    col_stats_per_table=data.collect("col_stats"),
    col_names_dict_per_table={k: tf.col_names_dict for k, tf in data.collect("tf").items()},
    edge_types=list(data.collect("edge_index").keys()),
    stype_embedder_dict={
        stype.categorical: EmbeddingEncoder(
            na_strategy=NAStrategy.MOST_FREQUENT,
        ),
        stype.numerical: LinearEncoder(
            na_strategy=NAStrategy.MEAN,
        ),
        stype.embedding: EmbeddingTranscoder(
            in_channels=300,
        ),
        stype.timestamp: TimestampEncoder(),
    },
    num_gnn_layers=3,
    table_transform=Sequential(
        "x_dict",
        [
            (SelfAttention(64, list(data.collect("tf").keys())), "x_dict -> x_dict2"),
            (
                lambda x_dict1, x_dict2: {
                    k: x_dict1[k] + x_dict2[k] for k in x_dict1.keys()
                },
                "x_dict, x_dict2 -> x_dict",
            ),
            (
                NodeApplied(
                    lambda node_type: torch.nn.BatchNorm1d(data[]), list(data.collect("tf").keys())
                ),
                "x_dict -> x_dict",
            ),
        ],
    ),
).to(device)
lightning_model = LightningModel(model, dataset.defaults.target_table, lr=0.0001).to(device)

['game', 'opening']


In [71]:
trainer = L.Trainer(
    accelerator=device.type,
    devices=1,
    deterministic=False,
    max_epochs=100,
    max_steps=-1,
)

trainer.fit(lightning_model, train_loader, val_dataloaders=val_loader)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/jakub/miniconda3/envs/deep-db-learning/lib/python3.10/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name        | Type             | Params
-------------------------------------------------
0 | model       | BlueprintModel   | 300 K 
1 | loss_module | CrossEntropyLoss | 0     
-------------------------------------------------
300 K     Trainable params
0         Non-trainable params
300 K     Total params
1.200     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/jakub/miniconda3/envs/deep-db-learning/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


RuntimeError: running_mean should contain 30 elements not 64