In [1]:
import pandas as pd
import numpy as np
import gc
import os
from pytorch_tabnet.tab_model import TabNetRegressor
import torch

# 'TSP' is our paper project code
from pytorch_tabnet.TSP_implement import TSP_Loss

import json

## Hyperparameters and Settings
- `data_path_fast` and `data_path_expensive` are datasets for an estimator that can be replaced. Please select the dataset for a specific fast/expensive model in `./dataset_estimator/ImageNet`.

- `lambda_CDC_weight` is the `lambda` in the paper.

- The trained weight will be stored in `./estimator_weights/weights_ImageNet_{fast_model}_{expensive_model}`

In [2]:
root = os.path.abspath('.')

# dataset path
data_path_fast= os.path.join(root,'dataset_estimator','ImageNet','tf_efficientnet_b0_ImageNet_dataset.csv')
data_path_expensive=os.path.join(root,'dataset_estimator','ImageNet','resnet101_ImageNet_dataset.csv')

# setting
dataset_split_type = 'split_by_json' # 'split_by_json' or 'split_by_random'
json_for_estimator = os.path.join(root,'seeds_json','ImageNet_json_split_for_estimator_1.json')

# The metric is used for early stopping
eval_metric=['AIR'] # 'AIR' or 'TSP-Loss-Overall'

learning_rate = 1e-2
max_epochs = 200

lambda_CDC_weight = 0.2 # 0.1, 0.15 or 0.2
Loss_Base_weight = 1.0

## Processing settings


In [3]:
# save weight name
if lambda_CDC_weight != 0:
    loss_type = "L-Overall"
else: 
    loss_type = "L-Base"

if eval_metric[-1] == 'TSP-Loss-Overall':
    metric_name = "Loss"
else: 
    metric_name = "AIR"

seed_id = json_for_estimator.split('_')[-1].split('.')[0]

shallow_name = os.path.split(data_path_fast.split('_dataset')[0])[-1]
deep_name = os.path.split(data_path_expensive.split('_dataset')[0])[-1]

saving_weight_name = os.path.join(root,'estimator_weights',f'weights_ImageNet_{shallow_name}_{deep_name}',f'TabNet_{loss_type}_{metric_name}_Seed_{seed_id}')

## Processing the data

In [4]:
data_shallow = pd.read_csv(data_path_fast,header=0)
print(data_shallow.shape)

data_shallow.head()

(64056, 1003)


Unnamed: 0,index_in_dataset,correctness,correct_cls_index,prob_index_0,prob_index_1,prob_index_2,prob_index_3,prob_index_4,prob_index_5,prob_index_6,...,prob_index_990,prob_index_991,prob_index_992,prob_index_993,prob_index_994,prob_index_995,prob_index_996,prob_index_997,prob_index_998,prob_index_999
0,1022729,1,798,3.016706e-07,1e-06,5e-06,4e-06,4e-06,2.358331e-05,3e-06,...,2e-06,2e-06,2e-06,2e-06,1e-06,8e-06,1e-06,2e-06,2.6e-05,5e-06
1,616088,1,479,3.725413e-05,4.1e-05,6e-05,0.000187,0.000156,0.0001999536,0.000133,...,5.9e-05,7.3e-05,0.000135,0.000115,0.000138,0.00038,9.9e-05,0.000118,0.000109,9.8e-05
2,376291,1,294,0.0001502893,0.000139,0.000136,3.5e-05,0.000239,0.0002784144,0.000143,...,4.8e-05,0.000137,0.000104,4.8e-05,8.4e-05,5.2e-05,6.1e-05,6.3e-05,7.5e-05,0.000175
3,819303,0,638,3.032851e-06,3e-06,3e-06,2e-06,2e-06,4.118787e-07,3e-06,...,7e-06,9e-06,7e-06,5e-06,1.5e-05,2.3e-05,4e-06,5.5e-05,3e-06,8e-06
4,523820,1,408,2.35018e-05,2.8e-05,6.5e-05,5.1e-05,2.4e-05,4.494722e-05,0.000103,...,9e-06,5.7e-05,4.9e-05,6.2e-05,2.8e-05,3e-06,1.5e-05,6.8e-05,5.4e-05,0.00016


In [5]:
data_deep = pd.read_csv(data_path_expensive,header=0)
print(data_deep.shape)

data_deep.head()

(64056, 1003)


