In [1]:
import pandas as pd
import numpy as np
import gc
import os
from pytorch_tabnet.tab_model import TabNetRegressor
import torch

# 'TSP' is our paper project code
from pytorch_tabnet.TSP_implement import TSP_Loss

import json

## Hyperparameters and Settings
- `data_path_fast` and `data_path_expensive` are datasets for an estimator that can be replaced. Please select the dataset for a specific fast/expensive model in `./dataset_estimator/CIFAR100`.

- `lambda_CDC_weight` is the `lambda` in the paper.

- The trained weight will be stored in `./estimator_weights/weights_{fast_model}_{expensive_model}`

In [2]:
root = os.path.abspath('.')

# dataset path
data_path_fast= os.path.join(root,'dataset_estimator','CIFAR100','resnet18_dataset.csv')
data_path_expensive=os.path.join(root,'dataset_estimator','CIFAR100','resnet101_dataset.csv')

# setting
dataset_split_type = 'split_by_json' # 'split_by_json' or 'split_by_random'
json_for_estimator = os.path.join(root,'seeds_json','json_split_for_estimator_1.json')

# The metric is used for early stopping
eval_metric=['AIR'] # 'AIR' or 'TSP-Loss-Overall'

learning_rate = 2e-3
max_epochs = 200

lambda_CDC_weight = 0.2 # 0.1, 0.15 or 0.2
Loss_Base_weight = 1.0

## Processing settings


In [3]:
# save weight name
if lambda_CDC_weight != 0:
    loss_type = "L-Overall"
else: 
    loss_type = "L-Base"

if eval_metric[-1] == 'TSP-Loss-Overall':
    metric_name = "Loss"
else: 
    metric_name = "AIR"

seed_id = json_for_estimator.split('_')[-1].split('.')[0]

shallow_name = os.path.split(data_path_fast.split('_dataset')[0])[-1]
deep_name = os.path.split(data_path_expensive.split('_dataset')[0])[-1]

saving_weight_name = os.path.join(root,'estimator_weights',f'weights_{shallow_name}_{deep_name}',f'TabNet_{loss_type}_{metric_name}_Seed_{seed_id}')

## Processing the data

In [4]:
data_shallow = pd.read_csv(data_path_fast,header=0)
print(data_shallow.shape)

data_shallow.head()

(10000, 103)


Unnamed: 0,index_in_dataset,correctness,correct_cls_index,prob_index_0,prob_index_1,prob_index_2,prob_index_3,prob_index_4,prob_index_5,prob_index_6,...,prob_index_90,prob_index_91,prob_index_92,prob_index_93,prob_index_94,prob_index_95,prob_index_96,prob_index_97,prob_index_98,prob_index_99
0,1384,1,12,0.003795,0.001102,0.004457,0.003634,0.00204,0.008378,0.000414,...,0.013671,0.000946,0.001571,0.000334,0.001054,0.005671,0.001071,0.002967,0.003104,0.00143
1,47652,0,52,0.000691,0.000299,0.003798,0.000804,0.000154,0.000505,0.000129,...,0.000485,0.000235,0.000484,0.000238,0.000952,0.000562,0.009734,0.000484,0.000404,0.00046
2,7641,1,94,0.000874,0.000264,0.000128,0.000102,0.000478,0.000159,0.00226,...,0.000117,7.2e-05,0.001296,0.001115,0.938051,0.000655,0.000327,0.000278,0.000118,0.000482
3,29229,1,82,0.000273,0.000217,0.000216,0.000362,0.000523,0.000812,0.011486,...,0.000503,0.000302,0.005693,0.000231,0.00072,0.000559,0.000665,0.000743,0.001218,0.000375
4,13143,1,78,0.000508,0.000934,0.001392,0.004842,0.000776,0.001373,0.000226,...,0.002749,0.000832,0.001555,0.000562,0.001098,0.003118,0.002139,0.002678,0.001215,0.055704


In [5]:
data_deep = pd.read_csv(data_path_expensive,header=0)
print(data_deep.shape)

data_deep.head()

(10000, 103)


