In [1]:
import os, sys
sys.path.append('../')

from typing import List, Dict, Any

import torch

import lightning as L

from torch_geometric.data import HeteroData
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import conv
import torch_geometric.transforms as T
from torch_geometric.typing import EdgeType, NodeType

import torch_frame
from torch_frame.data import StatType

from db_transformer.nn.embedder import TabTransformerEmbedder, TromptEmbedder, FTTransformerEmbedder
from db_transformer.data.ctu_dataset import CTUDataset

device = torch.device('cuda' if False and torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

%reload_ext autoreload
%autoreload 2

In [2]:
class Model(torch.nn.Module):
    def __init__(
        self,
        target_table: str,
        table_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]],
        table_col_names_dict: Dict[torch_frame.stype, List[str]],
        edge_types: List[EdgeType],
        embed_dim: int,
        out_dim: int,
        num_embedder_layers: int = 1,
        num_embedder_heads: int = 1,
        num_transformer_heads: int = 1,
        embedder_attn_dropout: float = 0.0,
    ):
        super().__init__()

        self.target_table = target_table

        self.embedder = TabTransformerEmbedder(
            table_col_stats=table_col_stats,
            table_col_names_dict=table_col_names_dict,
            embed_dim=embed_dim,
            num_layers=num_embedder_layers,
            num_heads=num_embedder_heads,
            attn_dropout=embedder_attn_dropout,
        )

        convs = {
            edge_type: conv.TransformerConv(
                in_channels=embed_dim, out_channels=embed_dim, heads=num_transformer_heads
            )
            for edge_type in edge_types
        }
        self.conv = conv.HeteroConv(convs) if len(edge_types) > 0 else lambda x, edge: x

        self.out_lin = torch.nn.Linear(embed_dim * num_transformer_heads, out_dim)

    def forward(
        self,
        tf_dict: Dict[str, torch_frame.TensorFrame],
        edge_dict: Dict[str, torch.Tensor],
    ):
        x_dict = self.embedder(tf_dict)
        x_dict = self.conv(x_dict, edge_dict)

        x_target = x_dict[self.target_table]

        x_target = self.out_lin(x_target)
        return torch.softmax(x_target, dim=-1)

In [46]:
class LightningModel(L.LightningModule):
    def __init__(self, model: TabTransformerEmbedder, target_table: str, lr: float) -> None:
        super().__init__()
        self.model = model
        self.target_table = target_table
        self.lr = lr
        self.loss_module = torch.nn.CrossEntropyLoss()

    def forward(self, data: HeteroData):
        out = self.model(data.collect("tf"), data.collect("edge_index", allow_empty=True))

        target = data[self.target_table].y
        loss = self.loss_module(out, target)
        acc = (out.argmax(dim=-1) == target).type(torch.float).mean()
        return loss, acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, batch):
        loss, acc = self.forward(batch)
        batch_size = batch[self.target_table].y.shape[0]
        self.log("train_loss", loss, batch_size=batch_size, prog_bar=True)
        self.log("train_acc", acc, batch_size=batch_size, prog_bar=True)
        return loss

    def validation_step(self, batch):
        _, acc = self.forward(batch)
        batch_size = batch[self.target_table].y.shape[0]

        self.log("val_acc", acc, batch_size=batch_size, prog_bar=True)

    def test_step(self, batch):
        _, acc = self.forward(batch)
        batch_size = batch[self.target_table].y.shape[0]

        self.log("test_acc", acc, batch_size=batch_size, prog_bar=True)

In [None]:
dataset = CTUDataset("Chess", data_dir="../datasets", force_remake=False)

data = dataset.build_hetero_data(device)

n_total = data[dataset.defaults.target_table].y.shape[0]
data = T.RandomNodeSplit(split="train_rest", num_val=int(0.30 * n_total), num_test=0)(data)

In [48]:
target = dataset.defaults.target

model = Model(
    target_table=target[0],
    table_col_stats=data.collect("col_stats"),
    table_col_names_dict={k: tf.col_names_dict for k, tf in data.collect("tf").items()},
    edge_types=data.collect("edge_index", allow_empty=True).keys(),
    embed_dim=64,
    num_embedder_layers=4,
    num_embedder_heads=4,
    num_transformer_heads=1,
    embedder_attn_dropout=0.1,
    out_dim=dataset.schema[target[0]].columns[target[1]].card,
).to(device)
lightning_model = LightningModel(model, dataset.defaults.target_table, lr=0.0001).to(device)

In [49]:
train_loader = NeighborLoader(
    data,
    num_neighbors=[30] * 5,
    batch_size=1000,
    input_nodes=(target[0], data[target[0]].train_mask),
)

val_loader = NeighborLoader(
    data,
    num_neighbors=[30] * 5,
    batch_size=1000,
    input_nodes=(target[0], data[target[0]].val_mask),
)

In [None]:
trainer = L.Trainer(
    accelerator=device.type,
    devices=1,
    deterministic=False,
    max_epochs=100,
    max_steps=-1,
)

trainer.fit(lightning_model, train_loader, val_dataloaders=val_loader)