Unnamed: 0,index_in_dataset,correctness,correct_cls_index,prob_index_0,prob_index_1,prob_index_2,prob_index_3,prob_index_4,prob_index_5,prob_index_6,...,prob_index_990,prob_index_991,prob_index_992,prob_index_993,prob_index_994,prob_index_995,prob_index_996,prob_index_997,prob_index_998,prob_index_999
0,1022729,1,798,1e-06,8.676604e-07,2e-06,7.610744e-07,2e-06,2e-06,4.783777e-07,...,8.534692e-07,3e-06,2e-06,2e-06,1e-06,7.887821e-07,7.520305e-07,3e-06,1e-06,8.415565e-07
1,616088,1,479,9.7e-05,0.0001246419,9.8e-05,0.0001952513,5.2e-05,0.000102,0.0001192066,...,0.0001321169,8.2e-05,0.000205,0.000204,0.000215,4.614002e-05,4.976594e-05,6.5e-05,6.1e-05,2.028638e-05
2,376291,1,294,0.000338,6.952937e-05,4.6e-05,0.0001501845,0.000151,8.4e-05,7.68114e-05,...,5.494538e-05,9.3e-05,7.8e-05,0.000153,0.00015,0.0001367076,9.844397e-05,0.000128,3.4e-05,0.0002950774
3,819303,0,638,0.000159,0.000121574,0.00017,3.645422e-05,0.000408,0.000255,0.0001105206,...,0.0002582925,0.000253,0.000216,0.000172,0.000213,0.0002282218,7.214426e-05,0.00043,0.000189,0.0001136141
4,523820,1,408,0.000132,8.513094e-05,0.000248,7.586108e-05,0.000183,0.000123,0.0001160375,...,9.529041e-05,0.000141,0.00017,0.00011,0.000173,0.0001211055,5.923554e-05,0.00013,9.2e-05,0.0001753411


In [6]:
# check filename
print(data_shallow.iloc[:,0]==data_deep.iloc[:,0])

0        True
1        True
2        True
3        True
4        True
         ... 
64051    True
64052    True
64053    True
64054    True
64055    True
Name: index_in_dataset, Length: 64056, dtype: bool


In [7]:
traning_data=pd.concat([data_shallow.iloc[:,3:],data_deep.iloc[:,3:],data_shallow.iloc[:,1],data_deep.iloc[:,1],data_shallow.iloc[:,2]],axis=1)
print(traning_data.shape)
traning_data.head()

(64056, 2003)


Unnamed: 0,prob_index_0,prob_index_1,prob_index_2,prob_index_3,prob_index_4,prob_index_5,prob_index_6,prob_index_7,prob_index_8,prob_index_9,...,prob_index_993,prob_index_994,prob_index_995,prob_index_996,prob_index_997,prob_index_998,prob_index_999,correctness,correctness.1,correct_cls_index
0,3.016706e-07,1e-06,5e-06,4e-06,4e-06,2.358331e-05,3e-06,3e-06,7.440551e-07,2e-06,...,2e-06,1e-06,7.887821e-07,7.520305e-07,3e-06,1e-06,8.415565e-07,1,1,798
1,3.725413e-05,4.1e-05,6e-05,0.000187,0.000156,0.0001999536,0.000133,0.000224,0.0003195728,0.000224,...,0.000204,0.000215,4.614002e-05,4.976594e-05,6.5e-05,6.1e-05,2.028638e-05,1,1,479
2,0.0001502893,0.000139,0.000136,3.5e-05,0.000239,0.0002784144,0.000143,3.9e-05,6.137923e-05,2.7e-05,...,0.000153,0.00015,0.0001367076,9.844397e-05,0.000128,3.4e-05,0.0002950774,1,1,294
3,3.032851e-06,3e-06,3e-06,2e-06,2e-06,4.118787e-07,3e-06,1e-06,2.556007e-06,3e-06,...,0.000172,0.000213,0.0002282218,7.214426e-05,0.00043,0.000189,0.0001136141,0,0,638
4,2.35018e-05,2.8e-05,6.5e-05,5.1e-05,2.4e-05,4.494722e-05,0.000103,1.7e-05,1.099746e-05,2.1e-05,...,0.00011,0.000173,0.0001211055,5.923554e-05,0.00013,9.2e-05,0.0001753411,1,1,408


## Splitting training, validation and test sets

In [8]:
if dataset_split_type == 'split_by_random':  
    training_data_dim=traning_data.shape[1]-3
    
    traning_data["Set"] = np.random.choice(["train", "valid", "test"], p =[.8, .1, .1], size=(traning_data.shape[0],))

    train_indices = traning_data[traning_data.Set=="train"].index
    valid_indices = traning_data[traning_data.Set=="valid"].index
    test_indices = traning_data[traning_data.Set=="test"].index

    X_train = traning_data.iloc[:,:training_data_dim].values[train_indices]
    y_train = traning_data.iloc[:,training_data_dim:-1].values[train_indices]#.reshape(-1, 3)

    X_valid = traning_data.iloc[:,:training_data_dim].values[valid_indices]
    y_valid = traning_data.iloc[:,training_data_dim:-1].values[valid_indices]#.reshape(-1, 3)

    X_test = traning_data.iloc[:,:training_data_dim].values[test_indices]
    y_test = traning_data.iloc[:,training_data_dim:-1].values[test_indices]#.reshape(-1, 3)
    print(X_valid.shape,y_valid.shape)