Unnamed: 0,index_in_dataset,correctness,correct_cls_index,prob_index_0,prob_index_1,prob_index_2,prob_index_3,prob_index_4,prob_index_5,prob_index_6,...,prob_index_90,prob_index_91,prob_index_92,prob_index_93,prob_index_94,prob_index_95,prob_index_96,prob_index_97,prob_index_98,prob_index_99
0,1384,1,12,0.000743,0.000413,0.000473,0.000702,0.001246,0.005389,0.000932,...,0.000386,0.000754,0.001024,0.001731,0.00054,0.001054,0.000727,0.001175,0.001221,0.000555
1,47652,0,52,0.001202,0.000653,0.000856,0.001213,0.000425,0.000925,0.000809,...,0.000573,0.000655,0.000943,0.000776,0.000982,0.001496,0.050321,0.001202,0.001111,0.001907
2,7641,1,94,0.000882,0.000636,0.000793,0.000435,0.000892,0.003257,0.002244,...,0.000737,0.000604,0.00093,0.000844,0.919552,0.001065,0.000744,0.000465,0.000798,0.000858
3,29229,1,82,0.000366,0.000351,0.000187,0.000696,0.000644,0.000296,0.001348,...,0.000401,0.000216,0.006174,0.000173,0.000313,0.000589,0.00072,0.000626,0.000435,0.000155
4,13143,1,78,0.000386,0.000672,0.000813,0.000718,0.000666,0.000286,0.00041,...,0.000838,0.001803,0.001267,0.000378,0.000859,0.001674,0.001201,0.000477,0.000381,0.03012


In [6]:
# check filename
print(data_shallow.iloc[:,0]==data_deep.iloc[:,0])

0       True
1       True
2       True
3       True
4       True
        ... 
9995    True
9996    True
9997    True
9998    True
9999    True
Name: index_in_dataset, Length: 10000, dtype: bool


In [7]:
traning_data=pd.concat([data_shallow.iloc[:,3:],data_deep.iloc[:,3:],data_shallow.iloc[:,1],data_deep.iloc[:,1],data_shallow.iloc[:,2]],axis=1)
print(traning_data.shape)
traning_data.head()

(10000, 203)


Unnamed: 0,prob_index_0,prob_index_1,prob_index_2,prob_index_3,prob_index_4,prob_index_5,prob_index_6,prob_index_7,prob_index_8,prob_index_9,...,prob_index_93,prob_index_94,prob_index_95,prob_index_96,prob_index_97,prob_index_98,prob_index_99,correctness,correctness.1,correct_cls_index
0,0.003795,0.001102,0.004457,0.003634,0.00204,0.008378,0.000414,0.002539,0.006601,0.001594,...,0.001731,0.00054,0.001054,0.000727,0.001175,0.001221,0.000555,1,1,12
1,0.000691,0.000299,0.003798,0.000804,0.000154,0.000505,0.000129,0.000344,0.000887,0.001053,...,0.000776,0.000982,0.001496,0.050321,0.001202,0.001111,0.001907,0,0,52
2,0.000874,0.000264,0.000128,0.000102,0.000478,0.000159,0.00226,0.001749,0.000146,0.001091,...,0.000844,0.919552,0.001065,0.000744,0.000465,0.000798,0.000858,1,1,94
3,0.000273,0.000217,0.000216,0.000362,0.000523,0.000812,0.011486,0.000421,0.000445,0.000648,...,0.000173,0.000313,0.000589,0.00072,0.000626,0.000435,0.000155,1,1,82
4,0.000508,0.000934,0.001392,0.004842,0.000776,0.001373,0.000226,0.001912,0.001828,0.000898,...,0.000378,0.000859,0.001674,0.001201,0.000477,0.000381,0.03012,1,1,78


## Splitting training, validation and test sets

In [8]:
if dataset_split_type == 'split_by_random':  
    training_data_dim=traning_data.shape[1]-3
    
    traning_data["Set"] = np.random.choice(["train", "valid", "test"], p =[.8, .1, .1], size=(traning_data.shape[0],))

    train_indices = traning_data[traning_data.Set=="train"].index
    valid_indices = traning_data[traning_data.Set=="valid"].index
    test_indices = traning_data[traning_data.Set=="test"].index

    X_train = traning_data.iloc[:,:training_data_dim].values[train_indices]
    y_train = traning_data.iloc[:,training_data_dim:-1].values[train_indices]#.reshape(-1, 3)

    X_valid = traning_data.iloc[:,:training_data_dim].values[valid_indices]
    y_valid = traning_data.iloc[:,training_data_dim:-1].values[valid_indices]#.reshape(-1, 3)

    X_test = traning_data.iloc[:,:training_data_dim].values[test_indices]
    y_test = traning_data.iloc[:,training_data_dim:-1].values[test_indices]#.reshape(-1, 3)
    print(X_valid.shape,y_valid.shape)

