# Saint Model Architecture

Code from https://github.com/somepago/saint adapted for the dataset and notebook execution.

In [1]:
import os
from types import SimpleNamespace
import torch
import numpy as np
import pandas as pd
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

from saint.pretrainmodel import SAINT
from saint.data_openml import DataSetCatCon, data_split
from saint.pretraining import SAINT_pretrain
from saint.utils import count_parameters, classification_scores, mean_sq_error, get_scheduler
from saint.augmentations import embed_data_mask, add_noise

try:
    import wandb
except:
    wandb = None

In [2]:
opt = SimpleNamespace(**{
  'dset_seed': 42,
  'dset_id': 'v20,  # dataset version
  'run_name': 'train_v3',
  'cont_embeddings': 'MLP', # 'MLP', 'Noemb', 'pos_singleMLP'
  'attentiontype': 'colrow',  #  'col', 'colrow', 'row', 'justmlp', 'attn', 'attnmlp'
  'optimizer': 'AdamW',  # 'AdamW', 'Adam', 'SGD'
  'scheduler': 'cosine',  # 'cosine', 'linear'
  'embedding_size': 32,
  'transformer_depth': 6,
  'attention_heads': 8,
  'attention_dropout': 0.1,
  'ff_dropout': 0.1,
  'lr': 0.002,
  'epochs': 50,
  'batchsize': 2048,
  'pretrain': True,  # test with False # TODO
  'pretrain_epochs': 15,
  'savemodelroot': './bestmodels',
  'set_seed': 1,
  'active_log': True if wandb else False,  # Weights and Biases API for logging
  'pt_tasks': ['contrastive', 'denoising'],  # 'contrastive', 'contrastive_sim', 'denoising'
  'pt_aug': [],  # 'mixup', 'cutmix' (list)
  'pt_aug_lam': 0.1,
  'mixup_lam': 0.3,
  'train_noise_type': 'missing', # None, 'missing', 'cutmix'
  'train_noise_level': 0.01, 
  'ssl_samples': None,  # int or None
  'pt_projhead_style': 'diff',  # 'diff', 'same', 'nohead'
  'nce_temp': 0.7,
  'lam0': 0.5,
  'lam1': 10,
  'lam2': 1,
  'lam3': 10,
  'final_mlp_style': 'sep', # 'common', 'sep'
  'vision_dset': False,
  'task': 'regression', 'dtask': 'reg',
})
modelsave_path = os.path.join(os.getcwd(),opt.savemodelroot,opt.task,str(opt.dset_id),opt.run_name)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device is {device}.")
torch.manual_seed(opt.set_seed)
os.makedirs(modelsave_path, exist_ok=True)

Device is cuda:0.


