In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys
sys.path.append('/cdthome/iuo722/accel')
from pathlib import Path
from accel.read_data import prep_data
from sklearn.metrics import r2_score, mean_squared_error, explained_variance_score
from sklearn.model_selection import GridSearchCV
from skopt import BayesSearchCV
from sklearn.model_selection import ParameterGrid

from pytorch_tabnet.tab_model import TabNetRegressor

np.random.seed(42)

In [2]:
def score(y_true,y_pred): 
    """ Function to print the metrics of interest of the model """
    mse = mean_squared_error(y_true, y_pred) #set score here and not below if using MSE in GridCV
    r2 = r2_score(y_true, y_pred)
    ev = explained_variance_score(y_true, y_pred)
    print("MSE is: ", mse)
    print("R2 is: ", r2)
    print("Explained variance is:", ev)

### Load the data

In [3]:
PATH = '/cdtshared/wearables/students/group5/'

X_train, X_val, X_test, y_train, y_val, y_test, mean_mode = prep_data(PATH + "eliminated-missing-participants.csv",normalise=True, one_hot=True)

<class 'pandas.core.frame.DataFrame'>


### Model tuning

In [4]:
n_d = [8, 16, 32, 64]
#n_a = [8,16,32, 64]
n_steps = [3,5, 8]
gamma = [1.0,1.3,1.7,2.0]

# Create the grid
param_grid = {'n_d': n_d,
               'n_steps': n_steps,
               'gamma': gamma}

In [5]:
"""
# Attempt to use the scikit-learn built in functions
model = TabNetRegressor()
model.fit(X_train.to_numpy(), y_train.values.reshape(-1,1))

clf = GridSearchCV(model, param_grid)
clf.fit(X_train.to_numpy(), y_train.values.reshape(-1,1),scoring=mean_squared_error)

# Try with Bayesian optimisation for faster computation of tuning
opt = BayesSearchCV(model, param_grid, n_iter=30, cv=5, verbose=1)
opt.fit(X_train, y_train)
"""

'\n# Attempt to use the scikit-learn built in functions\nmodel = TabNetRegressor()\nmodel.fit(X_train.to_numpy(), y_train.values.reshape(-1,1))\n\nclf = GridSearchCV(model, param_grid)\nclf.fit(X_train.to_numpy(), y_train.values.reshape(-1,1),scoring=mean_squared_error)\n\n# Try with Bayesian optimisation for faster computation of tuning\nopt = BayesSearchCV(model, param_grid, n_iter=30, cv=5, verbose=1)\nopt.fit(X_train, y_train)\n'

In [None]:
grid = ParameterGrid(param_grid)

search_results = pd.DataFrame() 
for params in grid:
    params['n_a'] = params['n_d'] # n_a=n_d always per the paper
    tabnet = TabNetRegressor()
    tabnet.set_params(**params)
    tabnet.fit(X_train.to_numpy(), y_train.values.reshape(-1,1), max_epochs=60)
    score = mean_squared_error(y_val.to_numpy(), tabnet.predict(X_val.to_numpy()))
    
    results = pd.DataFrame([params])
    results['score'] = score
    search_results = search_results.append(results)
    search_results.to_csv("results-larger-grid.csv")
    

Device used : cuda
No early stopping will be performed, last training weights will be used.
epoch 0  | loss: 312.37536|  0:00:03s
epoch 1  | loss: 81.69316|  0:00:06s
epoch 2  | loss: 77.45436|  0:00:09s
epoch 3  | loss: 76.54649|  0:00:12s
epoch 4  | loss: 75.49936|  0:00:15s
epoch 5  | loss: 74.45534|  0:00:19s
epoch 6  | loss: 73.7138 |  0:00:22s
epoch 7  | loss: 73.11408|  0:00:25s
epoch 8  | loss: 72.70931|  0:00:28s
epoch 9  | loss: 72.29734|  0:00:31s
epoch 10 | loss: 71.88532|  0:00:34s
epoch 11 | loss: 71.69349|  0:00:38s
epoch 12 | loss: 71.25032|  0:00:41s
epoch 13 | loss: 71.16936|  0:00:44s
epoch 14 | loss: 71.21986|  0:00:47s
epoch 15 | loss: 70.78746|  0:00:50s
epoch 16 | loss: 70.63534|  0:00:53s
epoch 17 | loss: 70.61977|  0:00:56s
epoch 18 | loss: 70.42246|  0:00:59s
epoch 19 | loss: 70.043  |  0:01:03s
epoch 20 | loss: 70.09652|  0:01:06s
epoch 21 | loss: 69.54108|  0:01:09s
epoch 22 | loss: 69.22034|  0:01:12s
epoch 23 | loss: 69.01707|  0:01:15s
epoch 24 | loss: 68

epoch 14 | loss: 74.86076|  0:01:20s
epoch 15 | loss: 75.77875|  0:01:26s
epoch 16 | loss: 75.09118|  0:01:32s
epoch 17 | loss: 75.02481|  0:01:37s
epoch 18 | loss: 74.98275|  0:01:43s
epoch 19 | loss: 75.11447|  0:01:49s
epoch 20 | loss: 74.99015|  0:01:54s
epoch 21 | loss: 74.64398|  0:02:00s
epoch 22 | loss: 74.64298|  0:02:05s
epoch 23 | loss: 74.6953 |  0:02:10s
epoch 24 | loss: 74.53597|  0:02:16s
epoch 25 | loss: 74.85296|  0:02:21s
epoch 26 | loss: 74.48037|  0:02:26s
epoch 27 | loss: 74.44211|  0:02:32s
epoch 28 | loss: 74.95167|  0:02:37s
epoch 29 | loss: 74.31039|  0:02:42s
epoch 30 | loss: 74.26841|  0:02:48s
epoch 31 | loss: 74.32581|  0:02:53s
epoch 32 | loss: 74.62068|  0:02:58s
epoch 33 | loss: 74.57591|  0:03:03s
epoch 34 | loss: 74.24712|  0:03:09s
epoch 35 | loss: 74.18614|  0:03:14s
epoch 36 | loss: 74.23686|  0:03:19s
epoch 37 | loss: 74.39203|  0:03:25s
epoch 38 | loss: 74.57047|  0:03:30s
epoch 39 | loss: 74.21797|  0:03:36s
epoch 40 | loss: 74.73837|  0:03:41s
e

