In [1]:
import os
os.chdir('../')

import transtab

# set random seed
transtab.random_seed(42)

In [2]:
# load a dataset and start vanilla supervised training
allset, trainset, valset, testset, cat_cols, num_cols, bin_cols \
    = transtab.load_data('credit-g')

########################################
openml data index: 31
load data from credit-g
# data: 1000, # feat: 20, # cate: 11,  # bin: 2, # numerical: 7, pos rate: 0.70


In [3]:
# make a fast pre-train of TransTab contrastive learning model
# build contrastive learner, set supervised=True for supervised VPCL
model, collate_fn = transtab.build_contrastive_learner(
    cat_cols, num_cols, bin_cols, 
    supervised=True, # if take supervised CL
    num_partition=4, # num of column partitions for pos/neg sampling
    overlap_ratio=0.5, # specify the overlap ratio of column partitions during the CL
)

# start contrastive pretraining training
training_arguments = {
    'num_epoch':50,
    'batch_size':64,
    'lr':1e-4,
    'eval_metric':'val_loss',
    'eval_less_is_better':True,
    'output_dir':'./checkpoint'
    }

transtab.train(model, trainset, valset, collate_fn=collate_fn, **training_arguments)

Epoch:   0%|          | 0/50 [00:00<?, ?it/s]

epoch: 0, test val_loss: 6.349929
epoch: 0, train loss: 72.9975, lr: 0.000100, spent: 1.3 secs
epoch: 1, test val_loss: 6.043663
epoch: 1, train loss: 62.8806, lr: 0.000100, spent: 2.2 secs
epoch: 2, test val_loss: 5.999826
epoch: 2, train loss: 61.3078, lr: 0.000100, spent: 3.0 secs
epoch: 3, test val_loss: 5.989734
epoch: 3, train loss: 61.0470, lr: 0.000100, spent: 3.9 secs
epoch: 4, test val_loss: 5.986117
epoch: 4, train loss: 60.9742, lr: 0.000100, spent: 4.8 secs
epoch: 5, test val_loss: 5.984314
epoch: 5, train loss: 60.9454, lr: 0.000100, spent: 5.8 secs
epoch: 6, test val_loss: 5.983197
epoch: 6, train loss: 60.9270, lr: 0.000100, spent: 6.7 secs
epoch: 7, test val_loss: 5.982450
epoch: 7, train loss: 60.9164, lr: 0.000100, spent: 7.6 secs
epoch: 8, test val_loss: 5.981885
epoch: 8, train loss: 60.9102, lr: 0.000100, spent: 8.5 secs
epoch: 9, test val_loss: 5.981443
epoch: 9, train loss: 60.9047, lr: 0.000100, spent: 9.5 secs
epoch: 10, test val_loss: 5.981087
epoch: 10, trai

2022-08-31 14:15:16.839 | INFO     | transtab.trainer:train:132 - load best at last from ./checkpoint
2022-08-31 14:15:16.853 | INFO     | transtab.trainer:save_model:239 - saving model checkpoint to ./checkpoint


epoch: 49, test val_loss: 5.978854
epoch: 49, train loss: 60.8699, lr: 0.000100, spent: 46.8 secs


2022-08-31 14:15:17.035 | INFO     | transtab.trainer:train:137 - training complete, cost 47.0 secs.


In [4]:
# There are two ways to build the encoder
# First, take the whole pretrained model and output the cls token embedding at the last layer's outputs
enc = transtab.build_encoder(
    binary_columns=bin_cols,
    checkpoint = './checkpoint'
)

2022-08-31 14:15:17.125 | INFO     | transtab.modeling_transtab:load:773 - missing keys: []
2022-08-31 14:15:17.126 | INFO     | transtab.modeling_transtab:load:774 - unexpected keys: ['projection_head.dense.weight']
2022-08-31 14:15:17.126 | INFO     | transtab.modeling_transtab:load:775 - load model from ./checkpoint
2022-08-31 14:15:17.159 | INFO     | transtab.modeling_transtab:load:222 - load feature extractor from ./checkpoint/extractor/extractor.json


In [5]:
# Then take the encoder to get the input embedding
df = trainset[0]
output = enc(df)
print(output.shape)
output[:2]

torch.Size([700, 128])