In [3]:
if opt.active_log:
    try:
      import wandb
      if opt.train_noise_type is not None and opt.train_noise_level > 0:
          wandb.init(project="saint_v2_robustness", group =f'{opt.run_name}_{opt.task}' ,name = f'{opt.task}_{opt.train_noise_type}_{str(opt.train_noise_level)}_{str(opt.attentiontype)}_{str(opt.dset_id)}')
      elif opt.ssl_samples is not None:
          wandb.init(project="saint_v2_ssl", group = f'{opt.run_name}_{opt.task}' ,name = f'{opt.task}_{str(opt.ssl_samples)}_{str(opt.attentiontype)}_{str(opt.dset_id)}')
      else:
          raise'wrong config.check the file you are running'
      wandb.config.update(opt)
    except:
      opt.active_log = False

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mgabrui[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
def load_flight_data(dataset_version, dset_seed=42, valid_split=0.1):
    df = pd.read_csv(f'data/challenge_set_updated_{dataset_version}.csv')
    df_test = pd.read_csv(f'data/submission_set_updated_{dataset_version}.csv')
  
    train_indices, valid_indices = train_test_split(np.arange(len(df)), test_size=valid_split, random_state=dset_seed)
    test_indices = np.arange(len(df), len(df) + len(df_test))
    df_combined = pd.concat([df, df_test], ignore_index=True).replace([np.inf, -np.inf], np.nan)
    X, y = df_combined.drop(columns=['tow']), df_combined['tow']
    # categorical_indicator = df_combined.dtypes.apply(lambda x: not pd.api.types.is_numeric_dtype(x)).to_list()

    # categorical_columns = X.columns[list(np.where(np.array(categorical_indicator)==True)[0])].tolist()
    # print(categorical_columns)
    categorical_columns = ['aircraft_type', 
             'wtc', 
             'airline',
             'offblock_hour',
             'offblock_minute', 
             'offblock_day_of_week',
             'offblock_month',
             'offblock_week_of_year', 
             'offblock_season', 
             'arrival_hour',
             'arrival_minute',
             'is_offblock_weekend',
             'is_offblock_rush_hour',
             'flight_duration_category',                       
             'adep_region', 
             'ades_region', 
             'same_country_flight',
             'same_region_flight',                        
             'flight_direction',
             'is_intercontinental',
             'Manufacturer',
             'Model_FAA',
             'Physical_Class_Engine',
             'FAA_Weight',
             'adep_geo_cluster',
             'ades_geo_cluster']
    cont_columns = list(set(X.columns.tolist()) - set(categorical_columns))
    # cat_idxs = list(np.where(np.array(categorical_indicator)==True)[0])
    cat_idxs = [df_combined.columns.get_loc(col) for col in categorical_columns]
    con_idxs = list(set(range(len(X.columns))) - set(cat_idxs))
    for col in categorical_columns:
        X[col] = X[col].astype("object")

    temp = X.fillna("MissingValue")
    nan_mask = temp.ne("MissingValue").astype(int)
    
    cat_dims = []
    for col in categorical_columns:
    #     X[col] = X[col].cat.add_categories("MissingValue")
        X[col] = X[col].fillna("MissingValue")
        l_enc = LabelEncoder() 
        X[col] = l_enc.fit_transform(X[col].values)
        cat_dims.append(len(l_enc.classes_))
    for col in cont_columns:
    #     X[col].fillna("MissingValue",inplace=True)
        X.fillna(X.loc[train_indices, col].mean(), inplace=True)
    y = y.values
    y_min, y_max = y[train_indices].min()*0.95, y[train_indices].max()*1.05
    y = (y - y_min) / (y_max - y_min)
    X_train, y_train = data_split(X,y,nan_mask,train_indices)
    X_valid, y_valid = data_split(X,y,nan_mask,valid_indices)
    X_test, y_test = data_split(X,y,nan_mask,test_indices)

    train_mean, train_std = np.array(X_train['data'][:,con_idxs],dtype=np.float32).mean(0), np.array(X_train['data'][:,con_idxs],dtype=np.float32).std(0)
    train_std = np.where(train_std < 1e-6, 1e-6, train_std)
    # import ipdb; ipdb.set_trace()
    return cat_dims, cat_idxs, con_idxs, X_train, y_train, X_valid, y_valid, X_test, y_test, train_mean, train_std, y_min, y_max

cat_dims, cat_idxs, con_idxs, X_train, y_train, X_valid, y_valid, X_test, y_test, train_mean, train_std, y_min, y_max = load_flight_data(opt.dset_id, opt.dset_seed)
continuous_mean_std = np.array([train_mean,train_std]).astype(np.float32)
cat_dims = np.append(np.array([1]),np.array(cat_dims)).astype(int) #Appending 1 for CLS token, this is later used to generate embeddings.

In [5]:
if opt.attentiontype != 'col':
    opt.transformer_depth = 1
    opt.attention_heads = 4
    opt.attention_dropout = 0.8
    opt.embedding_size = 16
    if opt.optimizer =='SGD':
        opt.ff_dropout = 0.4
        opt.lr = 0.01
    else:
        opt.ff_dropout = 0.8
opt.__dict__

{'dset_seed': 42,
 'dset_id': 'v8',
 'run_name': 'train_v3',
 'cont_embeddings': 'MLP',
 'attentiontype': 'colrow',
 'optimizer': 'AdamW',
 'scheduler': 'cosine',
 'embedding_size': 16,
 'transformer_depth': 1,
 'attention_heads': 4,
 'attention_dropout': 0.8,
 'ff_dropout': 0.8,
 'lr': 0.002,
 'epochs': 50,
 'batchsize': 2048,
 'pretrain': True,
 'pretrain_epochs': 15,
 'savemodelroot': './bestmodels',
 'set_seed': 1,
 'active_log': True,
 'pt_tasks': ['contrastive', 'denoising'],
 'pt_aug': [],
 'pt_aug_lam': 0.1,
 'mixup_lam': 0.3,
 'train_noise_type': 'missing',
 'train_noise_level': 0.01,
 'ssl_samples': None,
 'pt_projhead_style': 'diff',
 'nce_temp': 0.7,
 'lam0': 0.5,
 'lam1': 10,
 'lam2': 1,
 'lam3': 10,
 'final_mlp_style': 'sep',
 'vision_dset': False,
 'task': 'regression',
 'dtask': 'reg'}

In [6]:
train_ds = DataSetCatCon(X_train, y_train, cat_idxs,opt.dtask,continuous_mean_std)
trainloader = DataLoader(train_ds, batch_size=opt.batchsize, shuffle=True,num_workers=8)

valid_ds = DataSetCatCon(X_valid, y_valid, cat_idxs,opt.dtask, continuous_mean_std)
validloader = DataLoader(valid_ds, batch_size=opt.batchsize, shuffle=False,num_workers=8)

test_ds = DataSetCatCon(X_test, y_test, cat_idxs,opt.dtask, continuous_mean_std)
testloader = DataLoader(test_ds, batch_size=opt.batchsize, shuffle=False,num_workers=8)

y_dim = 1 # opt.task 'regression'
criterion = nn.MSELoss().to(device)

model = SAINT(
    categories = tuple(cat_dims), 
    num_continuous = len(con_idxs),                
    dim = opt.embedding_size,                           
    dim_out = 1,                       
    depth = opt.transformer_depth,                       
    heads = opt.attention_heads,                         
    attn_dropout = opt.attention_dropout,             
    ff_dropout = opt.ff_dropout,                  
    mlp_hidden_mults = (4, 2),       
    cont_embeddings = opt.cont_embeddings,
    attentiontype = opt.attentiontype,
    final_mlp_style = opt.final_mlp_style,
    y_dim = y_dim
)
vision_dset = opt.vision_dset

# print(count_parameters(model))
# import ipdb; ipdb.set_trace()
model.to(device)

if opt.pretrain:
    model = SAINT_pretrain(model, cat_idxs,X_train,y_train, continuous_mean_std, opt,device)

if opt.ssl_samples is not None and opt.ssl_samples > 0 :
    print('We are in semi-supervised learning case')
    train_pts_touse = np.random.choice(X_train['data'].shape[0], opt.ssl_samples)
    X_train['data'] = X_train['data'][train_pts_touse,:]
    y_train['data'] = y_train['data'][train_pts_touse]
    
    X_train['mask'] = X_train['mask'][train_pts_touse,:]
    train_bsize = min(opt.ssl_samples//4,opt.batchsize)

    train_ds = DataSetCatCon(X_train, y_train, cat_idxs,opt.dtask,continuous_mean_std)
    trainloader = DataLoader(train_ds, batch_size=train_bsize, shuffle=True,num_workers=4)

Pretraining begins!
Epoch: 0, Running Loss: 4969.757315635681
Epoch: 1, Running Loss: 1311.4788708686829
Epoch: 2, Running Loss: 992.9838337898254
Epoch: 3, Running Loss: 958.7148609161377
Epoch: 4, Running Loss: 947.4848208427429
Epoch: 5, Running Loss: 940.8644299507141
Epoch: 6, Running Loss: 937.3305153846741
Epoch: 7, Running Loss: 933.893542766571
Epoch: 8, Running Loss: 932.4347143173218
Epoch: 9, Running Loss: 931.3027195930481
Epoch: 10, Running Loss: 930.0310459136963
Epoch: 11, Running Loss: 927.9718689918518
Epoch: 12, Running Loss: 927.7993550300598
Epoch: 13, Running Loss: 926.7460074424744
Epoch: 14, Running Loss: 926.265483379364
END OF PRETRAINING!


In [7]:
if opt.optimizer == 'SGD':
    optimizer = optim.SGD(model.parameters(), lr=opt.lr,
                          momentum=0.9, weight_decay=5e-4)
    scheduler = get_scheduler(opt, optimizer)
elif opt.optimizer == 'Adam':
    optimizer = optim.Adam(model.parameters(),lr=opt.lr)
elif opt.optimizer == 'AdamW':
    optimizer = optim.AdamW(model.parameters(),lr=opt.lr)
best_valid_auroc = 0
best_valid_accuracy = 0
best_test_auroc = 0
best_test_accuracy = 0
best_valid_rmse = 100000
patience, current_patience = 10, 0

for epoch in range(opt.epochs):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        optimizer.zero_grad()
        # x_categ is the the categorical data, with y appended as last feature. x_cont has continuous data. cat_mask is an array of ones same shape as x_categ except for last column(corresponding to y's) set to 0s. con_mask is an array of ones same shape as x_cont. 
        x_categ, x_cont, y_gts, cat_mask, con_mask = data[0].to(device), data[1].to(device),data[2].to(device),data[3].to(device),data[4].to(device)
        if opt.train_noise_type is not None and opt.train_noise_level>0:
            noise_dict = {
                'noise_type' : opt.train_noise_type,
                'lambda' : opt.train_noise_level
            }
            if opt.train_noise_type == 'cutmix':
                x_categ, x_cont = add_noise(x_categ,x_cont, noise_params = noise_dict)
            elif opt.train_noise_type == 'missing':
                cat_mask, con_mask = add_noise(cat_mask, con_mask, noise_params = noise_dict)
        # We are converting the data to embeddings in the next step
        _ , x_categ_enc, x_cont_enc = embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset)           
        reps = model.transformer(x_categ_enc, x_cont_enc)
        # select only the representations corresponding to y and apply mlp on it in the next step to get the predictions.
        y_reps = reps[:,0,:]
        
        y_outs = model.mlpfory(y_reps)
        loss = criterion(y_outs,y_gts) 
        loss.backward()
        optimizer.step()
        if opt.optimizer == 'SGD':
            scheduler.step()
        running_loss += loss.item()
    # print(running_loss)
    if opt.active_log:
        wandb.log({'epoch': epoch ,'train_epoch_loss': running_loss, 
        'loss': loss.item()
        })
    model.eval()
    with torch.no_grad():
        valid_rmse = mean_sq_error(model, validloader, device,vision_dset)    
        # test_rmse = mean_sq_error(model, testloader, device,vision_dset)  
        print('[EPOCH %d] TRAIN RMSE: %.3f   VALID RMSE: %.3f' %
            (epoch + 1, running_loss, valid_rmse ))
        # print('[EPOCH %d] TEST RMSE: %.3f' %
        #     (epoch + 1, test_rmse ))
        if opt.active_log:
            wandb.log({'valid_rmse': valid_rmse}) # ,'test_rmse': test_rmse })     
        if valid_rmse < best_valid_rmse:
            current_patience = 0
            best_valid_rmse = valid_rmse
            # best_test_rmse = test_rmse
            torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))
        else:
            current_patience += 1
    model.train()
    if current_patience == patience:
        break

