# Saint Model Architecture

Code from https://github.com/somepago/saint adapted for the dataset and notebook execution.

In [None]:
import torch
from torch import nn
from saint.pretrainmodel import SAINT
from saint.data_openml import data_prep_openml, task_dset_ids, DataSetCatCon
from saint.pretraining import SAINT_pretrain
import argparse
from torch.utils.data import DataLoader
import torch.optim as optim
from saint.utils import count_parameters, classification_scores, mean_sq_error, get_scheduler
from saint.augmentations import embed_data_mask, add_noise

import os
import numpy as np
from types import SimpleNamespace

In [None]:
opt = SimpleNamespace(**{
  'dset_id': 'v8',  # dataset version
  'vision_dset': False,
  'task': 'regression', 'dtask': 'reg',
  'cont_embeddings': 'MLP', # 'MLP', 'Noemb', 'pos_singleMLP'
  'embedding_size': 32,
  'transformer_depth': 6,
  'attention_heads': 8,
  'attention_dropout': 0.1,
  'ff_dropout': 0.1,
  'attentiontype': 'colrow',  #  'col', 'colrow', 'row', 'justmlp', 'attn', 'attnmlp'
  'optimizer': 'AdamW',  # 'AdamW', 'Adam', 'SGD'
  'scheduler': 'cosine',  # 'cosine', 'linear'
  'lr': 0.0001,
  'epochs': 100,
  'batchsize': 256,
  'savemodelroot': './bestmodels',
  'run_name': 'testrun',
  'set_seed': 1,
  'dset_seed': 1,
  'active_log': True,  # wandb
  'pretrain': True,  # test with False # TODO
  'pretrain_epochs': 50,
  'pt_tasks': ['contrastive', 'denoising'],  # 'contrastive', 'contrastive_sim', 'denoising'
  'pt_aug': [],  # 'mixup', 'cutmix' (list)
  'pt_aug_lam': 0.1,
  'mixup_lam': 0.3,
  'train_noise_type': 'missing', # None, 'missing', 'cutmix'
  'train_noise_level': 0.01, 
  'ssl_samples': None,  # int or None
  'pt_projhead_style': 'diff',  # 'diff', 'same', 'nohead'
  'nce_temp': 0.7,
  'lam0': 0.5,
  'lam1': 10,
  'lam2': 1,
  'lam3': 10,
  'final_mlp_style': 'sep' # 'common', 'sep'
})
modelsave_path = os.path.join(os.getcwd(),opt.savemodelroot,opt.task,str(opt.dset_id),opt.run_name)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device is {device}.")
torch.manual_seed(opt.set_seed)
os.makedirs(modelsave_path, exist_ok=True)

In [None]:
if opt.active_log:
    import wandb
    if opt.train_noise_type is not None and opt.train_noise_level > 0:
        wandb.init(project="saint_v2_robustness", group =f'{opt.run_name}_{opt.task}' ,name = f'{opt.task}_{opt.train_noise_type}_{str(opt.train_noise_level)}_{str(opt.attentiontype)}_{str(opt.dset_id)}')
    elif opt.ssl_samples is not None:
        wandb.init(project="saint_v2_ssl", group = f'{opt.run_name}_{opt.task}' ,name = f'{opt.task}_{str(opt.ssl_samples)}_{str(opt.attentiontype)}_{str(opt.dset_id)}')
    else:
        raise'wrong config.check the file you are running'
    wandb.config.update(opt)

