In [7]:
import os, sys
sys.path.append('../')

import torch

import lightning as L

from torch_geometric.data import HeteroData
from torch_geometric.loader import NeighborLoader
import torch_geometric.transforms as T

from db_transformer.nn.tabtransformer_gnn import TabTransformerGNN
from db_transformer.data.ctu_dataset import CTUDataset

%reload_ext autoreload
%autoreload 2

In [8]:
class LightningModel(L.LightningModule):
    def __init__(self, model: TabTransformerGNN, target_table: str, lr: float) -> None:
        super().__init__()
        self.model = model
        self.target_table = target_table
        self.lr = lr
        self.loss_module = torch.nn.CrossEntropyLoss()

    def forward(self, data: HeteroData):
        out = self.model(data.collect("tf"), data.collect("edge_index"))

        target = data[self.target_table].y
        loss = self.loss_module(out, target)
        acc = (out.argmax(dim=-1) == target).sum().float() / target.shape[0]
        return loss, acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, batch):
        loss, acc = self.forward(batch)
        batch_size = batch[self.target_table].y.shape[0]
        self.log("train_loss", loss, batch_size=batch_size, prog_bar=True)
        self.log("train_acc", acc, batch_size=batch_size, prog_bar=True)
        return loss

    def validation_step(self, batch):
        _, acc = self.forward(batch)

        self.log("val_acc", acc, batch_size=1, prog_bar=True)

    def test_step(self, batch):
        _, acc = self.forward(batch)

        self.log("test_acc", acc, batch_size=1, prog_bar=True)

In [9]:
dataset = CTUDataset("CiteSeer", data_dir="../datasets", force_remake=False)

data = dataset.build_hetero_data()

n_total = data[dataset.defaults.target_table].y.shape[0]
data = T.RandomNodeSplit(split="train_rest", num_val=int(0.30 * n_total), num_test=0)(data)

In [10]:
target = dataset.defaults.target

model = TabTransformerGNN(
    target_table=target[0],
    table_col_stats=data.collect("col_stats"),
    table_col_names_dict={k: tf.col_names_dict for k, tf in data.collect("tf").items()},
    edge_types=data.collect("edge_index").keys(),
    embed_dim=64,
    num_transformer_layers=4,
    num_transformer_heads=4,
    attn_dropout=0.1,
    out_dim=dataset.schema[target[0]].columns[target[1]].card,
)
lightning_model = LightningModel(model, dataset.defaults.target_table, lr=0.0001)



In [11]:
train_loader = NeighborLoader(
    data,
    num_neighbors=[30] * 5,
    batch_size=10000,
    input_nodes=(target[0], data[target[0]].train_mask),
)

val_loader = NeighborLoader(
    data,
    num_neighbors=[30] * 5,
    batch_size=10000,
    input_nodes=(target[0], data[target[0]].val_mask),
)

# lightning_model(batch)

In [12]:
trainer = L.Trainer(
    accelerator="cpu",
    devices=1,
    deterministic=True,
    max_epochs=100,
    max_steps=-1,
)

trainer.fit(lightning_model, train_loader, val_dataloaders=val_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name        | Type              | Params
--------------------------------------------------
0 | model       | TabTransformerGNN | 917 K 
1 | loss_module | CrossEntropyLoss  | 0     
--------------------------------------------------
917 K     Trainable params
0         Non-trainable params
917 K     Total params
3.670     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/jakub/miniconda3/envs/deep-db-learning/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


RuntimeError: torch.cat(): expected a non-empty list of Tensors

In [None]:

trainer.