epoch 31 | loss: 61.71572|  0:02:05s
epoch 32 | loss: 61.2905 |  0:02:10s
epoch 33 | loss: 60.28201|  0:02:15s
epoch 34 | loss: 59.85488|  0:02:19s
epoch 35 | loss: 59.49552|  0:02:24s
epoch 36 | loss: 58.63632|  0:02:29s
epoch 37 | loss: 58.06444|  0:02:34s
epoch 38 | loss: 57.66857|  0:02:38s
epoch 39 | loss: 57.21827|  0:02:42s
epoch 40 | loss: 57.35132|  0:02:46s
epoch 41 | loss: 56.67903|  0:02:50s
epoch 42 | loss: 55.5095 |  0:02:54s
epoch 43 | loss: 55.44364|  0:02:57s
epoch 44 | loss: 55.19598|  0:03:01s
epoch 45 | loss: 54.50453|  0:03:05s
epoch 46 | loss: 54.52905|  0:03:08s
epoch 47 | loss: 53.73705|  0:03:12s
epoch 48 | loss: 53.2034 |  0:03:16s
epoch 49 | loss: 52.82004|  0:03:19s
epoch 50 | loss: 52.46177|  0:03:23s
epoch 51 | loss: 52.06117|  0:03:27s
epoch 52 | loss: 52.25378|  0:03:30s
epoch 53 | loss: 51.13118|  0:03:34s
epoch 54 | loss: 51.26238|  0:03:37s
epoch 55 | loss: 50.80119|  0:03:41s
epoch 56 | loss: 50.65709|  0:03:45s
epoch 57 | loss: 50.85032|  0:03:49s
e

epoch 48 | loss: 38.54979|  0:02:20s
epoch 49 | loss: 38.34441|  0:02:24s
epoch 50 | loss: 37.51291|  0:02:27s
epoch 51 | loss: 36.92427|  0:02:29s
epoch 52 | loss: 36.8115 |  0:02:32s
epoch 53 | loss: 36.52337|  0:02:34s
epoch 54 | loss: 35.52908|  0:02:37s
epoch 55 | loss: 35.33866|  0:02:40s
epoch 56 | loss: 34.38518|  0:02:42s
epoch 57 | loss: 34.42707|  0:02:45s
epoch 58 | loss: 33.94797|  0:02:47s
epoch 59 | loss: 33.58185|  0:02:50s
epoch 60 | loss: 33.15498|  0:02:53s
epoch 61 | loss: 32.40138|  0:02:55s
epoch 62 | loss: 31.57974|  0:02:58s
epoch 63 | loss: 31.97055|  0:03:01s
epoch 64 | loss: 31.49951|  0:03:04s
epoch 65 | loss: 31.36065|  0:03:07s
epoch 66 | loss: 31.17927|  0:03:10s
epoch 67 | loss: 31.05164|  0:03:13s
epoch 68 | loss: 30.91767|  0:03:16s
epoch 69 | loss: 30.88273|  0:03:19s
epoch 70 | loss: 30.09005|  0:03:22s
epoch 71 | loss: 29.79993|  0:03:24s
epoch 72 | loss: 29.90763|  0:03:27s
epoch 73 | loss: 29.94107|  0:03:30s
epoch 74 | loss: 28.86526|  0:03:33s
e

epoch 65 | loss: 51.61065|  0:05:52s
epoch 66 | loss: 51.43685|  0:05:57s
epoch 67 | loss: 51.04534|  0:06:02s
epoch 68 | loss: 51.12594|  0:06:07s
epoch 69 | loss: 50.52495|  0:06:12s
epoch 70 | loss: 49.80589|  0:06:17s
epoch 71 | loss: 49.36082|  0:06:22s
epoch 72 | loss: 48.18191|  0:06:27s
epoch 73 | loss: 48.35283|  0:06:32s
epoch 74 | loss: 48.10306|  0:06:37s
epoch 75 | loss: 47.24865|  0:06:42s
epoch 76 | loss: 46.65369|  0:06:47s
epoch 77 | loss: 46.00038|  0:06:51s
epoch 78 | loss: 45.45222|  0:06:56s
epoch 79 | loss: 44.95182|  0:07:01s
epoch 80 | loss: 44.26415|  0:07:06s
epoch 81 | loss: 43.38165|  0:07:11s
epoch 82 | loss: 42.61244|  0:07:16s
epoch 83 | loss: 42.27978|  0:07:21s
epoch 84 | loss: 41.81759|  0:07:26s
epoch 85 | loss: 42.1396 |  0:07:31s
epoch 86 | loss: 42.34114|  0:07:36s
epoch 87 | loss: 41.98118|  0:07:41s
epoch 88 | loss: 41.42957|  0:07:46s
epoch 89 | loss: 40.41868|  0:07:50s
epoch 90 | loss: 39.3108 |  0:07:55s
epoch 91 | loss: 39.21797|  0:08:00s
e

epoch 82 | loss: 28.94869|  0:05:11s
epoch 83 | loss: 28.58069|  0:05:14s
epoch 84 | loss: 27.53312|  0:05:18s
epoch 85 | loss: 27.07875|  0:05:22s
epoch 86 | loss: 25.24382|  0:05:25s
epoch 87 | loss: 25.16894|  0:05:29s
epoch 88 | loss: 24.44266|  0:05:32s
epoch 89 | loss: 23.78145|  0:05:36s
epoch 90 | loss: 23.12467|  0:05:39s
epoch 91 | loss: 22.51027|  0:05:43s
epoch 92 | loss: 21.66728|  0:05:46s
epoch 93 | loss: 22.04993|  0:05:50s
epoch 94 | loss: 22.24571|  0:05:53s
epoch 95 | loss: 21.35564|  0:05:57s
epoch 96 | loss: 20.90151|  0:06:00s
epoch 97 | loss: 21.04756|  0:06:04s
epoch 98 | loss: 20.38627|  0:06:08s
epoch 99 | loss: 19.59142|  0:06:12s
Device used : cuda
No early stopping will be performed, last training weights will be used.
epoch 0  | loss: 197.95162|  0:00:05s
epoch 1  | loss: 81.01093|  0:00:10s
epoch 2  | loss: 77.90592|  0:00:15s
epoch 3  | loss: 79.28522|  0:00:21s
epoch 4  | loss: 76.55397|  0:00:26s
epoch 5  | loss: 76.76456|  0:00:31s
epoch 6  | loss: 76

