In [2]:
import os
import torch
import numpy as np
from tqdm import tqdm
import optuna
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn import metrics
import random
from pytorch_tabnet.tab_model import TabNetRegressor
from scipy import stats
import re
from pandas.core.frame import DataFrame
import warnings
warnings.filterwarnings("ignore")

In [2]:
def calc_metrics(y_pred, y_test):
    mse = metrics.mean_squared_error(y_test, y_pred)
    rmse = metrics.mean_squared_error(y_test, y_pred, squared=False)
    r_2 = metrics.r2_score(y_test, y_pred)
    mae = metrics.mean_absolute_error(y_test, y_pred)
    pcc, _ = stats.pearsonr(y_test, y_pred)
    return [mse, rmse, r_2, mae, pcc]

In [33]:
def average_predict(x: np.array, all_models: list) -> float:
    all_pred = []

    for model in all_models:
        pred = model.predict(x)
        all_pred.append(pred)

    all_pred = np.concatenate(all_pred, axis=1)
    tm = all_pred.mean(axis=1).flatten()
    return tm

In [3]:
def print_metrics(**kwargs):
    metrics = pd.DataFrame(kwargs, index = [0])
    return metrics

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [21]:
def process_repr(repr):
    digits = re.findall(r'[-+]?\d*\.?\d+', repr)
    float_repr = [float(item) for item in digits]
    return float_repr

In [22]:
train = pd.read_csv('../data/train.csv')
full_test = pd.read_csv('../data/blind_test.csv')


In [23]:
train['repr'] = train['repr'].apply(process_repr)
full_test['repr'] = full_test['repr'].apply(process_repr)


In [19]:
SEED = 42
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = True


seed_everything(SEED)

In [20]:
def process_features(chunk: DataFrame) -> tuple[np.ndarray, np.ndarray]:
    x = chunk.loc[:, 'repr']
    x = np.array(list(x))
    y = chunk.loc[:, 'Tm']
    y = np.array(list(y)).reshape(-1, 1)
    return x, y

In [23]:
x, y = process_features(train)

In [36]:
x_test, y_test = process_features(full_test)

In [24]:
X_train, X_valid, y_train, y_valid = train_test_split(x, y, test_size=0.33, random_state=SEED)

In [21]:
MAX_EPOCHS = 300
PATIENCE = 30
BATCH_SIZE = 1024

##### Perform hyperparameters search

In [68]:
def objective(trial):
    n_shared = trial.suggest_int('n_shared', 1, 5)
    n_independent = trial.suggest_int('n_independent', 1, 5)
    n_d = trial.suggest_int('n_d', 8, 64)
    n_a = trial.suggest_int('n_a', 8, 64)
    n_steps = trial.suggest_int('n_steps', 3, 10)
    gamma = trial.suggest_float('gamma', 1.0, 2.0)
    momentum = trial.suggest_float('momentum', 0.1, 0.9)
    lambda_sparse = trial.suggest_loguniform('lambda_sparse', 1e-5, 1e-1)
    
    
    model = TabNetRegressor(n_d=n_d, n_a=n_a, n_independent=n_independent, n_shared=n_shared, n_steps=n_steps, gamma=gamma, momentum=momentum, 
                            lambda_sparse=lambda_sparse, verbose=0, seed=SEED, device_name=device, optimizer_fn=torch.optim.Adam, 
                            scheduler_fn=torch.optim.lr_scheduler.StepLR, scheduler_params={'step_size': 10, 'gamma': 0.9})

    model.fit(X_train, y_train, eval_set=[(X_valid, y_valid)], patience=PATIENCE, max_epochs=MAX_EPOCHS, batch_size=BATCH_SIZE)
    y_pred = model.predict(X_valid)
    mse = metrics.mean_squared_error(y_valid, y_pred)
    return mse

In [None]:
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=70)
for i, trial in enumerate(study.trials):
    print(f'Iteration {i+1}: Best value = {study.best_value}')

