In [6]:
import os
os.chdir('../')

import transtab

# set random seed
transtab.random_seed(42)

In [7]:
# load a dataset and start vanilla supervised training
allset, trainset, valset, testset, cat_cols, num_cols, bin_cols = transtab.load_data(['credit-g', 'credit-approval'])

# build transtab classifier model
model = transtab.build_classifier(cat_cols, num_cols, bin_cols)

# start training
training_arguments = {
    'num_epoch':50,
    'eval_metric':'val_loss',
    'eval_less_is_better':True,
    'output_dir':'./checkpoint',
    'batch_size':128,
    'lr':1e-4,
    'weight_decay':1e-4,
    }
transtab.train(model, trainset[0], valset[0], **training_arguments)

# save model
model.save('./ckpt/pretrained')

########################################
openml data index: 31
load data from credit-g
# data: 1000, # feat: 20, # cate: 11,  # bin: 2, # numerical: 7, pos rate: 0.70
########################################
openml data index: 29
load data from credit-approval
# data: 690, # feat: 15, # cate: 9,  # bin: 0, # numerical: 6, pos rate: 0.56


Epoch:   0%|          | 0/50 [00:00<?, ?it/s]

epoch: 0, test val_loss: 0.574102
epoch: 0, train loss: 3.9759, lr: 0.000100, spent: 0.4 secs
epoch: 1, test val_loss: 0.565162
epoch: 1, train loss: 3.7812, lr: 0.000100, spent: 0.9 secs
epoch: 2, test val_loss: 0.576745
EarlyStopping counter: 1 out of 5
epoch: 2, train loss: 3.6560, lr: 0.000100, spent: 1.1 secs
epoch: 3, test val_loss: 0.566665
EarlyStopping counter: 2 out of 5
epoch: 3, train loss: 3.6539, lr: 0.000100, spent: 1.4 secs
epoch: 4, test val_loss: 0.548929
epoch: 4, train loss: 3.6118, lr: 0.000100, spent: 1.7 secs
epoch: 5, test val_loss: 0.545800
epoch: 5, train loss: 3.5634, lr: 0.000100, spent: 2.2 secs
epoch: 6, test val_loss: 0.545121
epoch: 6, train loss: 3.5035, lr: 0.000100, spent: 2.4 secs
epoch: 7, test val_loss: 0.529130
epoch: 7, train loss: 3.4372, lr: 0.000100, spent: 2.7 secs
epoch: 8, test val_loss: 0.525149
epoch: 8, train loss: 3.3768, lr: 0.000100, spent: 3.0 secs
epoch: 9, test val_loss: 0.518042
epoch: 9, train loss: 3.3204, lr: 0.000100, spent: 3

2022-10-05 08:35:04.023 | INFO     | transtab.trainer:train:136 - load best at last from ./checkpoint
2022-10-05 08:35:04.042 | INFO     | transtab.trainer:save_model:243 - saving model checkpoint to ./checkpoint
2022-10-05 08:35:04.167 | INFO     | transtab.trainer:train:141 - training complete, cost 13.1 secs.


epoch: 38, test val_loss: 0.503903
EarlyStopping counter: 5 out of 5
early stopped


In [8]:
# now let's use another data and try to leverage the pretrained model for finetuning
# here we have loaded the required data `credit-approval` before, no need to load again.

# load the pretrained model
model.load('./ckpt/pretrained')

# update model's categorical/numerical/binary column dict
# need to specify the number of classes if the new dataset has different # of classes from the 
# pretrained one.
model.update({'cat':cat_cols,'num':num_cols,'bin':bin_cols, 'num_class':2})

2022-10-05 08:35:04.352 | INFO     | transtab.modeling_transtab:load:773 - missing keys: []
2022-10-05 08:35:04.354 | INFO     | transtab.modeling_transtab:load:774 - unexpected keys: []
2022-10-05 08:35:04.355 | INFO     | transtab.modeling_transtab:load:775 - load model from ./ckpt/pretrained
2022-10-05 08:35:04.370 | INFO     | transtab.modeling_transtab:load:222 - load feature extractor from ./ckpt/pretrained/extractor/extractor.json
2022-10-05 08:35:04.372 | INFO     | transtab.modeling_transtab:update:832 - Build a new classifier with num 2 classes outputs, need further finetune to work.


In [9]:
# start training
training_arguments = {
    'num_epoch':50,
    'eval_metric':'auc',
    'eval_less_is_better':False,
    'output_dir':'./checkpoint',
    'batch_size':128,
    'lr':2e-4,
    }

transtab.train(model, trainset[1], valset[1], **training_arguments)


Epoch:   0%|          | 0/50 [00:00<?, ?it/s]

  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 0, test auc: 0.282251