epoch 99 | loss: 55.99507|  0:04:32s
Device used : cuda
No early stopping will be performed, last training weights will be used.
epoch 0  | loss: 298.44219|  0:00:03s
epoch 1  | loss: 84.58405|  0:00:07s
epoch 2  | loss: 81.45239|  0:00:11s
epoch 3  | loss: 80.26908|  0:00:14s
epoch 4  | loss: 80.64962|  0:00:17s
epoch 5  | loss: 78.6668 |  0:00:21s
epoch 6  | loss: 76.20063|  0:00:24s
epoch 7  | loss: 74.52824|  0:00:28s
epoch 8  | loss: 73.75183|  0:00:31s
epoch 9  | loss: 73.49773|  0:00:34s
epoch 10 | loss: 72.87206|  0:00:38s
epoch 11 | loss: 72.44639|  0:00:41s
epoch 12 | loss: 72.40273|  0:00:45s
epoch 13 | loss: 72.33265|  0:00:48s
epoch 14 | loss: 72.11021|  0:00:51s
epoch 15 | loss: 71.77442|  0:00:55s
epoch 16 | loss: 71.78125|  0:00:58s
epoch 17 | loss: 71.29802|  0:01:02s
epoch 18 | loss: 71.07581|  0:01:05s
epoch 19 | loss: 70.98983|  0:01:08s
epoch 20 | loss: 70.57178|  0:01:12s
epoch 21 | loss: 70.47894|  0:01:15s
epoch 22 | loss: 70.32701|  0:01:19s
epoch 23 | loss: 69

epoch 13 | loss: 70.98155|  0:00:35s
epoch 14 | loss: 70.76432|  0:00:38s
epoch 15 | loss: 70.50987|  0:00:40s
epoch 16 | loss: 70.25897|  0:00:43s
epoch 17 | loss: 70.12916|  0:00:45s
epoch 18 | loss: 70.3433 |  0:00:48s
epoch 19 | loss: 69.76126|  0:00:50s
epoch 20 | loss: 69.54296|  0:00:53s
epoch 21 | loss: 69.53135|  0:00:56s
epoch 22 | loss: 69.04427|  0:00:59s
epoch 23 | loss: 68.75108|  0:01:02s
epoch 24 | loss: 68.06149|  0:01:04s
epoch 25 | loss: 68.05892|  0:01:07s
epoch 26 | loss: 66.96625|  0:01:09s
epoch 27 | loss: 66.34567|  0:01:12s
epoch 28 | loss: 65.6543 |  0:01:15s
epoch 29 | loss: 65.21142|  0:01:17s
epoch 30 | loss: 64.6053 |  0:01:20s
epoch 31 | loss: 64.29841|  0:01:22s
epoch 32 | loss: 63.73717|  0:01:25s
epoch 33 | loss: 63.10847|  0:01:27s
epoch 34 | loss: 62.21088|  0:01:30s
epoch 35 | loss: 61.39598|  0:01:33s
epoch 36 | loss: 61.16294|  0:01:35s
epoch 37 | loss: 61.04173|  0:01:38s
epoch 38 | loss: 60.14988|  0:01:40s
epoch 39 | loss: 59.64727|  0:01:43s
e

epoch 30 | loss: 70.92085|  0:02:35s
epoch 31 | loss: 70.44972|  0:02:41s
epoch 32 | loss: 70.34704|  0:02:46s
epoch 33 | loss: 69.98215|  0:02:51s
epoch 34 | loss: 70.26058|  0:02:56s
epoch 35 | loss: 70.25467|  0:03:01s
epoch 36 | loss: 69.92986|  0:03:06s
epoch 37 | loss: 69.14297|  0:03:11s
epoch 38 | loss: 68.669  |  0:03:16s
epoch 39 | loss: 68.43169|  0:03:21s
epoch 40 | loss: 68.29199|  0:03:27s
epoch 41 | loss: 67.16482|  0:03:33s
epoch 42 | loss: 66.86968|  0:03:39s
epoch 43 | loss: 66.83779|  0:03:44s
epoch 44 | loss: 65.97023|  0:03:51s
epoch 45 | loss: 65.20364|  0:03:57s
epoch 46 | loss: 64.51473|  0:04:02s
epoch 47 | loss: 64.34095|  0:04:07s
epoch 48 | loss: 63.78105|  0:04:12s
epoch 49 | loss: 63.54267|  0:04:17s
epoch 50 | loss: 62.7532 |  0:04:22s
epoch 51 | loss: 62.03497|  0:04:27s
epoch 52 | loss: 62.29542|  0:04:32s
epoch 53 | loss: 61.48368|  0:04:37s
epoch 54 | loss: 60.65086|  0:04:42s
epoch 55 | loss: 61.08146|  0:04:47s
epoch 56 | loss: 60.9178 |  0:04:52s
e

