In [2]:
import os, sys

sys.path.append("../")

from typing import List, Dict, Union, Optional, Any, Tuple, Literal

import math
import random

import pandas as pd

import numpy as np

import torch
from torch.nn import functional as F

import lightning as L
from lightning.pytorch import seed_everything

import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.loader import NeighborLoader, HGTLoader
from torch_geometric.nn import conv, Sequential, summary
import torch_geometric.transforms as T
from torch_geometric.typing import EdgeType, NodeType

import torch_frame
from torch_frame import stype, NAStrategy
from torch_frame.nn import encoder

from torch_frame.nn import TabTransformerConv
from torch_frame.data import StatType


from db_transformer.nn import (
    BlueprintModel,
    EmbeddingTranscoder,
    SelfAttention,
    CrossAttentionConv,
    NodeApplied,
    AttentionAggregation
)
from db_transformer.nn.lightning import LightningWrapper
from db_transformer.nn.lightning.callbacks import BestMetricsLoggerCallback

from db_transformer.data.ctu_dataset import CTUDataset, CTU_REPOSITORY_DEFAULTS, TaskType

from experiments.blueprint_instances import create_blueprint_model


%reload_ext autoreload
%autoreload 2

In [13]:
dataset = CTUDataset("legalActs", data_dir="../datasets", force_remake=False)

data, col_stats_dict = dataset.build_hetero_data(
    force_rematerilize=False, no_text_emebedding=False
)

n_total = data[dataset.defaults.target_table].y.shape[0]

data = T.RandomNodeSplit(split="train_rest", num_val=int(0.30 * n_total), num_test=0)(data)

Building data:   0%|          | 0/5 [00:00<?, ?it/s]

Table scrapefix has stypes:
	categorical: ['contributor', 'fix_description']
Table people has stypes:
	categorical: ['court', 'jury']
	embedding: ['name']
Table legalact_people has stypes:
	categorical: ['__filler']
Table legalact_link has stypes:
	categorical: ['__filler']
Table legalacts has stypes:
	timestamp: ['LegalDate', 'MotiveDate', 'SendDate', 'StartDate', 'update']
	categorical: ['ActLink', 'CaseKind', 'Court', 'HighCourt', 'MotiveLink', 'ResultOfAppeal', 'Status', 'TypeOfDocument']
	numerical: ['ActNumber', 'ActYear', 'CaseNumber', 'OutNumber', 'YearHigherCourt']
	embedding: ['Judge', 'hash']


In [23]:
target = dataset.defaults.target

seed_everything(42)

total_samples = data[target[0]].y.shape[0]
scale_exponent = 2
min_batch_size = max(16, int(2 ** np.around(np.log2(total_samples / 500))))
batch_size = min(min_batch_size * 2**scale_exponent, 16384)
print(batch_size)

# train_loader = NeighborLoader(
train_loader = HGTLoader(
    data,
    # num_neighbors=[30] * 5,
    num_samples=[30] * 1,
    batch_size=batch_size,
    input_nodes=(target[0], data[target[0]].train_mask),
    shuffle=True,
)

# val_loader = NeighborLoader(
val_loader = HGTLoader(
    data,
    # num_neighbors=[30] * 5,
    num_samples=[30] * 1,
    batch_size=batch_size,
    input_nodes=(target[0], data[target[0]].val_mask),
    shuffle=True,
)

Seed set to 42


4096


In [24]:
sample: HeteroData = next(iter(train_loader))

y: torch.Tensor = sample[target[0]].y
print(y.unique(return_counts=True))
# print(y.min(), y.max())

(tensor([0, 1, 2, 3, 4, 5, 6]), tensor([2597,  804,  302,  272,   59,   31,   31]))


In [28]:
edge_types = list(data.collect("edge_index", allow_empty=True).keys())

del model
model = create_blueprint_model(
    "transformer",
    dataset.defaults,
    {node: tf.col_names_dict for node, tf in data.collect("tf").items() if tf.num_rows > 0},
    edge_types,
    col_stats_dict,
    dict(
        embed_dim=64,
        encoder="with_time",
        gnn_layers=5,
        mlp_dims=[64, 64],
        num_heads=1,
        residual=True,
        batch_norm=False,
        dropout=0.0,
    ),
)

In [29]:
print(
    summary(
        model,
        sample.collect("tf"),
        sample.collect("edge_index", allow_empty=True),
        max_depth=10,
    )
)

+-----------------------------------------------------------------------------------+------------------------------------------------+--------------------------------+----------+
| Layer                                                                             | Input Shape                                    | Output Shape                   | #Param   |
|-----------------------------------------------------------------------------------+------------------------------------------------+--------------------------------+----------|
| BlueprintModel                                                                    |                                                | [4096, 7]                      | 240,517  |
| ├─(embedder)Sequential                                                            |                                                |                                | 50,944   |
| │    └─(module_0)DBEmbedder                                                       |                    

In [30]:
is_regression = dataset.defaults.task == TaskType.REGRESSION
device = "cuda" if False and torch.cuda.is_available() else "cpu"

lightning_model = LightningWrapper(
    model, dataset.defaults.target_table, lr=0.0001, task_type=dataset.defaults.task
)

metric = "mae" if is_regression else "acc"
cmp = "min" if is_regression else "max"

print(metric, cmp)

trainer = L.Trainer(
    accelerator=device,
    deterministic=False,
    max_epochs=10,
    max_steps=-1,
    enable_checkpointing=False,
    logger=False,
    callbacks=[
        BestMetricsLoggerCallback(
            monitor=f"val_{metric}",
            cmp=cmp,
            metrics=[
                # "train_loss",
                # "val_loss",
                # "test_loss",
                f"train_{metric}",
                f"val_{metric}",
                f"test_{metric}",
            ],
        ),
    ],
)

trainer.fit(lightning_model, train_loader, val_dataloaders=val_loader)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name        | Type             | Params
-------------------------------------------------
0 | model       | BlueprintModel   | 240 K 
1 | loss_module | CrossEntropyLoss | 0     
-------------------------------------------------
240 K     Trainable params
0         Non-trainable params
240 K     Total params
0.962     Total estimated model params size (MB)


acc max


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]