In [None]:
import os, sys

sys.path.append("../")

from typing import List, Dict, Union, Optional, Any, Tuple, Literal

import math

import torch
from torch.nn import functional as F

import lightning as L

from torch_geometric.data import HeteroData
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import conv, Sequential, summary
import torch_geometric.transforms as T
from torch_geometric.typing import EdgeType, NodeType

import torch_frame
from torch_frame import stype, NAStrategy
from torch_frame.nn import encoder
from torch_frame.nn import TabTransformerConv
from torch_frame.data import StatType


from db_transformer.nn import (
    BlueprintModel,
    EmbeddingTranscoder,
    SelfAttention,
    CrossAttentionConv,
    NodeApplied,
)
from db_transformer.nn.lightning import LightningWrapper
from db_transformer.nn.lightning.callbacks import BestMetricsLoggerCallback

from db_transformer.data.ctu_dataset import CTUDataset, CTU_REPOSITORY_DEFAULTS

from experiments.blueprint_instances import create_blueprint_model


%reload_ext autoreload
%autoreload 2

device = torch.device("cuda" if False and torch.cuda.is_available() else "cpu")

In [None]:
dataset = CTUDataset("Chess", data_dir="../datasets", force_remake=False)

data = dataset.build_hetero_data(device, force_rematerilize=False, no_text_emebedding=True)

n_total = data[dataset.defaults.target_table].y.shape[0]
data = T.RandomNodeSplit(split="train_rest", num_val=int(0.30 * n_total), num_test=0)(data)

In [None]:
target = dataset.defaults.target

num_neighbors = {edge_type: [30] * 5 for edge_type in data.collect("edge_index").keys()}

train_loader = NeighborLoader(
    data,
    num_neighbors=num_neighbors,
    batch_size=10000,
    input_nodes=(target[0], data[target[0]].train_mask),
    subgraph_type="bidirectional",
)

val_loader = NeighborLoader(
    data,
    num_neighbors=num_neighbors,
    batch_size=1000,
    input_nodes=(target[0], data[target[0]].val_mask),
    subgraph_type="bidirectional",
)

sample = next(iter(train_loader))

In [None]:
model = create_blueprint_model(
    "honza",
    dataset.defaults,
    data,
    dict(embed_dim=64, num_layers=3, mlp_dims=[64], batch_norm=True),
).to(device)
print(summary(model, sample.collect("tf"), sample.collect("edge_index"), max_depth=10))

In [None]:
lightning_model = LightningWrapper(model, dataset.defaults.target_table, lr=0.0001).to(
    device
)
trainer = L.Trainer(
    accelerator=device.type,
    devices=1,
    deterministic=False,
    max_epochs=1000,
    max_steps=-1,
)

trainer.fit(lightning_model, train_loader, val_dataloaders=val_loader)