epoch 47 | loss: 68.63693|  0:03:09s
epoch 48 | loss: 68.2163 |  0:03:13s
epoch 49 | loss: 68.27103|  0:03:17s
epoch 50 | loss: 68.46879|  0:03:20s
epoch 51 | loss: 68.80932|  0:03:24s
epoch 52 | loss: 67.51275|  0:03:28s
epoch 53 | loss: 66.53126|  0:03:32s
epoch 54 | loss: 65.78046|  0:03:36s
epoch 55 | loss: 65.00261|  0:03:39s
epoch 56 | loss: 64.84863|  0:03:43s
epoch 57 | loss: 63.5898 |  0:03:47s
epoch 58 | loss: 62.96461|  0:03:50s
epoch 59 | loss: 62.75085|  0:03:54s
epoch 60 | loss: 62.26306|  0:03:57s
epoch 61 | loss: 61.56814|  0:04:01s
epoch 62 | loss: 60.76082|  0:04:05s
epoch 63 | loss: 59.89637|  0:04:09s
epoch 64 | loss: 59.1732 |  0:04:12s
epoch 65 | loss: 57.95835|  0:04:16s
epoch 66 | loss: 57.28588|  0:04:20s
epoch 67 | loss: 56.49249|  0:04:24s
epoch 68 | loss: 55.87882|  0:04:28s
epoch 69 | loss: 55.13348|  0:04:31s
epoch 70 | loss: 54.98931|  0:04:35s
epoch 71 | loss: 54.24755|  0:04:39s
epoch 72 | loss: 52.81844|  0:04:42s
epoch 73 | loss: 52.91517|  0:04:46s
e

epoch 64 | loss: 15.22298|  0:03:05s
epoch 65 | loss: 15.91633|  0:03:08s
epoch 66 | loss: 15.64023|  0:03:10s
epoch 67 | loss: 15.1594 |  0:03:13s
epoch 68 | loss: 14.78283|  0:03:15s
epoch 69 | loss: 14.58202|  0:03:18s
epoch 70 | loss: 14.55478|  0:03:21s
epoch 71 | loss: 14.15737|  0:03:23s
epoch 72 | loss: 14.49427|  0:03:26s
epoch 73 | loss: 13.93561|  0:03:28s
epoch 74 | loss: 13.6699 |  0:03:31s
epoch 75 | loss: 13.71909|  0:03:33s
epoch 76 | loss: 13.43778|  0:03:36s
epoch 77 | loss: 13.17942|  0:03:39s
epoch 78 | loss: 13.84175|  0:03:41s
epoch 79 | loss: 13.34881|  0:03:44s
epoch 80 | loss: 13.10451|  0:03:46s
epoch 81 | loss: 13.23032|  0:03:49s
epoch 82 | loss: 12.98756|  0:03:52s
epoch 83 | loss: 12.63791|  0:03:55s
epoch 84 | loss: 13.0409 |  0:03:58s
epoch 85 | loss: 12.29649|  0:04:00s
epoch 86 | loss: 12.42456|  0:04:03s
epoch 87 | loss: 12.2679 |  0:04:06s
epoch 88 | loss: 12.0895 |  0:04:09s
epoch 89 | loss: 12.63399|  0:04:12s
epoch 90 | loss: 11.87019|  0:04:15s
e

epoch 81 | loss: 51.01507|  0:07:00s
epoch 82 | loss: 50.63381|  0:07:05s
epoch 83 | loss: 49.7741 |  0:07:10s
epoch 84 | loss: 48.19607|  0:07:15s
epoch 85 | loss: 48.10037|  0:07:21s
epoch 86 | loss: 46.54575|  0:07:27s
epoch 87 | loss: 45.46958|  0:07:33s
epoch 88 | loss: 44.83378|  0:07:39s
epoch 89 | loss: 44.19389|  0:07:45s
epoch 90 | loss: 43.43567|  0:07:51s
epoch 91 | loss: 43.25422|  0:07:56s
epoch 92 | loss: 40.93397|  0:08:01s
epoch 93 | loss: 39.81989|  0:08:07s
epoch 94 | loss: 39.33914|  0:08:12s
epoch 95 | loss: 38.20728|  0:08:18s
epoch 96 | loss: 38.27772|  0:08:24s
epoch 97 | loss: 37.27715|  0:08:30s
epoch 98 | loss: 36.11983|  0:08:36s
epoch 99 | loss: 34.54843|  0:08:41s
Device used : cuda
No early stopping will be performed, last training weights will be used.
epoch 0  | loss: 312.37536|  0:00:02s
epoch 1  | loss: 81.69316|  0:00:05s
epoch 2  | loss: 77.45436|  0:00:07s
epoch 3  | loss: 76.54649|  0:00:10s
epoch 4  | loss: 75.49936|  0:00:12s
epoch 5  | loss: 74

epoch 98 | loss: 55.37838|  0:05:39s
epoch 99 | loss: 55.07784|  0:05:43s
Device used : cuda
No early stopping will be performed, last training weights will be used.
epoch 0  | loss: 248.04065|  0:00:04s
epoch 1  | loss: 86.30015|  0:00:09s
epoch 2  | loss: 81.42643|  0:00:14s
epoch 3  | loss: 78.87972|  0:00:19s
epoch 4  | loss: 79.56758|  0:00:24s
epoch 5  | loss: 77.16496|  0:00:30s
epoch 6  | loss: 76.9038 |  0:00:35s
epoch 7  | loss: 77.00301|  0:00:41s
epoch 8  | loss: 76.55679|  0:00:46s
epoch 9  | loss: 76.34117|  0:00:51s
epoch 10 | loss: 75.22937|  0:00:56s
epoch 11 | loss: 75.07828|  0:01:02s
epoch 12 | loss: 75.24536|  0:01:06s
epoch 13 | loss: 74.92307|  0:01:12s
epoch 14 | loss: 74.86076|  0:01:18s
epoch 15 | loss: 75.77875|  0:01:24s
epoch 16 | loss: 75.09118|  0:01:30s
epoch 17 | loss: 75.02481|  0:01:36s
epoch 18 | loss: 74.98275|  0:01:41s
epoch 19 | loss: 75.11447|  0:01:45s
epoch 20 | loss: 74.99015|  0:01:50s
epoch 21 | loss: 74.64398|  0:01:55s
epoch 22 | loss: 74