[EPOCH 1] TRAIN RMSE: 1116.950   VALID RMSE: 0.149
[EPOCH 2] TRAIN RMSE: 3.734   VALID RMSE: 0.134
[EPOCH 3] TRAIN RMSE: 3.121   VALID RMSE: 0.150
[EPOCH 4] TRAIN RMSE: 3.042   VALID RMSE: 0.125
[EPOCH 5] TRAIN RMSE: 2.435   VALID RMSE: 0.095
[EPOCH 6] TRAIN RMSE: 1.991   VALID RMSE: 0.121
[EPOCH 7] TRAIN RMSE: 1.864   VALID RMSE: 0.075
[EPOCH 8] TRAIN RMSE: 1.402   VALID RMSE: 0.073
[EPOCH 9] TRAIN RMSE: 1.125   VALID RMSE: 0.076
[EPOCH 10] TRAIN RMSE: 1.135   VALID RMSE: 0.073
[EPOCH 11] TRAIN RMSE: 0.693   VALID RMSE: 0.052
[EPOCH 12] TRAIN RMSE: 0.579   VALID RMSE: 0.049
[EPOCH 13] TRAIN RMSE: 0.519   VALID RMSE: 0.047
[EPOCH 14] TRAIN RMSE: 0.495   VALID RMSE: 0.047
[EPOCH 15] TRAIN RMSE: 0.487   VALID RMSE: 0.049
[EPOCH 16] TRAIN RMSE: 0.437   VALID RMSE: 0.047
[EPOCH 17] TRAIN RMSE: 0.454   VALID RMSE: 0.044
[EPOCH 18] TRAIN RMSE: 0.404   VALID RMSE: 0.040
[EPOCH 19] TRAIN RMSE: 0.346   VALID RMSE: 0.034
[EPOCH 20] TRAIN RMSE: 0.262   VALID RMSE: 0.032
[EPOCH 21] TRAIN RMSE: 0.2

