In [2]:
import os, sys
sys.path.append('../')

from typing import List, Dict, Any

import torch

import lightning as L

from torch_geometric.data import HeteroData
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import conv
import torch_geometric.transforms as T
from torch_geometric.typing import EdgeType, NodeType

import torch_frame
from torch_frame.data import StatType

from db_transformer.nn.embedder import TabTransformerEmbedder
from db_transformer.data.ctu_dataset import CTUDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

%reload_ext autoreload
%autoreload 2

In [2]:
class Model(torch.nn.Module):
    def __init__(
        self,
        target_table: str,
        table_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]],
        table_col_names_dict: Dict[torch_frame.stype, List[str]],
        edge_types: List[EdgeType],
        embed_dim: int,
        num_embedder_layers: int,
        num_embedder_heads: int,
        embedder_attn_dropout: float,
        out_dim: int,
    ):
        super().__init__()

        self.target_table = target_table

        self.table_embedder = TabTransformerEmbedder(
            table_col_stats=table_col_stats,
            table_col_names_dict=table_col_names_dict,
            embed_dim=embed_dim,
            num_transformer_layers=num_embedder_layers,
            num_transformer_heads=num_embedder_heads,
            attn_dropout=embedder_attn_dropout,
        )

        # TODO: This can be also a cross-attention layer
        convs = {
            edge_type: conv.SAGEConv(in_channels=embed_dim, out_channels=embed_dim)
            for edge_type in edge_types
        }
        self.conv = conv.HeteroConv(convs)

        self.out_lin = torch.nn.Linear(embed_dim, out_dim)

    def forward(
        self,
        tf_dict: Dict[str, torch_frame.TensorFrame],
        edge_dict: Dict[str, torch.Tensor],
    ):
        x_dict = self.table_embedder(tf_dict)
        x_dict = self.conv(x_dict, edge_dict)

        x_target = x_dict[self.target_table]

        x_target = self.out_lin(x_target)
        return torch.softmax(x_target, dim=-1)

In [3]:
class LightningModel(L.LightningModule):
    def __init__(self, model: TabTransformerEmbedder, target_table: str, lr: float) -> None:
        super().__init__()
        self.model = model
        self.target_table = target_table
        self.lr = lr
        self.loss_module = torch.nn.CrossEntropyLoss()

    def forward(self, data: HeteroData):
        out = self.model(data.collect("tf"), data.collect("edge_index"))

        target = data[self.target_table].y
        loss = self.loss_module(out, target)
        acc = (out.argmax(dim=-1) == target).sum().float() / target.shape[0]
        return loss, acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, batch):
        loss, acc = self.forward(batch)
        batch_size = batch[self.target_table].y.shape[0]
        self.log("train_loss", loss, batch_size=batch_size, prog_bar=True)
        self.log("train_acc", acc, batch_size=batch_size, prog_bar=True)
        return loss

    def validation_step(self, batch):
        _, acc = self.forward(batch)

        self.log("val_acc", acc, batch_size=1, prog_bar=True)

    def test_step(self, batch):
        _, acc = self.forward(batch)

        self.log("test_acc", acc, batch_size=1, prog_bar=True)

In [4]:
dataset = CTUDataset("CiteSeer", data_dir="../datasets", force_remake=False)

data = dataset.build_hetero_data(device)

n_total = data[dataset.defaults.target_table].y.shape[0]
data = T.RandomNodeSplit(split="train_rest", num_val=int(0.30 * n_total), num_test=0)(data)

Table content has stypes: {'paper_id': <stype.categorical: 'categorical'>, 'word_cited_id': <stype.categorical: 'categorical'>}
Table cites has stypes: {'cited_paper_id': <stype.text_embedded: 'text_embedded'>, 'citing_paper_id': <stype.text_embedded: 'text_embedded'>}
Table paper has stypes: {'paper_id': <stype.text_embedded: 'text_embedded'>, 'class_label': <stype.categorical: 'categorical'>}


In [5]:
target = dataset.defaults.target

model = Model(
    target_table=target[0],
    table_col_stats=data.collect("col_stats"),
    table_col_names_dict={k: tf.col_names_dict for k, tf in data.collect("tf").items()},
    edge_types=data.collect("edge_index").keys(),
    embed_dim=64,
    num_embedder_layers=4,
    num_embedder_heads=4,
    embedder_attn_dropout=0.1,
    out_dim=dataset.schema[target[0]].columns[target[1]].card,
).to(device)
lightning_model = LightningModel(model, dataset.defaults.target_table, lr=0.0001).to(device)



In [6]:
train_loader = NeighborLoader(
    data,
    num_neighbors=[30] * 5,
    batch_size=10000,
    input_nodes=(target[0], data[target[0]].train_mask),
)

val_loader = NeighborLoader(
    data,
    num_neighbors=[30] * 5,
    batch_size=10000,
    input_nodes=(target[0], data[target[0]].val_mask),
)

In [7]:
trainer = L.Trainer(
    accelerator=device.type,
    devices=1,
    deterministic=False,
    max_epochs=100,
    max_steps=-1,
)

trainer.fit(lightning_model, train_loader, val_dataloaders=val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/jakub/miniconda3/envs/deep-db-learning/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type             | Params
-------------------------------------------------
0 | model       | Model            | 917 K 
1 | loss_module | CrossEntropyLoss | 0     
-------------------------------------------------
917 K     Train

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/jakub/miniconda3/envs/deep-db-learning/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/jakub/miniconda3/envs/deep-db-learning/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/jakub/miniconda3/envs/deep-db-learning/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/home/jakub/miniconda3/envs/deep-db-learning/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