In [None]:
def load_flight_data(filepath, datasplit=[.80, .1, .1]):
    return None
    # Load the dataset
    df = pd.read_csv(filepath)

    # Convert the date columns to datetime objects
    df['date'] = pd.to_datetime(df['date'])
    df['actual_offblock_time'] = pd.to_datetime(df['actual_offblock_time'])
    df['arrival_time'] = pd.to_datetime(df['arrival_time'])

    # Feature engineering: Extracting month, day of the week, and hour (rounded to the nearest hour)
    df['offblock_hour'] = df['actual_offblock_time'].dt.round('H').dt.hour
    df['offblock_day_of_week'] = df['actual_offblock_time'].dt.dayofweek
    df['offblock_month'] = df['actual_offblock_time'].dt.month

    # Convert the new features to categorical data types
    df['offblock_hour'] = df['offblock_hour'].astype('category')
    df['offblock_day_of_week'] = df['offblock_day_of_week'].astype('category')
    df['offblock_month'] = df['offblock_month'].astype('category')

    # Drop unnecessary columns
    df = df.drop(columns=['flight_id','name_adep', 'callsign', 'actual_offblock_time', 'arrival_time', 'date'])
        
    # Define categorical and continuous columns
    categorical_cols = ['adep', 'country_code_adep', 'ades', 'name_ades', 'country_code_ades',
                        'aircraft_type', 'wtc', 'airline', 'offblock_hour', 'offblock_day_of_week', 'offblock_month']
    continuous_cols = ['flight_duration', 'taxiout_time', 'flown_distance', 'tow']

    # Handle missing values
    df = df.fillna('MissingValue')

    # Encode categorical columns
    cat_dims = []
    for col in categorical_cols:
        le = LabelEncoder()
        df[col] = le.fit_transform(df[col].astype(str))
        cat_dims.append(df[col].nunique())
    
    # Normalize continuous columns
    scaler = StandardScaler()
    df[continuous_cols] = scaler.fit_transform(df[continuous_cols])
    
    # Target column (tow - Take Off Weight)
    y = df['tow'].values

    # Drop unnecessary columns
    df = df.drop(columns=['flight_id', 'tow'])
    
    # Split the dataset into train, validation, and test sets
    X_train, X_temp, y_train, y_temp = train_test_split(df, y, test_size=1 - datasplit[0], random_state=42)
    X_valid, X_test, y_valid, y_test = train_test_split(X_temp, y_temp, test_size=datasplit[2] / (datasplit[1] + datasplit[2]), random_state=42)
    
    # Define indices of categorical and continuous columns
    cat_idxs = [X_train.columns.get_loc(col) for col in categorical_cols]
    con_idxs = [X_train.columns.get_loc(col) for col in continuous_cols]
    
    # Calculate mean and std for continuous columns in the training set
    train_mean, train_std = X_train[continuous_cols].mean().values, X_train[continuous_cols].std().values
    train_std = np.where(train_std < 1e-6, 1e-6, train_std)

    return cat_dims, cat_idxs, con_idxs, X_train, y_train, X_valid, y_valid, X_test, y_test, train_mean, train_std

cat_dims, cat_idxs, con_idxs, X_train, y_train, X_valid, y_valid, X_test, y_test, train_mean, train_std = \
  load_flight_data(opt.dset_id, opt.dset_seed,opt.task, datasplit=[.65, .15, .2])
continuous_mean_std = np.array([train_mean,train_std]).astype(np.float32)
cat_dims = np.append(np.array([1]),np.array(cat_dims)).astype(int) #Appending 1 for CLS token, this is later used to generate embeddings.

In [None]:
##### Setting some hyperparams based on inputs and dataset
_,nfeat = X_train['data'].shape
if nfeat > 100:
    opt.embedding_size = min(4,opt.embedding_size)
    opt.batchsize = min(64, opt.batchsize)
if opt.attentiontype != 'col':
    opt.transformer_depth = 1
    opt.attention_heads = 4
    opt.attention_dropout = 0.8
    opt.embedding_size = 16
    if opt.optimizer =='SGD':
        opt.ff_dropout = 0.4
        opt.lr = 0.01
    else:
        opt.ff_dropout = 0.8

opt.batchsize = 2048 # TODO optimize

nfeat, opt.batchsize, opt

