In [None]:
import os, sys

sys.path.append("../")

from typing import List, Dict, Union, Optional, Any, Tuple, Literal

import math
import random

import pandas as pd

import numpy as np

import torch
from torch.nn import functional as F

import lightning as L
from lightning.pytorch import seed_everything

import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import conv, Sequential, summary
import torch_geometric.transforms as T
from torch_geometric.typing import EdgeType, NodeType

import torch_frame
from torch_frame import stype, NAStrategy
from torch_frame.nn import encoder

from torch_frame.nn import TabTransformerConv
from torch_frame.data import StatType


from db_transformer.nn import (
    BlueprintModel,
    EmbeddingTranscoder,
    SelfAttention,
    CrossAttentionConv,
    NodeApplied,
)
from db_transformer.nn.lightning import LightningWrapper
from db_transformer.nn.lightning.callbacks import BestMetricsLoggerCallback

from db_transformer.data.ctu_dataset import CTUDataset, CTU_REPOSITORY_DEFAULTS, TaskType

from experiments.blueprint_instances import create_blueprint_model


%reload_ext autoreload
%autoreload 2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
df = pd.read_csv("../datasets/info.csv")
# df = df.sort_values("n_target_tuples", ascending=False, inplace=False)
df["size"] = ""
df.loc[df["n_target_tuples"].between(0, 1000), "size"] = "0-1000"
df.loc[df["n_target_tuples"].between(1001, 10000), "size"] = "1001-10000"
df.loc[df["n_target_tuples"].between(10001, 100000), "size"] = "10001-100000"
df.loc[df["n_target_tuples"].between(100001, 1000000), "size"] = "100001-1000000"
df.loc[df["n_target_tuples"].between(1000001, 10000000), "size"] = "1000001-10000000"


print(
    "tiny:",
    df.loc[df["size"] == "0-1000"].loc[df["task"] == "categorical"]["dataset"].values,
    "\n",
)
print(
    "small:",
    df.loc[df["size"] == "1001-10000"].loc[df["task"] == "categorical"]["dataset"].values,
    "\n",
)
print(
    "medium:",
    df.loc[df["size"] == "10001-100000"].loc[df["task"] == "categorical"]["dataset"].values,
    "\n",
)
print(
    "big:",
    df.loc[df["size"] == "100001-1000000"]
    .loc[df["task"] == "categorical"]["dataset"]
    .values,
    "\n",
)
print(
    "giant:",
    df.loc[df["size"] == "1000001-10000000"]
    .loc[df["task"] == "categorical"]["dataset"]
    .values,
    "\n",
)

In [None]:
dataset = CTUDataset("PremierLeague", data_dir="../datasets", force_remake=False)

data = dataset.build_hetero_data(force_rematerilize=False, no_text_emebedding=False)

n_total = data[dataset.defaults.target_table].y.shape[0]

data = T.RandomNodeSplit(split="train_rest", num_val=int(0.30 * n_total), num_test=0)(data)

In [None]:
target = dataset.defaults.target

seed_everything(42)

num_neighbors = {
    edge_type: [30] * 5 for edge_type in data.collect("edge_index", allow_empty=True).keys()
}

train_loader = NeighborLoader(
    data,
    num_neighbors=num_neighbors,
    batch_size=512,
    input_nodes=(target[0], data[target[0]].train_mask),
    subgraph_type="bidirectional",
    shuffle=True,
)

val_loader = NeighborLoader(
    data,
    num_neighbors=num_neighbors,
    batch_size=512,
    input_nodes=(target[0], data[target[0]].val_mask),
    subgraph_type="bidirectional",
    shuffle=True,
)

In [None]:
sample: HeteroData = next(iter(train_loader))

y: torch.Tensor = sample[target[0]].y
print(y.unique(return_counts=True))
# print(y.min(), y.max())

In [None]:
edge_types = list(data.collect("edge_index", allow_empty=True).keys())

model = create_blueprint_model(
    "transformer",
    dataset.defaults,
    {node: tf.col_names_dict for node, tf in data.collect("tf").items() if tf.num_rows > 0},
    edge_types,
    data.collect("col_stats"),
    dict(
        embed_dim=64,
        encoder="with_time",
        gnn_layers=2,
        mlp_dims=[64, 64],
        num_heads=4,
        residual=True,
        batch_norm=True,
        dropout=0,
    ),
).to(device)


print(
    summary(
        model.cpu(),
        sample.collect("tf"),
        sample.collect("edge_index", allow_empty=True),
        max_depth=10,
    )
)

In [None]:
is_regression = dataset.defaults.task == TaskType.REGRESSION

lightning_model = LightningWrapper(
    model, dataset.defaults.target_table, lr=0.0001, task_type=dataset.defaults.task
).to(device)

metric = "mae" if is_regression else "accuracy"
cmp = "min" if is_regression else "max"

trainer = L.Trainer(
    accelerator=device.type,
    devices=1,
    deterministic=False,
    max_epochs=1000,
    max_steps=-1,
    enable_checkpointing=False,
    logger=False,
    callbacks=[
        BestMetricsLoggerCallback(
            monitor=f"val_{metric}",
            cmp=cmp,
            metrics=[
                # "train_loss",
                # "val_loss",
                # "test_loss",
                # f"train_{metric}",
                f"val_{metric}",
                f"test_{metric}",
            ],
        ),
    ],
)

trainer.fit(lightning_model, train_loader, val_dataloaders=val_loader)