tensor([[ 1.2959e+00,  1.5239e+00, -1.2096e+00,  3.0303e-01,  7.4638e-01,
          1.1758e+00,  1.1774e+00, -2.1921e-01,  4.2850e-01,  8.3295e-03,
         -5.3477e-01,  1.4859e+00, -2.0534e+00, -9.4093e-01,  3.7010e-01,
          1.3663e-01,  4.4837e-01,  1.3882e+00,  1.6472e+00, -1.2430e+00,
         -4.8809e-01, -5.1914e-01, -3.3168e-01,  1.9889e+00, -4.9873e-01,
          1.2286e+00,  8.6373e-01,  5.1300e-01,  6.7551e-01, -1.2021e+00,
          6.3210e-01,  6.2366e-01,  5.6712e-01,  1.2275e-03, -1.5154e+00,
          2.0082e+00, -1.2255e+00, -2.4254e-01, -5.1009e-01,  1.6733e+00,
         -1.2059e+00, -7.0246e-01,  1.8980e-01, -7.8196e-01,  1.0777e+00,
         -6.1830e-01, -1.1279e+00, -1.3290e+00,  9.6929e-01, -7.6388e-02,
         -4.5835e-01, -1.1462e+00,  1.5084e+00,  5.7778e-01,  2.0644e-01,
          4.3633e-01,  7.6116e-03,  5.2441e-01, -1.9919e-01, -1.9441e-01,
          1.8144e+00,  2.7863e-01, -1.8727e+00, -9.4760e-01,  1.1152e+00,
          3.5514e-01,  1.6321e+00,  4.

In [6]:
df.head()

Unnamed: 0,own_telephone,foreign_worker,duration,credit_amount,installment_commitment,residence_since,age,existing_credits,num_dependents,checking_status,credit_history,purpose,savings_status,employment,personal_status,other_parties,property_magnitude,other_payment_plans,housing,job
636,0,1,0.294118,0.061957,1.0,0.0,0.160714,0.0,0.0,no checking,existing paid,radio/tv,500<=X<1000,4<=X<7,female div/dep/mar,none,car,none,own,skilled
182,0,1,0.25,0.076868,1.0,0.333333,0.375,0.333333,1.0,<0,all paid,new car,no known savings,1<=X<4,male single,none,life insurance,none,own,unskilled resident
736,0,1,0.294118,0.622318,0.0,1.0,0.071429,0.333333,0.0,0<=X<200,existing paid,used car,<100,1<=X<4,female div/dep/mar,none,car,none,rent,high qualif/self emp/mgmt
922,0,1,0.073529,0.061406,0.666667,1.0,0.053571,0.0,0.0,<0,existing paid,radio/tv,<100,<1,female div/dep/mar,none,life insurance,none,rent,skilled
511,1,1,0.470588,0.244085,0.333333,0.333333,0.232143,0.0,0.0,no checking,existing paid,used car,<100,1<=X<4,male single,none,no known property,none,for free,high qualif/self emp/mgmt


In [8]:
# Second, if we only want to the embeded token level embeddings (embeddings before going to transformers)
enc = transtab.build_encoder(
    binary_columns=bin_cols,
    checkpoint = './checkpoint',
    num_layer = 0,
)

2022-08-31 14:16:28.124 | INFO     | transtab.modeling_transtab:load:222 - load feature extractor from ./checkpoint/extractor/extractor.json
2022-08-31 14:16:28.134 | INFO     | transtab.modeling_transtab:load:523 - missing keys: []
2022-08-31 14:16:28.135 | INFO     | transtab.modeling_transtab:load:524 - unexpected keys: []
2022-08-31 14:16:28.136 | INFO     | transtab.modeling_transtab:load:525 - load model from ./checkpoint


In [12]:
output = enc(df)
print(output['embedding'].shape)
output['embedding'][:2]

torch.Size([700, 85, 128])


tensor([[[ 0.1370,  0.0427, -0.0106,  ..., -0.0806,  0.0518, -0.1315],
         [ 0.0657,  0.0341, -0.0128,  ..., -0.0207,  0.0102, -0.0046],
         [ 0.1494,  0.4290,  0.2463,  ...,  0.1992, -0.0848, -0.0840],
         ...,
         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268],
         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268],
         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268]],

        [[ 0.1204,  0.0388, -0.0098,  ..., -0.0738,  0.0400, -0.1099],
         [ 0.0752,  0.0383, -0.0145,  ..., -0.0174,  0.0190, -0.0085],
         [ 0.1494,  0.4290,  0.2463,  ...,  0.1992, -0.0848, -0.0840],
         ...,
         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268],
         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268],
         [ 1.1575,  0.0165,  0.9202,  ..., -0.2052,  1.0815, -1.0268]]],
       device='cuda:0', grad_fn=<SliceBackward0>)