epoch: 0, train loss: 3.3862, lr: 0.000200, spent: 0.2 secs
epoch: 1, test auc: 0.865801
epoch: 1, train loss: 2.8794, lr: 0.000200, spent: 0.3 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 2, test auc: 0.865801
epoch: 2, train loss: 2.5943, lr: 0.000200, spent: 0.7 secs
epoch: 3, test auc: 0.865801
epoch: 3, train loss: 2.4300, lr: 0.000200, spent: 0.8 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 4, test auc: 0.872727
epoch: 4, train loss: 2.2617, lr: 0.000200, spent: 1.0 secs
epoch: 5, test auc: 0.879654
epoch: 5, train loss: 2.0867, lr: 0.000200, spent: 1.1 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 6, test auc: 0.880519
epoch: 6, train loss: 1.9774, lr: 0.000200, spent: 1.3 secs
epoch: 7, test auc: 0.883117
epoch: 7, train loss: 1.8739, lr: 0.000200, spent: 1.4 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 8, test auc: 0.889177
epoch: 8, train loss: 1.8919, lr: 0.000200, spent: 1.5 secs
epoch: 9, test auc: 0.890909
epoch: 9, train loss: 1.8794, lr: 0.000200, spent: 1.7 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 10, test auc: 0.896970
epoch: 10, train loss: 1.8456, lr: 0.000200, spent: 2.0 secs
epoch: 11, test auc: 0.897835
epoch: 11, train loss: 1.8213, lr: 0.000200, spent: 2.2 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 12, test auc: 0.896104
EarlyStopping counter: 1 out of 5
epoch: 12, train loss: 1.8219, lr: 0.000200, spent: 2.3 secs
epoch: 13, test auc: 0.903896
epoch: 13, train loss: 1.7924, lr: 0.000200, spent: 2.4 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 14, test auc: 0.905628
epoch: 14, train loss: 1.7964, lr: 0.000200, spent: 2.6 secs
epoch: 15, test auc: 0.904762
EarlyStopping counter: 1 out of 5
epoch: 15, train loss: 1.7641, lr: 0.000200, spent: 2.7 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 16, test auc: 0.904762
EarlyStopping counter: 2 out of 5
epoch: 16, train loss: 1.7788, lr: 0.000200, spent: 2.8 secs
epoch: 17, test auc: 0.909091
epoch: 17, train loss: 1.7456, lr: 0.000200, spent: 2.9 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 18, test auc: 0.910823
epoch: 18, train loss: 1.7438, lr: 0.000200, spent: 3.3 secs
epoch: 19, test auc: 0.912554
epoch: 19, train loss: 1.7569, lr: 0.000200, spent: 3.4 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 20, test auc: 0.912554
epoch: 20, train loss: 1.7533, lr: 0.000200, spent: 3.5 secs
epoch: 21, test auc: 0.915152
epoch: 21, train loss: 1.7439, lr: 0.000200, spent: 3.7 secs
epoch: 22, test auc: 0.915152
epoch: 22, train loss: 1.7020, lr: 0.000200, spent: 3.9 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 23, test auc: 0.916883
epoch: 23, train loss: 1.7017, lr: 0.000200, spent: 4.0 secs
epoch: 24, test auc: 0.917749
epoch: 24, train loss: 1.6625, lr: 0.000200, spent: 4.1 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 25, test auc: 0.918615
epoch: 25, train loss: 1.6432, lr: 0.000200, spent: 4.3 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 26, test auc: 0.922944
epoch: 26, train loss: 1.6299, lr: 0.000200, spent: 4.7 secs
epoch: 27, test auc: 0.922944
EarlyStopping counter: 1 out of 5
epoch: 27, train loss: 1.6158, lr: 0.000200, spent: 4.8 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 28, test auc: 0.925541
epoch: 28, train loss: 1.5971, lr: 0.000200, spent: 4.9 secs
epoch: 29, test auc: 0.926407
epoch: 29, train loss: 1.5771, lr: 0.000200, spent: 5.0 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 30, test auc: 0.927273
epoch: 30, train loss: 1.5763, lr: 0.000200, spent: 5.2 secs
epoch: 31, test auc: 0.933333
epoch: 31, train loss: 1.6021, lr: 0.000200, spent: 5.3 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 32, test auc: 0.936797
epoch: 32, train loss: 1.5513, lr: 0.000200, spent: 5.5 secs
epoch: 33, test auc: 0.938528
epoch: 33, train loss: 1.5160, lr: 0.000200, spent: 5.6 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 34, test auc: 0.938528
epoch: 34, train loss: 1.5250, lr: 0.000200, spent: 5.8 secs
epoch: 35, test auc: 0.938528
epoch: 35, train loss: 1.4732, lr: 0.000200, spent: 6.0 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 36, test auc: 0.934199
EarlyStopping counter: 1 out of 5
epoch: 36, train loss: 1.4738, lr: 0.000200, spent: 6.1 secs
epoch: 37, test auc: 0.934199
EarlyStopping counter: 2 out of 5
epoch: 37, train loss: 1.4667, lr: 0.000200, spent: 6.2 secs


  y_test = pd.concat(y_test, 0)
  y_test = pd.concat(y_test, 0)


epoch: 38, test auc: 0.933333
EarlyStopping counter: 3 out of 5
epoch: 38, train loss: 1.4209, lr: 0.000200, spent: 6.3 secs
epoch: 39, test auc: 0.933333
EarlyStopping counter: 4 out of 5
epoch: 39, train loss: 1.4371, lr: 0.000200, spent: 6.4 secs


  y_test = pd.concat(y_test, 0)
2022-10-05 08:35:10.982 | INFO     | transtab.trainer:train:136 - load best at last from ./checkpoint
2022-10-05 08:35:10.994 | INFO     | transtab.trainer:save_model:243 - saving model checkpoint to ./checkpoint
2022-10-05 08:35:11.142 | INFO     | transtab.trainer:train:141 - training complete, cost 6.7 secs.


epoch: 40, test auc: 0.929870
EarlyStopping counter: 5 out of 5
early stopped


In [10]:
# evaluation
x_test, y_test = testset[1]
ypred = transtab.predict(model, x_test)
transtab.evaluate(ypred, y_test, metric='auc')

from sklearn.metrics import roc_auc_score
print(roc_auc_score(y_test, ypred))

auc 0.95 mean/interval 0.8757(0.05)
0.8807749627421758