In [None]:
train_ds = DataSetCatCon(X_train, y_train, cat_idxs,opt.dtask,continuous_mean_std)
trainloader = DataLoader(train_ds, batch_size=opt.batchsize, shuffle=True,num_workers=4)

valid_ds = DataSetCatCon(X_valid, y_valid, cat_idxs,opt.dtask, continuous_mean_std)
validloader = DataLoader(valid_ds, batch_size=opt.batchsize, shuffle=False,num_workers=4)

test_ds = DataSetCatCon(X_test, y_test, cat_idxs,opt.dtask, continuous_mean_std)
testloader = DataLoader(test_ds, batch_size=opt.batchsize, shuffle=False,num_workers=4)

y_dim = 1 # opt.task 'regression'
criterion = nn.MSELoss().to(device)

model = SAINT(
    categories = tuple(cat_dims), 
    num_continuous = len(con_idxs),                
    dim = opt.embedding_size,                           
    dim_out = 1,                       
    depth = opt.transformer_depth,                       
    heads = opt.attention_heads,                         
    attn_dropout = opt.attention_dropout,             
    ff_dropout = opt.ff_dropout,                  
    mlp_hidden_mults = (4, 2),       
    cont_embeddings = opt.cont_embeddings,
    attentiontype = opt.attentiontype,
    final_mlp_style = opt.final_mlp_style,
    y_dim = y_dim
)
vision_dset = opt.vision_dset

# print(count_parameters(model))
# import ipdb; ipdb.set_trace()
model.to(device)

if opt.pretrain:
    model = SAINT_pretrain(model, cat_idxs,X_train,y_train, continuous_mean_std, opt,device)

