In [8]:
import yaml
import torch

from tabgns.utils import set_seed
from tabgns.mlp import MLP
from tabgns.trainer import TabGNSTrainer
from load_data import load_higgs, split_train_valid_test

In [9]:
with open("conf/simple.yaml", "r") as f:
    config_dict = yaml.safe_load(f)

set_seed(config_dict['general']['seed'])

In [10]:
X, y = load_higgs()
(X_train, y_train), (X_valid, y_valid), (X_test, y_test) = split_train_valid_test(X, y, dl=False, seed=config_dict['general']['seed'])

config_dict['general']['inp_dim'] = X.shape[1]
config_dict['general']['out_dim'] = 2 if config_dict['general']['dataset'] == 'higgs' else 7
config_dict['layer_widths'] = config_dict['search_space']['layer_widths']


seed: 42


In [11]:
mlp = MLP(config_dict)
mlp.fit(X_train, y_train, X_valid, y_valid)
y_pred = mlp.predict(torch.tensor(X_test, dtype=torch.float32))

Early stopper initialized. Patience: 20


Training:   0%|          | 0/300 [00:00<?, ?it/s]

Early stopping


In [12]:
# Convert y_test to class indices if it's one-hot encoded
y_test_tensor = torch.tensor(y_test)
if y_test_tensor.ndim > 1 and y_test_tensor.size(-1) > 1:
  y_test_indices = torch.argmax(y_test_tensor, dim=1)
else:
  y_test_indices = y_test_tensor.squeeze() if y_test_tensor.ndim > 1 else y_test_tensor
  y_test_indices = y_test_indices.long()

print(mlp)
mlp_acc = (y_pred == y_test_indices).float().mean().item()
print(f"MLP Accuracy: {mlp_acc}")

MLP(
  (net): Sequential(
    (0): Linear(in_features=28, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=2, bias=True)
  )
)
MLP Accuracy: 0.7100458741188049


In [13]:
trainer = TabGNSTrainer(config_dict)
mlp_nas = trainer.fit(X_train, y_train, X_valid, y_valid)
y_pred = mlp_nas.predict(torch.tensor(X_test, dtype=torch.float32))

print(mlp_nas)
nas_acc = (y_pred == y_test_indices).float().mean().item()
print(f"TabGNS Accuracy: {nas_acc}")

Early stopper initialized. Patience: 20


  0%|          | 0/300 [00:00<?, ?it/s]

Early stopping
Early stopper initialized. Patience: 20


Training:   0%|          | 0/300 [00:00<?, ?it/s]

Early stopping
MLP(
  (net): Sequential(
    (0): Linear(in_features=28, out_features=55, bias=True)
    (1): ReLU()
    (2): Linear(in_features=55, out_features=22, bias=True)
    (3): ReLU()
    (4): Linear(in_features=22, out_features=2, bias=True)
  )
)
TabGNS Accuracy: 0.717797040939331


In [14]:
print(f"""
MLP:
    Architecture: {mlp.get_architecture()}
    Params: {sum([a * b for a, b in zip([config_dict['general']['inp_dim']] + mlp.get_architecture(), mlp.get_architecture() + [config_dict['general']['out_dim']])]):,}
    Accuracy: {mlp_acc}
TabGNS:
    Architecture: {mlp_nas.get_architecture()}
    Params: {sum([a * b for a, b in zip([config_dict['general']['inp_dim']] + mlp_nas.get_architecture(), mlp_nas.get_architecture() + [config_dict['general']['out_dim']])]):,}
    Accuracy: {nas_acc}
""")


MLP:
    Architecture: [256, 256]
    Params: 73,216
    Accuracy: 0.7100458741188049
TabGNS:
    Architecture: [55, 22]
    Params: 2,794
    Accuracy: 0.717797040939331