epoch 12 | loss: 71.85837|  0:00:46s
epoch 13 | loss: 71.28157|  0:00:49s
epoch 14 | loss: 70.88779|  0:00:53s
epoch 15 | loss: 70.83863|  0:00:57s
epoch 16 | loss: 70.41435|  0:01:01s
epoch 17 | loss: 70.12118|  0:01:04s
epoch 18 | loss: 69.67004|  0:01:08s
epoch 19 | loss: 69.3486 |  0:01:12s
epoch 20 | loss: 68.69529|  0:01:16s
epoch 21 | loss: 68.63821|  0:01:20s
epoch 22 | loss: 68.09139|  0:01:23s
epoch 23 | loss: 67.09791|  0:01:27s
epoch 24 | loss: 66.72421|  0:01:31s
epoch 25 | loss: 65.99782|  0:01:35s
epoch 26 | loss: 65.71252|  0:01:39s
epoch 27 | loss: 64.80336|  0:01:42s
epoch 28 | loss: 64.18532|  0:01:46s
epoch 29 | loss: 63.10388|  0:01:50s
epoch 30 | loss: 62.80255|  0:01:54s
epoch 31 | loss: 61.71572|  0:01:57s
epoch 32 | loss: 61.2905 |  0:02:01s
epoch 33 | loss: 60.28201|  0:02:05s
epoch 34 | loss: 59.85488|  0:02:08s
epoch 35 | loss: 59.49552|  0:02:12s
epoch 36 | loss: 58.63632|  0:02:15s
epoch 37 | loss: 58.06444|  0:02:19s
epoch 38 | loss: 57.66857|  0:02:22s
e

epoch 29 | loss: 53.79974|  0:01:34s
epoch 30 | loss: 53.6492 |  0:01:37s
epoch 31 | loss: 52.0024 |  0:01:40s
epoch 32 | loss: 51.49843|  0:01:43s
epoch 33 | loss: 50.53411|  0:01:46s
epoch 34 | loss: 49.55914|  0:01:49s
epoch 35 | loss: 48.69453|  0:01:53s
epoch 36 | loss: 47.28085|  0:01:56s
epoch 37 | loss: 47.02374|  0:01:59s
epoch 38 | loss: 45.89998|  0:02:02s
epoch 39 | loss: 44.822  |  0:02:05s
epoch 40 | loss: 43.77428|  0:02:08s
epoch 41 | loss: 43.56771|  0:02:11s
epoch 42 | loss: 42.73775|  0:02:14s
epoch 43 | loss: 42.44534|  0:02:17s
epoch 44 | loss: 41.04878|  0:02:21s
epoch 45 | loss: 40.43422|  0:02:24s
epoch 46 | loss: 40.01356|  0:02:27s
epoch 47 | loss: 38.95786|  0:02:30s
epoch 48 | loss: 38.54979|  0:02:34s
epoch 49 | loss: 38.34441|  0:02:37s
epoch 50 | loss: 37.51291|  0:02:40s
epoch 51 | loss: 36.92427|  0:02:43s
epoch 52 | loss: 36.8115 |  0:02:46s
epoch 53 | loss: 36.52337|  0:02:49s
epoch 54 | loss: 35.52908|  0:02:53s
epoch 55 | loss: 35.33866|  0:02:56s
e

epoch 46 | loss: 64.81935|  0:04:31s
epoch 47 | loss: 64.08788|  0:04:37s
epoch 48 | loss: 63.61755|  0:04:43s
epoch 49 | loss: 63.20037|  0:04:49s
epoch 50 | loss: 62.30307|  0:04:54s
epoch 51 | loss: 62.17072|  0:05:00s
epoch 52 | loss: 61.46368|  0:05:06s
epoch 53 | loss: 61.10262|  0:05:12s
epoch 54 | loss: 60.21911|  0:05:17s
epoch 55 | loss: 58.77994|  0:05:23s
epoch 56 | loss: 57.73475|  0:05:29s
epoch 57 | loss: 57.56475|  0:05:34s
epoch 58 | loss: 56.46589|  0:05:40s
epoch 59 | loss: 55.89753|  0:05:46s
epoch 60 | loss: 55.99911|  0:05:52s
epoch 61 | loss: 54.62462|  0:05:57s
epoch 62 | loss: 53.95699|  0:06:03s
epoch 63 | loss: 53.31703|  0:06:09s
epoch 64 | loss: 52.88079|  0:06:14s
epoch 65 | loss: 51.61065|  0:06:20s
epoch 66 | loss: 51.43685|  0:06:26s
epoch 67 | loss: 51.04534|  0:06:32s
epoch 68 | loss: 51.12594|  0:06:38s
epoch 69 | loss: 50.52495|  0:06:44s
epoch 70 | loss: 49.80589|  0:06:49s
epoch 71 | loss: 49.36082|  0:06:55s
epoch 72 | loss: 48.18191|  0:07:01s
e

epoch 63 | loss: 52.213  |  0:04:28s
epoch 64 | loss: 50.8777 |  0:04:32s
epoch 65 | loss: 49.42249|  0:04:36s
epoch 66 | loss: 48.01795|  0:04:40s
epoch 67 | loss: 46.49414|  0:04:44s
epoch 68 | loss: 45.97383|  0:04:48s
epoch 69 | loss: 44.71717|  0:04:53s
epoch 70 | loss: 42.80221|  0:04:57s
epoch 71 | loss: 41.44362|  0:05:01s
epoch 72 | loss: 40.69549|  0:05:05s
epoch 73 | loss: 38.33401|  0:05:09s
epoch 74 | loss: 37.0077 |  0:05:13s
epoch 75 | loss: 36.16215|  0:05:17s
epoch 76 | loss: 35.40123|  0:05:22s
epoch 77 | loss: 34.13814|  0:05:26s
epoch 78 | loss: 32.80522|  0:05:30s
epoch 79 | loss: 31.37572|  0:05:34s
epoch 80 | loss: 30.39565|  0:05:38s
epoch 81 | loss: 29.37612|  0:05:42s
epoch 82 | loss: 28.94869|  0:05:47s
epoch 83 | loss: 28.58069|  0:05:51s
epoch 84 | loss: 27.53312|  0:05:56s
epoch 85 | loss: 27.07875|  0:06:00s
epoch 86 | loss: 25.24382|  0:06:04s
epoch 87 | loss: 25.16894|  0:06:08s
epoch 88 | loss: 24.44266|  0:06:13s
epoch 89 | loss: 23.78145|  0:06:17s
e