In [12]:
model.load_state_dict(torch.load('%s/bestmodel.pth' % (modelsave_path)))
total_parameters = count_parameters(model)
print('TOTAL NUMBER OF PARAMS: %d' %(total_parameters))
print('RMSE on best model:  %.3f' %(best_valid_rmse * (y_max - y_min)))
if opt.active_log:
    wandb.log({'total_parameters': total_parameters, 'best_valid_rmse':best_valid_rmse, 
    'cat_dims':len(cat_idxs) , 'con_dims':len(con_idxs) })

TOTAL NUMBER OF PARAMS: 25532070
RMSE on best model:  5072.786


In [13]:
def get_preds(model, dloader, device, vision_dset):
    model.eval()
    y_test = torch.empty(0).to(device)
    y_pred = torch.empty(0).to(device)
    with torch.no_grad():
        for i, data in enumerate(dloader, 0):
            x_categ, x_cont, y_gts, cat_mask, con_mask = data[0].to(device), data[1].to(device),data[2].to(device),data[3].to(device),data[4].to(device)
            _ , x_categ_enc, x_cont_enc = embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset)           
            reps = model.transformer(x_categ_enc, x_cont_enc)
            y_reps = reps[:,0,:]
            y_outs = model.mlpfory(y_reps)
            y_test = torch.cat([y_test,y_gts.squeeze().float()],dim=0)
            y_pred = torch.cat([y_pred,y_outs],dim=0)
        return y_pred.cpu().numpy(), y_test.cpu().numpy()

valid_preds, valid_test = get_preds(model, validloader, device, vision_dset)
test_preds, _ = get_preds(model, testloader, device, vision_dset)

In [14]:
dft0 = pd.read_csv('./data/submission_set.csv')
dft0['tow'] = test_preds * (y_max - y_min) + y_min
dft0[['flight_id', 'tow']].to_csv('saint.csv', index=False)
pd.DataFrame.from_dict({'tow': valid_preds.numpy() * (y_max-y_min) + y_min}).to_csv('saint_val.csv', index=False)