elif dataset_split_type == 'split_by_json':
    training_data_dim = traning_data.shape[1]-3

    # Add index to the last column
    traning_data = pd.concat([traning_data,data_shallow.iloc[:,0]],axis=1)

    with open(json_for_estimator) as f:
        estimator_split = json.load(f)

    X_train = traning_data[traning_data['index_in_dataset'].isin(estimator_split['index_train'])].iloc[:,:training_data_dim].values
    y_train = traning_data[traning_data['index_in_dataset'].isin(estimator_split['index_train'])].iloc[:,training_data_dim:-1].values

    X_valid = traning_data[traning_data['index_in_dataset'].isin(estimator_split['index_val'])].iloc[:,:training_data_dim].values
    y_valid = traning_data[traning_data['index_in_dataset'].isin(estimator_split['index_val'])].iloc[:,training_data_dim:-1].values 

    print(X_valid.shape, y_valid.shape) 
    
else:
    raise Exception('please provide dataset_split_type')

(3000, 200) (3000, 3)


## Define estimator model

In [9]:
clf = TabNetRegressor(optimizer_params=dict(lr=learning_rate),scheduler_fn=torch.optim.lr_scheduler.StepLR,scheduler_params={"gamma": 0.9, "step_size": 5})



## Training TabNet

In [10]:
clf.fit(
    X_train=X_train, y_train=y_train,
    eval_set=[(X_train, y_train), (X_valid, y_valid)],
    eval_name=['train', 'valid'],
    # custom loss
    loss_fn=TSP_Loss(lambda_CDC_weight = lambda_CDC_weight, Loss_Base_weight = Loss_Base_weight),
    eval_metric=eval_metric,
    max_epochs=max_epochs,
    patience=50,
    batch_size=512, virtual_batch_size=128, #default:batch_size=1024, virtual_batch_size=128
    num_workers=0,
    drop_last=False
) 

TSP_TabNet!
epoch 0  | loss: 2.61202 | train_AIR: 0.64139 | valid_AIR: 0.62844 |  0:00:03s
epoch 1  | loss: 1.97001 | train_AIR: 0.6779  | valid_AIR: 0.58486 |  0:00:04s
epoch 2  | loss: 1.52769 | train_AIR: 0.73315 | valid_AIR: 0.68578 |  0:00:05s
epoch 3  | loss: 1.14739 | train_AIR: 0.71723 | valid_AIR: 0.70872 |  0:00:07s
epoch 4  | loss: 0.91326 | train_AIR: 0.71255 | valid_AIR: 0.73165 |  0:00:08s
epoch 5  | loss: 0.94872 | train_AIR: 0.76873 | valid_AIR: 0.74541 |  0:00:09s
epoch 6  | loss: 0.7496  | train_AIR: 0.7706  | valid_AIR: 0.75229 |  0:00:11s
epoch 7  | loss: 0.66606 | train_AIR: 0.77154 | valid_AIR: 0.77294 |  0:00:12s
epoch 8  | loss: 0.56893 | train_AIR: 0.76592 | valid_AIR: 0.74083 |  0:00:13s
epoch 9  | loss: 0.51383 | train_AIR: 0.76124 | valid_AIR: 0.71101 |  0:00:15s
epoch 10 | loss: 0.47333 | train_AIR: 0.72378 | valid_AIR: 0.68807 |  0:00:16s
epoch 11 | loss: 0.46596 | train_AIR: 0.71536 | valid_AIR: 0.66284 |  0:00:17s
epoch 12 | loss: 0.42913 | train_AIR: 0.



## Saving weight of estimator

In [11]:
#save tabnet model
saved_filepath = clf.save_model(saving_weight_name)

Successfully saved model at c:\Users\yaoching\Desktop\ICCV_Smart_cascade_code\estimator_weights\weights_resnet18_resnet101\TabNet_L-Overall_AIR_Seed_1.zip


## Toy test

In [12]:
print(clf.predict(np.array([1.0,0.,0.,0.,0.,0.]+[0.0]*94).reshape(-1,100)))
print(clf.predict(np.array([0.9,0.1,0.,0.,0.,0.]+[0.0]*94).reshape(-1,100)))
print(clf.predict(np.array([0.8,0.2,0.,0.,0.,0.]+[0.0]*94).reshape(-1,100)))
print(clf.predict(np.array([0.7,0.3,0.,0.,0.,0.]+[0.0]*94).reshape(-1,100)))

[[1.0395687]]
[[0.98425615]]
[[0.9629064]]
[[0.65394646]]