epoch 80 | loss: 57.43742|  0:04:07s
epoch 81 | loss: 57.06923|  0:04:10s
epoch 82 | loss: 57.23766|  0:04:14s
epoch 83 | loss: 56.73759|  0:04:17s
epoch 84 | loss: 56.8834 |  0:04:20s
epoch 85 | loss: 57.11476|  0:04:23s
epoch 86 | loss: 56.8556 |  0:04:26s
epoch 87 | loss: 56.80119|  0:04:29s
epoch 88 | loss: 56.64252|  0:04:32s
epoch 89 | loss: 56.65189|  0:04:35s
epoch 90 | loss: 56.45824|  0:04:38s
epoch 91 | loss: 56.3599 |  0:04:41s
epoch 92 | loss: 56.16373|  0:04:45s
epoch 93 | loss: 56.28512|  0:04:48s
epoch 94 | loss: 56.33459|  0:04:51s
epoch 95 | loss: 56.23502|  0:04:54s
epoch 96 | loss: 56.23672|  0:04:57s
epoch 97 | loss: 56.03983|  0:05:00s
epoch 98 | loss: 56.25474|  0:05:03s
epoch 99 | loss: 55.99507|  0:05:06s
Device used : cuda
No early stopping will be performed, last training weights will be used.
epoch 0  | loss: 298.44219|  0:00:04s
epoch 1  | loss: 84.58405|  0:00:08s
epoch 2  | loss: 81.45239|  0:00:12s
epoch 3  | loss: 80.26908|  0:00:16s
epoch 4  | loss: 80

epoch 97 | loss: 72.87575|  0:09:15s
epoch 98 | loss: 72.86833|  0:09:20s
epoch 99 | loss: 73.0639 |  0:09:26s
Device used : cuda
No early stopping will be performed, last training weights will be used.
epoch 0  | loss: 318.67538|  0:00:03s
epoch 1  | loss: 79.93089|  0:00:06s
epoch 2  | loss: 76.3055 |  0:00:09s
epoch 3  | loss: 75.29215|  0:00:12s
epoch 4  | loss: 74.95161|  0:00:15s
epoch 5  | loss: 73.94045|  0:00:18s
epoch 6  | loss: 73.24791|  0:00:22s
epoch 7  | loss: 72.83431|  0:00:25s
epoch 8  | loss: 72.56326|  0:00:28s
epoch 9  | loss: 72.20476|  0:00:31s
epoch 10 | loss: 72.23794|  0:00:34s
epoch 11 | loss: 71.76419|  0:00:37s
epoch 12 | loss: 71.46746|  0:00:40s
epoch 13 | loss: 70.98155|  0:00:43s
epoch 14 | loss: 70.76432|  0:00:46s
epoch 15 | loss: 70.50987|  0:00:50s
epoch 16 | loss: 70.25897|  0:00:53s
epoch 17 | loss: 70.12916|  0:00:56s
epoch 18 | loss: 70.3433 |  0:00:59s
epoch 19 | loss: 69.76126|  0:01:02s
epoch 20 | loss: 69.54296|  0:01:05s
epoch 21 | loss: 69

epoch 11 | loss: 76.20133|  0:01:09s
epoch 12 | loss: 75.77824|  0:01:15s
epoch 13 | loss: 75.56995|  0:01:20s
epoch 14 | loss: 74.67489|  0:01:26s
epoch 15 | loss: 74.33877|  0:01:32s
epoch 16 | loss: 73.6713 |  0:01:37s
epoch 17 | loss: 73.56452|  0:01:43s
epoch 18 | loss: 73.36445|  0:01:49s
epoch 19 | loss: 73.13076|  0:01:54s
epoch 20 | loss: 73.28425|  0:02:00s
epoch 21 | loss: 72.98799|  0:02:05s
epoch 22 | loss: 72.83123|  0:02:11s
epoch 23 | loss: 72.2976 |  0:02:17s
epoch 24 | loss: 72.63213|  0:02:22s
epoch 25 | loss: 72.32402|  0:02:28s
epoch 26 | loss: 71.92464|  0:02:34s
epoch 27 | loss: 71.65977|  0:02:39s
epoch 28 | loss: 71.48458|  0:02:45s
epoch 29 | loss: 71.18916|  0:02:51s
epoch 30 | loss: 70.92085|  0:02:56s
epoch 31 | loss: 70.44972|  0:03:02s
epoch 32 | loss: 70.34704|  0:03:08s
epoch 33 | loss: 69.98215|  0:03:13s
epoch 34 | loss: 70.26058|  0:03:19s
epoch 35 | loss: 70.25467|  0:03:25s
epoch 36 | loss: 69.92986|  0:03:30s
epoch 37 | loss: 69.14297|  0:03:36s
e

epoch 28 | loss: 72.90975|  0:02:00s
epoch 29 | loss: 73.01015|  0:02:05s
epoch 30 | loss: 72.34087|  0:02:09s
epoch 31 | loss: 72.48168|  0:02:13s
epoch 32 | loss: 72.19806|  0:02:17s
epoch 33 | loss: 72.20503|  0:02:21s
epoch 34 | loss: 71.83213|  0:02:25s
epoch 35 | loss: 72.20951|  0:02:29s
epoch 36 | loss: 72.28895|  0:02:34s
epoch 37 | loss: 72.846  |  0:02:38s
epoch 38 | loss: 72.60959|  0:02:42s
epoch 39 | loss: 72.35378|  0:02:46s
epoch 40 | loss: 72.31063|  0:02:50s
epoch 41 | loss: 71.85645|  0:02:54s
epoch 42 | loss: 71.21889|  0:02:59s
epoch 43 | loss: 70.50185|  0:03:03s
epoch 44 | loss: 70.16767|  0:03:07s
epoch 45 | loss: 69.67285|  0:03:11s
epoch 46 | loss: 69.21017|  0:03:16s
epoch 47 | loss: 68.63693|  0:03:20s
epoch 48 | loss: 68.2163 |  0:03:24s
epoch 49 | loss: 68.27103|  0:03:28s
epoch 50 | loss: 68.46879|  0:03:32s
epoch 51 | loss: 68.80932|  0:03:36s
epoch 52 | loss: 67.51275|  0:03:40s
epoch 53 | loss: 66.53126|  0:03:45s
epoch 54 | loss: 65.78046|  0:03:49s
e