[32m[I 2023-07-01 05:05:17,398][0m A new study created in memory with name: no-name-69076c3a-bea7-4304-af82-59d1a219d28e[0m



Early stopping occurred at epoch 114 with best_epoch = 84 and best_val_0_mse = 37.45805


[32m[I 2023-07-01 05:12:11,056][0m Trial 0 finished with value: 37.45805462099855 and parameters: {'n_shared': 1, 'n_independent': 3, 'n_d': 32, 'n_a': 53, 'n_steps': 10, 'gamma': 1.5843747996502784, 'momentum': 0.11172385993515209, 'lambda_sparse': 0.05430249254142419}. Best is trial 0 with value: 37.45805462099855.[0m



Early stopping occurred at epoch 103 with best_epoch = 73 and best_val_0_mse = 37.71494


[32m[I 2023-07-01 05:17:25,260][0m Trial 1 finished with value: 37.71494182874437 and parameters: {'n_shared': 3, 'n_independent': 3, 'n_d': 44, 'n_a': 32, 'n_steps': 6, 'gamma': 1.0515222749271895, 'momentum': 0.5236862805850789, 'lambda_sparse': 0.007550418166087127}. Best is trial 0 with value: 37.45805462099855.[0m



Early stopping occurred at epoch 85 with best_epoch = 55 and best_val_0_mse = 37.96869


[32m[I 2023-07-01 05:20:19,798][0m Trial 2 finished with value: 37.96869262736749 and parameters: {'n_shared': 2, 'n_independent': 1, 'n_d': 48, 'n_a': 37, 'n_steps': 6, 'gamma': 1.5478184581139967, 'momentum': 0.8974413782400997, 'lambda_sparse': 0.006302007357849685}. Best is trial 0 with value: 37.45805462099855.[0m



Early stopping occurred at epoch 82 with best_epoch = 52 and best_val_0_mse = 37.59423


[32m[I 2023-07-01 05:24:18,422][0m Trial 3 finished with value: 37.59422894933116 and parameters: {'n_shared': 5, 'n_independent': 2, 'n_d': 37, 'n_a': 9, 'n_steps': 5, 'gamma': 1.7710284039092177, 'momentum': 0.5308764788870531, 'lambda_sparse': 5.043860676264641e-05}. Best is trial 0 with value: 37.45805462099855.[0m



Early stopping occurred at epoch 72 with best_epoch = 42 and best_val_0_mse = 37.21356


[32m[I 2023-07-01 05:28:54,062][0m Trial 4 finished with value: 37.21355574017806 and parameters: {'n_shared': 5, 'n_independent': 5, 'n_d': 63, 'n_a': 12, 'n_steps': 5, 'gamma': 1.4111484359834034, 'momentum': 0.3592637810142343, 'lambda_sparse': 1.2038760187881758e-05}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 100 with best_epoch = 70 and best_val_0_mse = 38.00593


[32m[I 2023-07-01 05:34:31,412][0m Trial 5 finished with value: 38.005933372123195 and parameters: {'n_shared': 4, 'n_independent': 3, 'n_d': 21, 'n_a': 56, 'n_steps': 6, 'gamma': 1.071205384137437, 'momentum': 0.21877144426566197, 'lambda_sparse': 0.0005301323784886918}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 124 with best_epoch = 94 and best_val_0_mse = 38.10373


[32m[I 2023-07-01 05:41:46,609][0m Trial 6 finished with value: 38.10373140167757 and parameters: {'n_shared': 2, 'n_independent': 2, 'n_d': 32, 'n_a': 60, 'n_steps': 10, 'gamma': 1.481117673343563, 'momentum': 0.5746501909838163, 'lambda_sparse': 1.4579951229953682e-05}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 103 with best_epoch = 73 and best_val_0_mse = 37.99272


[32m[I 2023-07-01 05:50:38,530][0m Trial 7 finished with value: 37.99271507133866 and parameters: {'n_shared': 2, 'n_independent': 5, 'n_d': 27, 'n_a': 34, 'n_steps': 10, 'gamma': 1.5633104428754283, 'momentum': 0.7377936784160728, 'lambda_sparse': 0.028270599884157255}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 98 with best_epoch = 68 and best_val_0_mse = 37.94667


[32m[I 2023-07-01 06:00:42,141][0m Trial 8 finished with value: 37.946672439354664 and parameters: {'n_shared': 5, 'n_independent': 5, 'n_d': 43, 'n_a': 36, 'n_steps': 9, 'gamma': 1.570819955608183, 'momentum': 0.700835408418312, 'lambda_sparse': 0.04787312322229607}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 97 with best_epoch = 67 and best_val_0_mse = 37.46938


[32m[I 2023-07-01 06:04:24,396][0m Trial 9 finished with value: 37.469379445709535 and parameters: {'n_shared': 3, 'n_independent': 2, 'n_d': 11, 'n_a': 18, 'n_steps': 5, 'gamma': 1.440987406437393, 'momentum': 0.7896673033153728, 'lambda_sparse': 0.00010359274893013226}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 70 with best_epoch = 40 and best_val_0_mse = 37.83639


[32m[I 2023-07-01 06:07:08,440][0m Trial 10 finished with value: 37.83639098745498 and parameters: {'n_shared': 4, 'n_independent': 4, 'n_d': 63, 'n_a': 20, 'n_steps': 3, 'gamma': 1.284413309811417, 'momentum': 0.31205621835500663, 'lambda_sparse': 1.1307171659231187e-05}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 85 with best_epoch = 55 and best_val_0_mse = 37.77817


[32m[I 2023-07-01 06:12:02,639][0m Trial 11 finished with value: 37.778173308378804 and parameters: {'n_shared': 1, 'n_independent': 4, 'n_d': 64, 'n_a': 46, 'n_steps': 8, 'gamma': 1.9170886130994127, 'momentum': 0.11007841794509877, 'lambda_sparse': 0.0008123953803892716}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 96 with best_epoch = 66 and best_val_0_mse = 37.46122


[32m[I 2023-07-01 06:17:36,966][0m Trial 12 finished with value: 37.4612236608359 and parameters: {'n_shared': 1, 'n_independent': 4, 'n_d': 55, 'n_a': 46, 'n_steps': 8, 'gamma': 1.3092656732279955, 'momentum': 0.35317575499570003, 'lambda_sparse': 0.08117411370479448}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 69 with best_epoch = 39 and best_val_0_mse = 38.05818


[32m[I 2023-07-01 06:20:33,239][0m Trial 13 finished with value: 38.05818057988738 and parameters: {'n_shared': 4, 'n_independent': 5, 'n_d': 20, 'n_a': 52, 'n_steps': 3, 'gamma': 1.7221689476291813, 'momentum': 0.11475931898384872, 'lambda_sparse': 0.002971144787371413}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 73 with best_epoch = 43 and best_val_0_mse = 37.28311


[32m[I 2023-07-01 06:22:49,562][0m Trial 14 finished with value: 37.28310577838472 and parameters: {'n_shared': 1, 'n_independent': 3, 'n_d': 52, 'n_a': 8, 'n_steps': 4, 'gamma': 1.7002647775127337, 'momentum': 0.39155951341125883, 'lambda_sparse': 0.00023419314534341536}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 74 with best_epoch = 44 and best_val_0_mse = 38.01205


[32m[I 2023-07-01 06:26:34,452][0m Trial 15 finished with value: 38.012054748839795 and parameters: {'n_shared': 5, 'n_independent': 4, 'n_d': 55, 'n_a': 9, 'n_steps': 4, 'gamma': 1.9248846684965755, 'momentum': 0.4096907853354161, 'lambda_sparse': 0.00017667336771239585}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 70 with best_epoch = 40 and best_val_0_mse = 37.83311


[32m[I 2023-07-01 06:28:43,979][0m Trial 16 finished with value: 37.83310917309217 and parameters: {'n_shared': 3, 'n_independent': 1, 'n_d': 55, 'n_a': 19, 'n_steps': 4, 'gamma': 1.7446099793271768, 'momentum': 0.4327367557040908, 'lambda_sparse': 4.4638426743763165e-05}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 70 with best_epoch = 40 and best_val_0_mse = 37.91974


[32m[I 2023-07-01 06:30:53,995][0m Trial 17 finished with value: 37.91973517944574 and parameters: {'n_shared': 2, 'n_independent': 2, 'n_d': 49, 'n_a': 26, 'n_steps': 4, 'gamma': 1.3851320903346374, 'momentum': 0.27742849516423596, 'lambda_sparse': 0.00030351769979137594}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 102 with best_epoch = 72 and best_val_0_mse = 37.28944


[32m[I 2023-07-01 06:38:05,411][0m Trial 18 finished with value: 37.289437557187576 and parameters: {'n_shared': 4, 'n_independent': 4, 'n_d': 60, 'n_a': 14, 'n_steps': 7, 'gamma': 1.6658786633125056, 'momentum': 0.38045905887220716, 'lambda_sparse': 3.526875437479875e-05}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 97 with best_epoch = 67 and best_val_0_mse = 38.12808


[32m[I 2023-07-01 06:43:20,559][0m Trial 19 finished with value: 38.128077714478756 and parameters: {'n_shared': 5, 'n_independent': 3, 'n_d': 58, 'n_a': 25, 'n_steps': 5, 'gamma': 1.8257484251558926, 'momentum': 0.43672930159119683, 'lambda_sparse': 0.0001257056884844031}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 68 with best_epoch = 38 and best_val_0_mse = 38.13018


[32m[I 2023-07-01 06:45:57,786][0m Trial 20 finished with value: 38.13018338832778 and parameters: {'n_shared': 3, 'n_independent': 5, 'n_d': 51, 'n_a': 8, 'n_steps': 3, 'gamma': 1.664714479679484, 'momentum': 0.2616963113368859, 'lambda_sparse': 0.000378532036902018}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 98 with best_epoch = 68 and best_val_0_mse = 38.22152


[32m[I 2023-07-01 06:52:47,293][0m Trial 21 finished with value: 38.221515045600626 and parameters: {'n_shared': 4, 'n_independent': 4, 'n_d': 60, 'n_a': 14, 'n_steps': 7, 'gamma': 1.6639882524282332, 'momentum': 0.35797222840753984, 'lambda_sparse': 2.769105512352424e-05}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 123 with best_epoch = 93 and best_val_0_mse = 37.74256


[32m[I 2023-07-01 07:01:14,240][0m Trial 22 finished with value: 37.74256288516371 and parameters: {'n_shared': 4, 'n_independent': 4, 'n_d': 59, 'n_a': 14, 'n_steps': 7, 'gamma': 1.8518357581597114, 'momentum': 0.4680306366715444, 'lambda_sparse': 2.4013862238823508e-05}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 79 with best_epoch = 49 and best_val_0_mse = 37.65604


[32m[I 2023-07-01 07:06:13,553][0m Trial 23 finished with value: 37.656035573296876 and parameters: {'n_shared': 5, 'n_independent': 5, 'n_d': 43, 'n_a': 13, 'n_steps': 5, 'gamma': 1.9920902554898483, 'momentum': 0.36741117313482624, 'lambda_sparse': 6.240462578962895e-05}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 90 with best_epoch = 60 and best_val_0_mse = 37.58799


[32m[I 2023-07-01 07:11:59,489][0m Trial 24 finished with value: 37.58799365854912 and parameters: {'n_shared': 4, 'n_independent': 3, 'n_d': 64, 'n_a': 25, 'n_steps': 7, 'gamma': 1.6524657381322398, 'momentum': 0.21482228229981615, 'lambda_sparse': 2.100205785031453e-05}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 71 with best_epoch = 41 and best_val_0_mse = 37.66317


[32m[I 2023-07-01 07:15:31,240][0m Trial 25 finished with value: 37.66317282600975 and parameters: {'n_shared': 5, 'n_independent': 4, 'n_d': 51, 'n_a': 13, 'n_steps': 4, 'gamma': 1.5002644027968381, 'momentum': 0.482517665697767, 'lambda_sparse': 8.726361405203159e-05}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 138 with best_epoch = 108 and best_val_0_mse = 38.10154


[32m[I 2023-07-01 07:27:12,099][0m Trial 26 finished with value: 38.10154335377432 and parameters: {'n_shared': 4, 'n_independent': 5, 'n_d': 57, 'n_a': 22, 'n_steps': 8, 'gamma': 1.6631041455058866, 'momentum': 0.40536469703177946, 'lambda_sparse': 1.0839065693192679e-05}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 94 with best_epoch = 64 and best_val_0_mse = 37.63547


[32m[I 2023-07-01 07:32:00,399][0m Trial 27 finished with value: 37.635468523751435 and parameters: {'n_shared': 3, 'n_independent': 3, 'n_d': 39, 'n_a': 16, 'n_steps': 6, 'gamma': 1.7783352509013708, 'momentum': 0.32256393682391327, 'lambda_sparse': 0.000216111221450238}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 82 with best_epoch = 52 and best_val_0_mse = 37.49924


[32m[I 2023-07-01 07:36:51,328][0m Trial 28 finished with value: 37.49924035024813 and parameters: {'n_shared': 5, 'n_independent': 4, 'n_d': 61, 'n_a': 30, 'n_steps': 5, 'gamma': 1.4360167268066042, 'momentum': 0.3881407586051244, 'lambda_sparse': 3.851866534915528e-05}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 72 with best_epoch = 42 and best_val_0_mse = 37.7334


[32m[I 2023-07-01 07:39:36,920][0m Trial 29 finished with value: 37.733397096560786 and parameters: {'n_shared': 1, 'n_independent': 5, 'n_d': 52, 'n_a': 11, 'n_steps': 4, 'gamma': 1.6163562015224657, 'momentum': 0.5593736473218023, 'lambda_sparse': 8.55607100946924e-05}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 120 with best_epoch = 90 and best_val_0_mse = 37.88811


[32m[I 2023-07-01 07:46:08,049][0m Trial 30 finished with value: 37.88810775748842 and parameters: {'n_shared': 1, 'n_independent': 3, 'n_d': 54, 'n_a': 42, 'n_steps': 9, 'gamma': 1.7087368918845114, 'momentum': 0.4746258591675483, 'lambda_sparse': 2.6895865052175958e-05}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 90 with best_epoch = 60 and best_val_0_mse = 37.78058


[32m[I 2023-07-01 07:51:05,841][0m Trial 31 finished with value: 37.78057775808535 and parameters: {'n_shared': 1, 'n_independent': 3, 'n_d': 29, 'n_a': 64, 'n_steps': 9, 'gamma': 1.5514984629736155, 'momentum': 0.1805097473781949, 'lambda_sparse': 0.0010278830592473944}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 97 with best_epoch = 67 and best_val_0_mse = 37.49389


[32m[I 2023-07-01 07:55:29,243][0m Trial 32 finished with value: 37.493893942834625 and parameters: {'n_shared': 2, 'n_independent': 3, 'n_d': 46, 'n_a': 41, 'n_steps': 6, 'gamma': 1.5819525171714077, 'momentum': 0.28065503789406565, 'lambda_sparse': 0.015536465100099687}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 111 with best_epoch = 81 and best_val_0_mse = 37.87822


[32m[I 2023-07-01 08:01:59,817][0m Trial 33 finished with value: 37.87822294126423 and parameters: {'n_shared': 2, 'n_independent': 2, 'n_d': 34, 'n_a': 30, 'n_steps': 10, 'gamma': 1.6186655108983905, 'momentum': 0.1636118134982978, 'lambda_sparse': 0.0021376885172221106}. Best is trial 4 with value: 37.21355574017806.[0m



Early stopping occurred at epoch 103 with best_epoch = 73 and best_val_0_mse = 38.22591


[32m[I 2023-07-01 08:06:05,143][0m Trial 34 finished with value: 38.225914438927155 and parameters: {'n_shared': 1, 'n_independent': 3, 'n_d': 41, 'n_a': 49, 'n_steps': 6, 'gamma': 1.5262080621617142, 'momentum': 0.2520215535799848, 'lambda_sparse': 5.103762564680734e-05}. Best is trial 4 with value: 37.21355574017806.[0m


#### Trial 4 has the best score

##### Make stratified Grouped cross-validation accounting group Tm and evloutional cluster

In [23]:
from sklearn.model_selection import StratifiedGroupKFold

In [24]:
sgkf = StratifiedGroupKFold(n_splits=10)


In [26]:
X = train['repr'].values
y = train['group_Tm'].values
groups = train['clst'].values

In [27]:
from collections import defaultdict

In [28]:
num_fold = defaultdict(list)

In [29]:
for i, (train_idx, test_idx) in enumerate(sgkf.split(X, y, groups=groups)):
    num_fold[i].append(test_idx)

In [30]:
train['num_fold'] = [None] * len(train)

In [31]:
for key in num_fold:
    train['num_fold'].iloc[num_fold[key]] = key

In [32]:
full_test

Unnamed: 0,Protein ID,Tm,repr
0,Q72HG4_TT_C1523,74.007,"[0.094343215, -0.13420974, 0.018577041, -0.116..."
1,Q745T7_TT_P0220,77.065,"[0.0594677, -0.18177457, -0.013549126, -0.1414..."
2,Q72G97_recG,72.153,"[0.11095836, -0.09452299, 0.024978561, -0.1432..."
3,Q745Z3_TT_P0162,73.594,"[0.05716914, -0.10307638, 0.007367082, -0.0030..."
4,Q72HN7_TT_C1449,79.211,"[0.014693282, -0.11982835, -0.031532157, -0.09..."
...,...,...,...
1195,P54460_prmA,46.766,"[0.03920728340744972, -0.046484656631946564, 0..."
1196,O89001_Cpd,49.077,"[0.012924259528517723, -0.10131210833787918, 0..."
1197,P46019_PHKA2,51.235,"[0.04998369514942169, -0.05574483424425125, -0..."
1198,Q9HCN4-2_GPN1,60.556,"[0.022666864097118378, -0.08774963021278381, -..."


In [33]:
train.num_fold.value_counts()

8    2866
4    2866
1    2866
6    2866
5    2866
0    2866
3    2866
2    2866
9    2865
7    2865
Name: num_fold, dtype: int64

Trial 4 finished with value: 37.21355574017806 and parameters: {'n_shared': 5, 'n_independent': 5, 'n_d': 63, 'n_a': 12, 'n_steps': 5, 'gamma': 1.4111484359834034, 'momentum': 0.3592637810142343, 'lambda_sparse': 1.2038760187881758e-05}. Best is trial 4 with value: 37.21355574017806.

In [34]:
models = []
valid_metrics = []
for fold in tqdm(train['num_fold'].unique()):
    train_sample = train[train['num_fold']!=fold]
    valid_sample = train[train['num_fold']==fold]
    x_train, y_train = process_features(train_sample)
    x_valid, y_valid = process_features(valid_sample)
    model = TabNetRegressor(n_d=63, n_a=12, n_independent=5, n_shared=5, n_steps=5, gamma=1.411, momentum=0.35926, 
                            lambda_sparse=1.20e-5, verbose=0, seed=SEED, device_name=device, optimizer_fn=torch.optim.Adam, 
                            scheduler_fn=torch.optim.lr_scheduler.StepLR, scheduler_params={'step_size':10, 'gamma':0.9})
    model.fit(x_train, y_train, eval_set=[(x_valid, y_valid)], patience=PATIENCE, max_epochs=MAX_EPOCHS, batch_size=BATCH_SIZE)
    y_pred = model.predict(x_valid)
    y_pred = y_pred.flatten()
    y_valid = y_valid.flatten()
    metrics_value = calc_metrics(y_pred, y_valid)
    print(metrics_value)
    valid_metrics.append(metrics_value)
    models.append(model)

  0%|          | 0/10 [00:00<?, ?it/s]


Early stopping occurred at epoch 72 with best_epoch = 42 and best_val_0_mse = 39.11485


 10%|█         | 1/10 [05:28<49:19, 328.85s/it]

[39.114846683851006, 6.254186332677578, 0.6772552140056775, 4.572483316343497, 0.8260796761525663]

Early stopping occurred at epoch 71 with best_epoch = 41 and best_val_0_mse = 39.42174


 20%|██        | 2/10 [10:57<43:50, 328.79s/it]

[39.42173671776494, 6.278673165388125, 0.6820624734010162, 4.572040203117088, 0.828024662593642]

Early stopping occurred at epoch 70 with best_epoch = 40 and best_val_0_mse = 35.6348


 30%|███       | 3/10 [16:19<37:58, 325.56s/it]

[35.63479741384396, 5.969488873751585, 0.7038354048067696, 4.471909816837311, 0.8396453599986895]

Early stopping occurred at epoch 70 with best_epoch = 40 and best_val_0_mse = 38.0163


 40%|████      | 4/10 [21:41<32:25, 324.17s/it]

[38.01630409983907, 6.165736298272824, 0.6920958945614515, 4.629172972787166, 0.8322062792703915]

Early stopping occurred at epoch 78 with best_epoch = 48 and best_val_0_mse = 38.5989


 50%|█████     | 5/10 [27:37<27:58, 335.76s/it]

[38.59889571494406, 6.21280095568368, 0.6812911375473149, 4.6727366760807465, 0.8256017696084578]

Early stopping occurred at epoch 83 with best_epoch = 53 and best_val_0_mse = 37.48842


 60%|██████    | 6/10 [33:57<23:22, 350.61s/it]

[37.48842076431081, 6.1227788433284775, 0.6927321594371727, 4.529669465636807, 0.8323790968607186]

Early stopping occurred at epoch 68 with best_epoch = 38 and best_val_0_mse = 36.57849


 70%|███████   | 7/10 [39:10<16:55, 338.60s/it]

[36.57848653529415, 6.048015090531285, 0.6861278432405499, 4.502126440486795, 0.8296186753342104]

Early stopping occurred at epoch 60 with best_epoch = 30 and best_val_0_mse = 36.61165


 80%|████████  | 8/10 [43:49<10:39, 319.62s/it]

[36.611652527653774, 6.050756359964742, 0.6891468899667648, 4.467581531035862, 0.830356273288147]

Early stopping occurred at epoch 78 with best_epoch = 48 and best_val_0_mse = 36.38061


 90%|█████████ | 9/10 [49:47<05:31, 331.41s/it]

[36.380612614186155, 6.031634323646134, 0.7046559053120582, 4.499017571711165, 0.8403641922081609]

Early stopping occurred at epoch 66 with best_epoch = 36 and best_val_0_mse = 37.77945


100%|██████████| 10/10 [54:51<00:00, 329.17s/it]

[37.77945273503293, 6.146499225984897, 0.68708594555277, 4.644880397782985, 0.8294169300944408]





In [38]:
valid_metrics = np.array(valid_metrics)

In [41]:
mean_val_metrics = valid_metrics.mean(axis=0)

In [43]:
std_val_metrics = valid_metrics.std(axis=0)

In [50]:
val_metrics = pd.DataFrame({'mean': mean_val_metrics, 'std': std_val_metrics}, index=['mse', 'rmse', 'r_2', 'mae', 'pcc']).T

In [51]:
val_metrics

Unnamed: 0,mse,rmse,r_2,mae,pcc
mean,37.562521,6.128057,0.689629,4.556162,0.831369
std,1.190984,0.097153,0.008615,0.070101,0.004811


In [55]:
ytest_pred = average_predict(x_test, models)

In [58]:
y_test = y_test.flatten()

In [59]:
mse, rmse, r_2, mae, pcc = calc_metrics(ytest_pred, y_test)

In [60]:
print_metrics(mse=mse, mae=mae, rmse=rmse, r_2=r_2, pcc=pcc)

Unnamed: 0,mse,mae,rmse,r_2,pcc
0,33.234517,4.257874,5.764939,0.70705,0.840879


#### Save models

In [62]:
for i in range(len(models)):
    torch.save(models[i], f'../model/tabnet_model{i}.pt')