elif dataset_split_type == 'split_by_json':
    training_data_dim = traning_data.shape[1]-3

    # Add index to the last column
    traning_data = pd.concat([traning_data,data_shallow.iloc[:,0]],axis=1)

    with open(json_for_estimator) as f:
        estimator_split = json.load(f)

    X_train = traning_data[traning_data['index_in_dataset'].isin(estimator_split['index_train'])].iloc[:,:training_data_dim].values
    y_train = traning_data[traning_data['index_in_dataset'].isin(estimator_split['index_train'])].iloc[:,training_data_dim:-1].values

    X_valid = traning_data[traning_data['index_in_dataset'].isin(estimator_split['index_val'])].iloc[:,:training_data_dim].values
    y_valid = traning_data[traning_data['index_in_dataset'].isin(estimator_split['index_val'])].iloc[:,training_data_dim:-1].values 

    print(X_valid.shape, y_valid.shape) 
    
else:
    raise Exception('please provide dataset_split_type')

(19215, 2000) (19215, 3)


## Define estimator model

In [9]:
clf = TabNetRegressor(optimizer_params=dict(lr=learning_rate),scheduler_fn=torch.optim.lr_scheduler.StepLR,scheduler_params={"gamma": 0.9, "step_size": 5})



## Training TabNet

In [10]:
clf.fit(
    X_train=X_train, y_train=y_train,
    eval_set=[(X_train, y_train), (X_valid, y_valid)],
    eval_name=['train', 'valid'],
    # custom loss
    loss_fn=TSP_Loss(lambda_CDC_weight = lambda_CDC_weight, Loss_Base_weight = Loss_Base_weight),
    eval_metric=eval_metric,
    max_epochs=max_epochs,
    patience=15,
    batch_size=512, virtual_batch_size=128, #default:batch_size=1024, virtual_batch_size=128
    num_workers=0,
    drop_last=False
) 

TSP_TabNet!
epoch 0  | loss: 0.66471 | train_AIR: 0.68453 | valid_AIR: 0.65867 |  0:00:12s
epoch 1  | loss: 0.24537 | train_AIR: 0.68146 | valid_AIR: 0.65689 |  0:00:22s
epoch 2  | loss: 0.24084 | train_AIR: 0.68399 | valid_AIR: 0.65778 |  0:00:33s
epoch 3  | loss: 0.23974 | train_AIR: 0.68544 | valid_AIR: 0.66178 |  0:00:44s
epoch 4  | loss: 0.23837 | train_AIR: 0.68453 | valid_AIR: 0.65733 |  0:00:54s
epoch 5  | loss: 0.23709 | train_AIR: 0.69305 | valid_AIR: 0.66311 |  0:01:04s
epoch 6  | loss: 0.23605 | train_AIR: 0.69794 | valid_AIR: 0.67244 |  0:01:14s
epoch 7  | loss: 0.23602 | train_AIR: 0.69685 | valid_AIR: 0.67378 |  0:01:25s
epoch 8  | loss: 0.23535 | train_AIR: 0.69902 | valid_AIR: 0.67067 |  0:01:35s
epoch 9  | loss: 0.23513 | train_AIR: 0.69938 | valid_AIR: 0.67156 |  0:01:46s
epoch 10 | loss: 0.2348  | train_AIR: 0.69866 | valid_AIR: 0.672   |  0:01:57s
epoch 11 | loss: 0.23454 | train_AIR: 0.70083 | valid_AIR: 0.67067 |  0:02:07s
epoch 12 | loss: 0.23457 | train_AIR: 0.



## Saving weight of estimator

In [11]:
#save tabnet model
saved_filepath = clf.save_model(saving_weight_name)

Successfully saved model at c:\Users\yaoching\Desktop\ICCV_Smart_cascade_code\estimator_weights\weights_ImageNet_tf_efficientnet_b0_ImageNet_resnet101_ImageNet\TabNet_L-Overall_AIR_Seed_1.zip


## Toy test

In [12]:
print(clf.predict(np.array([1.0,0.,0.,0.,0.,0.]+[0.0]*994).reshape(-1,1000)))
print(clf.predict(np.array([0.9,0.1,0.,0.,0.,0.]+[0.0]*994).reshape(-1,1000)))
print(clf.predict(np.array([0.8,0.2,0.,0.,0.,0.]+[0.0]*994).reshape(-1,1000)))
print(clf.predict(np.array([0.7,0.3,0.,0.,0.,0.]+[0.0]*994).reshape(-1,1000)))

[[1.0453709]]
[[0.87544674]]
[[0.7082589]]
[[0.6741762]]