epoch 45 | loss: 23.49367|  0:02:23s
epoch 46 | loss: 22.21853|  0:02:27s
epoch 47 | loss: 21.39353|  0:02:30s
epoch 48 | loss: 21.44071|  0:02:33s
epoch 49 | loss: 21.33081|  0:02:36s
epoch 50 | loss: 20.53229|  0:02:39s
epoch 51 | loss: 19.34924|  0:02:42s
epoch 52 | loss: 19.21356|  0:02:45s
epoch 53 | loss: 18.4787 |  0:02:48s
epoch 54 | loss: 18.86931|  0:02:51s
epoch 55 | loss: 18.36559|  0:02:54s
epoch 56 | loss: 18.25592|  0:02:57s
epoch 57 | loss: 17.69277|  0:03:00s
epoch 58 | loss: 16.78475|  0:03:04s
epoch 59 | loss: 17.10046|  0:03:07s
epoch 60 | loss: 16.49894|  0:03:10s
epoch 61 | loss: 16.17741|  0:03:13s
epoch 62 | loss: 15.81152|  0:03:16s
epoch 63 | loss: 16.03258|  0:03:19s
epoch 64 | loss: 15.22298|  0:03:22s
epoch 65 | loss: 15.91633|  0:03:25s
epoch 66 | loss: 15.64023|  0:03:29s
epoch 67 | loss: 15.1594 |  0:03:32s
epoch 68 | loss: 14.78283|  0:03:35s
epoch 69 | loss: 14.58202|  0:03:38s
epoch 70 | loss: 14.55478|  0:03:41s
epoch 71 | loss: 14.15737|  0:03:44s
e

epoch 62 | loss: 63.5159 |  0:06:03s
epoch 63 | loss: 62.41208|  0:06:09s
epoch 64 | loss: 61.36778|  0:06:15s
epoch 65 | loss: 61.64263|  0:06:21s
epoch 66 | loss: 61.4136 |  0:06:26s
epoch 67 | loss: 61.72475|  0:06:32s
epoch 68 | loss: 60.05038|  0:06:38s
epoch 69 | loss: 58.78358|  0:06:44s
epoch 70 | loss: 58.19119|  0:06:50s
epoch 71 | loss: 58.10205|  0:06:56s
epoch 72 | loss: 58.0937 |  0:07:02s
epoch 73 | loss: 57.92984|  0:07:08s
epoch 74 | loss: 57.81713|  0:07:13s
epoch 75 | loss: 56.23532|  0:07:19s
epoch 76 | loss: 54.66535|  0:07:25s
epoch 77 | loss: 54.3632 |  0:07:31s
epoch 78 | loss: 53.03154|  0:07:36s
epoch 79 | loss: 52.48088|  0:07:42s
epoch 80 | loss: 51.72457|  0:07:48s
epoch 81 | loss: 51.01507|  0:07:53s
epoch 82 | loss: 50.63381|  0:07:59s
epoch 83 | loss: 49.7741 |  0:08:05s
epoch 84 | loss: 48.19607|  0:08:11s
epoch 85 | loss: 48.10037|  0:08:16s
epoch 86 | loss: 46.54575|  0:08:22s
epoch 87 | loss: 45.46958|  0:08:28s
epoch 88 | loss: 44.83378|  0:08:33s
e

epoch 79 | loss: 57.22862|  0:05:35s
epoch 80 | loss: 57.5161 |  0:05:39s
epoch 81 | loss: 57.51254|  0:05:43s
epoch 82 | loss: 57.075  |  0:05:47s
epoch 83 | loss: 57.2467 |  0:05:51s
epoch 84 | loss: 56.5519 |  0:05:55s
epoch 85 | loss: 56.53337|  0:05:59s
epoch 86 | loss: 56.89923|  0:06:03s
epoch 87 | loss: 56.57452|  0:06:07s
epoch 88 | loss: 56.87436|  0:06:11s
epoch 89 | loss: 56.36863|  0:06:15s
epoch 90 | loss: 56.54715|  0:06:19s
epoch 91 | loss: 56.15603|  0:06:23s
epoch 92 | loss: 55.9424 |  0:06:27s
epoch 93 | loss: 56.07956|  0:06:31s
epoch 94 | loss: 56.02763|  0:06:35s
epoch 95 | loss: 56.31972|  0:06:39s
epoch 96 | loss: 56.55101|  0:06:43s
epoch 97 | loss: 56.57901|  0:06:47s
epoch 98 | loss: 56.89295|  0:06:51s
epoch 99 | loss: 56.72315|  0:06:55s
Device used : cuda
No early stopping will be performed, last training weights will be used.
epoch 0  | loss: 264.78264|  0:00:05s
epoch 1  | loss: 89.92706|  0:00:11s
epoch 2  | loss: 86.86117|  0:00:16s
epoch 3  | loss: 86

epoch 96 | loss: 46.5602 |  0:04:29s
epoch 97 | loss: 46.72772|  0:04:32s
epoch 98 | loss: 46.50009|  0:04:34s
epoch 99 | loss: 46.32722|  0:04:37s
Device used : cuda
No early stopping will be performed, last training weights will be used.
epoch 0  | loss: 315.38642|  0:00:03s
epoch 1  | loss: 86.79656|  0:00:07s
epoch 2  | loss: 84.93751|  0:00:11s
epoch 3  | loss: 83.81853|  0:00:15s
epoch 4  | loss: 82.66033|  0:00:19s
epoch 5  | loss: 80.48818|  0:00:22s
epoch 6  | loss: 78.77209|  0:00:26s
epoch 7  | loss: 76.74848|  0:00:29s
epoch 8  | loss: 75.36483|  0:00:33s
epoch 9  | loss: 74.40809|  0:00:37s
epoch 10 | loss: 73.43167|  0:00:41s
epoch 11 | loss: 72.88184|  0:00:45s
epoch 12 | loss: 72.4716 |  0:00:48s
epoch 13 | loss: 72.36311|  0:00:52s
epoch 14 | loss: 71.8846 |  0:00:56s
epoch 15 | loss: 71.66741|  0:01:00s
epoch 16 | loss: 71.42203|  0:01:04s
epoch 17 | loss: 71.29851|  0:01:07s
epoch 18 | loss: 71.17596|  0:01:11s
epoch 19 | loss: 70.94089|  0:01:15s
epoch 20 | loss: 70

