In [None]:
import os, sys

sys.path.append("../")

from typing import List, Dict, Union, Optional, Any, Tuple

import math

import torch
from torch.nn import functional as F

import lightning as L

from torch_geometric.data import HeteroData
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import conv, Sequential, summary
import torch_geometric.transforms as T
from torch_geometric.typing import EdgeType, NodeType

import torch_frame
from torch_frame import stype, NAStrategy
from torch_frame.nn import encoder
from torch_frame.nn import TabTransformerConv
from torch_frame.data import StatType


from db_transformer.nn import (
    BlueprintModel,
    EmbeddingTranscoder,
    SelfAttention,
    CrossAttentionConv,
    NodeApplied,
)
from db_transformer.data.ctu_dataset import CTUDataset

%reload_ext autoreload
%autoreload 2

device = torch.device("cuda" if False and torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

In [None]:
class LightningModel(L.LightningModule):
    def __init__(self, model: BlueprintModel, target_table: str, lr: float) -> None:
        super().__init__()
        self.model = model
        self.target_table = target_table
        self.lr = lr
        self.loss_module = torch.nn.CrossEntropyLoss()

    def forward(self, data: HeteroData):
        out = self.model(data.collect("tf"), data.collect("edge_index", allow_empty=True))

        target = data[self.target_table].y
        loss = self.loss_module(out, target)
        acc = (out.argmax(dim=-1) == target).type(torch.float).mean()
        return loss, acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, batch):
        loss, acc = self.forward(batch)
        batch_size = batch[self.target_table].y.shape[0]
        self.log("train_loss", loss, batch_size=batch_size, prog_bar=True)
        self.log("train_acc", acc, batch_size=batch_size, prog_bar=True)
        return loss

    def validation_step(self, batch):
        _, acc = self.forward(batch)
        batch_size = batch[self.target_table].y.shape[0]

        self.log("val_acc", acc, batch_size=batch_size, prog_bar=True)

    def test_step(self, batch):
        _, acc = self.forward(batch)
        batch_size = batch[self.target_table].y.shape[0]

        self.log("test_acc", acc, batch_size=batch_size, prog_bar=True)

In [None]:
dataset = CTUDataset("medical", data_dir="../datasets", force_remake=True)

data = dataset.build_hetero_data(device, force_rematerilize=True)

n_total = data[dataset.defaults.target_table].y.shape[0]
data = T.RandomNodeSplit(split="train_rest", num_val=int(0.30 * n_total), num_test=0)(data)

In [None]:
target = dataset.defaults.target

num_neighbors = {edge_type: [30] * 5 for edge_type in data.collect("edge_index").keys()}

train_loader = NeighborLoader(
    data,
    num_neighbors=num_neighbors,
    batch_size=10000,
    input_nodes=(target[0], data[target[0]].train_mask),
    subgraph_type="bidirectional",
)

val_loader = NeighborLoader(
    data,
    num_neighbors=num_neighbors,
    batch_size=1000,
    input_nodes=(target[0], data[target[0]].val_mask),
    subgraph_type="bidirectional",
)

In [None]:
target = dataset.defaults.target

embed_dim = 64

model = BlueprintModel(
    target=target,
    embed_dim=embed_dim,
    col_stats_per_table=data.collect("col_stats"),
    col_names_dict_per_table={k: tf.col_names_dict for k, tf in data.collect("tf").items()},
    edge_types=list(data.collect("edge_index").keys()),
    stype_encoder_dict={
        stype.categorical: encoder.EmbeddingEncoder(
            na_strategy=NAStrategy.MOST_FREQUENT,
        ),
        stype.numerical: encoder.LinearEncoder(
            na_strategy=NAStrategy.MEAN,
        ),
        stype.embedding: EmbeddingTranscoder(in_channels=300),
    },
    positional_encoding=True,
    num_gnn_layers=3,
    table_transform=SelfAttention(embed_dim, 16),
    table_transform_unique=True,
    # table_transform=lambda i, node, cols: ExcelFormerConv(embed_dim, len(cols), 1),
    table_combination=CrossAttentionConv(embed_dim, 4),
    table_combination_unique=True,
    # table_combination=conv.TransformerConv(embed_dim, embed_dim, heads=4, dropout=0.1, root_weight=False),
    decoder_aggregation=lambda x: torch.reshape(x, (-1, math.prod(x.shape[1:]))),
    decoder=lambda cols: torch.nn.Sequential(
        torch.nn.Linear(
            embed_dim * len(cols),
            len(data[target[0]].col_stats[target[1]][StatType.COUNT][0]),
        ),
    ),
    output_activation=torch.nn.Softmax(dim=-1),
    positional_encoding_dropout=0.0,
    table_transform_dropout=0.1,
    table_combination_dropout=0.1,
    table_transform_residual=False,
    table_combination_residual=False,
    table_transform_norm=False,
    table_combination_norm=False,
).to(device)
print(summary(model, data.collect("tf"), data.collect("edge_index")))

In [None]:
lightning_model = LightningModel(model, dataset.defaults.target_table, lr=0.0001).to(device)
trainer = L.Trainer(
    accelerator=device.type,
    devices=1,
    deterministic=False,
    max_epochs=100,
    max_steps=-1,
)

trainer.fit(lightning_model, train_loader, val_dataloaders=val_loader)