if opt.ssl_samples is not None and opt.ssl_samples > 0 :
    print('We are in semi-supervised learning case')
    train_pts_touse = np.random.choice(X_train['data'].shape[0], opt.ssl_samples)
    X_train['data'] = X_train['data'][train_pts_touse,:]
    y_train['data'] = y_train['data'][train_pts_touse]
    
    X_train['mask'] = X_train['mask'][train_pts_touse,:]
    train_bsize = min(opt.ssl_samples//4,opt.batchsize)

    train_ds = DataSetCatCon(X_train, y_train, cat_idxs,opt.dtask,continuous_mean_std)
    trainloader = DataLoader(train_ds, batch_size=train_bsize, shuffle=True,num_workers=4)

In [None]:
if opt.optimizer == 'SGD':
    optimizer = optim.SGD(model.parameters(), lr=opt.lr,
                          momentum=0.9, weight_decay=5e-4)
    scheduler = get_scheduler(opt, optimizer)
elif opt.optimizer == 'Adam':
    optimizer = optim.Adam(model.parameters(),lr=opt.lr)
elif opt.optimizer == 'AdamW':
    optimizer = optim.AdamW(model.parameters(),lr=opt.lr)
best_valid_auroc = 0
best_valid_accuracy = 0
best_test_auroc = 0
best_test_accuracy = 0
best_valid_rmse = 100000

for epoch in range(opt.epochs):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        optimizer.zero_grad()
        # x_categ is the the categorical data, with y appended as last feature. x_cont has continuous data. cat_mask is an array of ones same shape as x_categ except for last column(corresponding to y's) set to 0s. con_mask is an array of ones same shape as x_cont. 
        x_categ, x_cont, y_gts, cat_mask, con_mask = data[0].to(device), data[1].to(device),data[2].to(device),data[3].to(device),data[4].to(device)
        if opt.train_noise_type is not None and opt.train_noise_level>0:
            noise_dict = {
                'noise_type' : opt.train_noise_type,
                'lambda' : opt.train_noise_level
            }
            if opt.train_noise_type == 'cutmix':
                x_categ, x_cont = add_noise(x_categ,x_cont, noise_params = noise_dict)
            elif opt.train_noise_type == 'missing':
                cat_mask, con_mask = add_noise(cat_mask, con_mask, noise_params = noise_dict)
        # We are converting the data to embeddings in the next step
        _ , x_categ_enc, x_cont_enc = embed_data_mask(x_categ, x_cont, cat_mask, con_mask,model,vision_dset)           
        reps = model.transformer(x_categ_enc, x_cont_enc)
        # select only the representations corresponding to y and apply mlp on it in the next step to get the predictions.
        y_reps = reps[:,0,:]
        
        y_outs = model.mlpfory(y_reps)
        if opt.task == 'regression':
            loss = criterion(y_outs,y_gts) 
        else:
            loss = criterion(y_outs,y_gts.squeeze()) 
        loss.backward()
        optimizer.step()
        if opt.optimizer == 'SGD':
            scheduler.step()
        running_loss += loss.item()
    # print(running_loss)
    if opt.active_log:
        wandb.log({'epoch': epoch ,'train_epoch_loss': running_loss, 
        'loss': loss.item()
        })
    if epoch%5==0:
            model.eval()
            with torch.no_grad():
                if opt.task in ['binary','multiclass']:
                    accuracy, auroc = classification_scores(model, validloader, device, opt.task,vision_dset)
                    test_accuracy, test_auroc = classification_scores(model, testloader, device, opt.task,vision_dset)

                    print('[EPOCH %d] VALID ACCURACY: %.3f, VALID AUROC: %.3f' %
                        (epoch + 1, accuracy,auroc ))
                    print('[EPOCH %d] TEST ACCURACY: %.3f, TEST AUROC: %.3f' %
                        (epoch + 1, test_accuracy,test_auroc ))
                    if opt.active_log:
                        wandb.log({'valid_accuracy': accuracy ,'valid_auroc': auroc })     
                        wandb.log({'test_accuracy': test_accuracy ,'test_auroc': test_auroc })  
                    if opt.task =='multiclass':
                        if accuracy > best_valid_accuracy:
                            best_valid_accuracy = accuracy
                            best_test_auroc = test_auroc
                            best_test_accuracy = test_accuracy
                            torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))
                    else:
                        if auroc > best_valid_auroc:
                            best_valid_auroc = auroc
                            best_test_auroc = test_auroc
                            best_test_accuracy = test_accuracy               
                            torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))

                else:
                    valid_rmse = mean_sq_error(model, validloader, device,vision_dset)    
                    test_rmse = mean_sq_error(model, testloader, device,vision_dset)  
                    print('[EPOCH %d] VALID RMSE: %.3f' %
                        (epoch + 1, valid_rmse ))
                    print('[EPOCH %d] TEST RMSE: %.3f' %
                        (epoch + 1, test_rmse ))
                    if opt.active_log:
                        wandb.log({'valid_rmse': valid_rmse ,'test_rmse': test_rmse })     
                    if valid_rmse < best_valid_rmse:
                        best_valid_rmse = valid_rmse
                        best_test_rmse = test_rmse
                        torch.save(model.state_dict(),'%s/bestmodel.pth' % (modelsave_path))
            model.train()

In [None]:
total_parameters = count_parameters(model)
print('TOTAL NUMBER OF PARAMS: %d' %(total_parameters))
if opt.task =='binary':
    print('AUROC on best model:  %.3f' %(best_test_auroc))
elif opt.task =='multiclass':
    print('Accuracy on best model:  %.3f' %(best_test_accuracy))
else:
    print('RMSE on best model:  %.3f' %(best_test_rmse))

if opt.active_log:
    if opt.task == 'regression':
        wandb.log({'total_parameters': total_parameters, 'test_rmse_bestep':best_test_rmse , 
        'cat_dims':len(cat_idxs) , 'con_dims':len(con_idxs) })        
    else:
        wandb.log({'total_parameters': total_parameters, 'test_auroc_bestep':best_test_auroc , 
        'test_accuracy_bestep':best_test_accuracy,'cat_dims':len(cat_idxs) , 'con_dims':len(con_idxs) })