In [None]:
import os, sys

sys.path.append("../")

from typing import List, Dict, Union, Optional, Any, Tuple, Literal

import math

import torch
from torch.nn import functional as F

import lightning as L

from torch_geometric.data import HeteroData
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import conv, Sequential, summary
import torch_geometric.transforms as T
from torch_geometric.typing import EdgeType, NodeType

import torch_frame
from torch_frame import stype, NAStrategy
from torch_frame.nn import encoder
from torch_frame.nn import TabTransformerConv
from torch_frame.data import StatType


from db_transformer.nn import (
    BlueprintModel,
    EmbeddingTranscoder,
    SelfAttention,
    CrossAttentionConv,
    NodeApplied,
)
from db_transformer.nn.lightning import LightningWrapper
from db_transformer.nn.lightning.callbacks import BestMetricsLoggerCallback

from db_transformer.data.ctu_dataset import CTUDataset, CTU_REPOSITORY_DEFAULTS

%reload_ext autoreload
%autoreload 2

device = torch.device("cuda" if False and torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

In [None]:
dataset = CTUDataset("Chess", data_dir="../datasets", force_remake=False)

data = dataset.build_hetero_data(device, force_rematerilize=False)

n_total = data[dataset.defaults.target_table].y.shape[0]
data = T.RandomNodeSplit(split="train_rest", num_val=int(0.30 * n_total), num_test=0)(data)

In [None]:
target = dataset.defaults.target

num_neighbors = {edge_type: [30] * 5 for edge_type in data.collect("edge_index").keys()}

train_loader = NeighborLoader(
    data,
    num_neighbors=num_neighbors,
    batch_size=10000,
    input_nodes=(target[0], data[target[0]].train_mask),
    subgraph_type="bidirectional",
)

val_loader = NeighborLoader(
    data,
    num_neighbors=num_neighbors,
    batch_size=1000,
    input_nodes=(target[0], data[target[0]].val_mask),
    subgraph_type="bidirectional",
)

In [None]:
target = dataset.defaults.target

embed_dim = 64
num_layers = 3

df_dict = {k: table.df for k, table in dataset.db.table_dict.items()}


def squeeze_dict(x: torch.Tensor):
    return x.view(*x.shape[:-2], -1)


model = BlueprintModel(
    target=target,
    embed_dim=embed_dim,
    col_stats_per_table=data.collect("col_stats"),
    col_names_dict_per_table={k: tf.col_names_dict for k, tf in data.collect("tf").items()},
    edge_types=list(data.collect("edge_index").keys()),
    stype_encoder_dict={
        stype.categorical: encoder.EmbeddingEncoder(
            na_strategy=NAStrategy.MOST_FREQUENT,
        ),
        stype.numerical: encoder.LinearEncoder(
            na_strategy=NAStrategy.MEAN,
        ),
    },
    positional_encoding=False,
    per_column_embedding=False,
    num_gnn_layers=num_layers,
    table_transform=lambda i, node, cols: (
        torch.nn.Identity() if i == 0 else torch.nn.ReLU()
    ),
    table_transform_unique=True,
    table_combination=lambda i, edge, cols: conv.SAGEConv(
        (
            len(cols[0]) * int(embed_dim / 2**i),
            len(cols[1]) * int(embed_dim / 2**i),
        ),
        len(cols[1]) * int(embed_dim / 2 ** (i + 1)),
        aggr="sum",
    ),
    table_combination_unique=True,
    decoder_aggregation=torch.nn.Identity(),
    decoder=lambda cols: torch.nn.Sequential(
        torch.nn.Linear(
            len(cols) * int(embed_dim / 2**num_layers),
            len(data[target[0]].col_stats[target[1]][StatType.COUNT][0]),
        ),
    ),
    output_activation=torch.nn.Softmax(dim=-1),
    positional_encoding_dropout=0.0,
    table_transform_dropout=0.0,
    table_combination_dropout=0.0,
    table_transform_residual=False,
    table_combination_residual=False,
    table_transform_norm=False,
    table_combination_norm=False,
).to(device)
print(summary(model, data.collect("tf"), data.collect("edge_index")))

In [None]:
lightning_model = LightningWrapper(model, dataset.defaults.target_table, lr=0.0001).to(
    device
)
trainer = L.Trainer(
    accelerator=device.type,
    devices=1,
    deterministic=False,
    # callbacks=[BestMetricsLoggerCallback()],
    max_epochs=1000,
    max_steps=-1,
)

trainer.fit(lightning_model, train_loader, val_dataloaders=val_loader)