epoch 10 | loss: 71.50597|  0:00:34s
epoch 11 | loss: 71.24383|  0:00:37s
epoch 12 | loss: 70.83292|  0:00:40s
epoch 13 | loss: 70.45778|  0:00:44s
epoch 14 | loss: 70.21313|  0:00:47s
epoch 15 | loss: 69.01892|  0:00:50s
epoch 16 | loss: 68.40432|  0:00:53s
epoch 17 | loss: 67.75156|  0:00:56s
epoch 18 | loss: 66.5172 |  0:00:59s
epoch 19 | loss: 65.96439|  0:01:02s
epoch 20 | loss: 64.54594|  0:01:06s
epoch 21 | loss: 63.20927|  0:01:09s
epoch 22 | loss: 62.11961|  0:01:12s
epoch 23 | loss: 61.03058|  0:01:15s
epoch 24 | loss: 59.58116|  0:01:19s
epoch 25 | loss: 58.37465|  0:01:22s
epoch 26 | loss: 57.04529|  0:01:25s
epoch 27 | loss: 56.09862|  0:01:28s
epoch 28 | loss: 55.00239|  0:01:31s
epoch 29 | loss: 54.38779|  0:01:34s
epoch 30 | loss: 53.42028|  0:01:37s
epoch 31 | loss: 51.4962 |  0:01:41s
epoch 32 | loss: 50.97379|  0:01:44s
epoch 33 | loss: 49.76829|  0:01:47s
epoch 34 | loss: 48.61156|  0:01:50s
epoch 35 | loss: 47.73953|  0:01:53s
epoch 36 | loss: 46.91133|  0:01:56s
e

epoch 27 | loss: 70.35786|  0:02:30s
epoch 28 | loss: 69.8565 |  0:02:34s
epoch 29 | loss: 69.83153|  0:02:39s
epoch 30 | loss: 69.77896|  0:02:44s
epoch 31 | loss: 68.69457|  0:02:49s
epoch 32 | loss: 68.22579|  0:02:54s
epoch 33 | loss: 67.51633|  0:02:59s
epoch 34 | loss: 67.50294|  0:03:04s
epoch 35 | loss: 66.9921 |  0:03:09s
epoch 36 | loss: 66.59732|  0:03:14s
epoch 37 | loss: 66.04363|  0:03:19s
epoch 38 | loss: 65.64858|  0:03:24s
epoch 39 | loss: 64.48944|  0:03:29s
epoch 40 | loss: 64.04979|  0:03:34s
epoch 41 | loss: 63.61254|  0:03:39s
epoch 42 | loss: 63.3991 |  0:03:44s
epoch 43 | loss: 62.16613|  0:03:49s
epoch 44 | loss: 61.64837|  0:03:54s
epoch 45 | loss: 60.68765|  0:03:59s
epoch 46 | loss: 59.68111|  0:04:04s
epoch 47 | loss: 58.8654 |  0:04:09s
epoch 48 | loss: 58.55645|  0:04:14s
epoch 49 | loss: 58.82473|  0:04:19s
epoch 50 | loss: 59.08546|  0:04:23s
epoch 51 | loss: 57.4049 |  0:04:29s
epoch 52 | loss: 56.67611|  0:04:34s
epoch 53 | loss: 55.60401|  0:04:39s
e

epoch 44 | loss: 46.81725|  0:02:41s
epoch 45 | loss: 45.97834|  0:02:44s
epoch 46 | loss: 44.33474|  0:02:48s
epoch 47 | loss: 43.95131|  0:02:51s
epoch 48 | loss: 42.79397|  0:02:55s
epoch 49 | loss: 41.00698|  0:02:59s
epoch 50 | loss: 40.37275|  0:03:02s
epoch 51 | loss: 38.70666|  0:03:06s
epoch 52 | loss: 37.12836|  0:03:09s
epoch 53 | loss: 35.76505|  0:03:13s
epoch 54 | loss: 34.88873|  0:03:16s
epoch 55 | loss: 33.9664 |  0:03:20s
epoch 56 | loss: 33.149  |  0:03:23s
epoch 57 | loss: 32.80945|  0:03:26s
epoch 58 | loss: 31.21404|  0:03:30s
epoch 59 | loss: 31.06314|  0:03:33s
epoch 60 | loss: 30.3532 |  0:03:37s
epoch 61 | loss: 29.60607|  0:03:40s
epoch 62 | loss: 28.35645|  0:03:44s
epoch 63 | loss: 27.10131|  0:03:47s
epoch 64 | loss: 26.31028|  0:03:51s
epoch 65 | loss: 25.78284|  0:03:54s
epoch 66 | loss: 25.67565|  0:03:58s
epoch 67 | loss: 24.46742|  0:04:01s
epoch 68 | loss: 23.8152 |  0:04:05s
epoch 69 | loss: 23.0728 |  0:04:08s
epoch 70 | loss: 22.06534|  0:04:12s
e

In [None]:
search_results = search_results.set_index(pd.Index([i for i in range(search_results.shape[0])]
min_mse_index= search_results['score'].idxmin()
best_model_row = search_results.iloc[min_mse_index,:]

best_model = best_model = TabNetRegressor(gamma = best_model_row["gamma"].values[0], 
               n_a = best_model_row["n_a"].values[0], 
               n_d = best_model_row["n_d"].values[0], 
               n_steps = best_model_row["n_steps"].values[0])
best_model.fit(X_train.to_numpy(), y_train.values.reshape(-1,1))

In [None]:
score(y_test.to_numpy(), best_model.predict(X_test.to_numpy()))