In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.ensemble import RandomForestRegressor,GradientBoostingRegressor
from sklearn.linear_model import LinearRegression,Ridge
from sklearn.preprocessing import LabelEncoder, StandardScaler, MinMaxScaler
from sklearn.impute import SimpleImputer
from sklearn.metrics import mean_absolute_error, root_mean_squared_error,r2_score
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from pytorch_tabnet.tab_model import TabNetRegressor

from hyperopt import fmin, tpe, hp, Trials

import sys
import os

# Add the parent directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
from pre_processing import preprocess_data

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [2]:
def load_data():
    dataset = pd.read_excel(r"globalterrorismdb_2021Jan-June_1222dist.xlsx") # 2021-2021 June
    return dataset

## Data Loading and Preprocessing

In [3]:
dataset = load_data()
X, y, dataset = preprocess_data(dataset)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

dataset.head()

Unnamed: 0,success,gname_freq,city_freq,country_freq,attacktype1_score,targtype1_score,weaptype1_score,gname_score,country_score,city_score,...,nkill_likelihood_score,region_3,region_5,region_6,region_8,region_9,region_10,region_11,region_12,nkill
0,1.0,0.228571,0.066667,0.167228,0.875,1.0,0.857143,0.977778,0.948718,0.123077,...,0.589934,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,5.0
1,1.0,0.121429,0.033333,0.096521,0.5,0.842105,0.857143,0.911111,0.871795,0.030769,...,0.347526,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0
2,1.0,0.0,0.033333,0.159371,0.875,0.894737,0.857143,0.0,0.641026,0.015385,...,0.47705,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
3,1.0,0.121429,0.033333,0.096521,0.875,0.894737,0.857143,0.911111,0.871795,0.092308,...,0.503175,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,6.0
4,1.0,0.0,0.033333,0.093154,0.875,1.0,0.857143,0.022222,0.74359,0.030769,...,0.543654,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0


## 2) TabNet
## Hyperparameter Tuning and Model Training

In [4]:
def objective(params):
    # Extract hyperparameters
    n_d = int(params['n_d'])
    n_a = int(params['n_a'])
    n_steps = int(params['n_steps'])
    gamma = params['gamma']
    lambda_sparse = params['lambda_sparse']
    lr = params['lr']
    weight_decay = params['weight_decay']
    
    # Initialize TabNet model with current hyperparameters
    tabNetModel = TabNetRegressor(
        n_d=n_d,            
        n_a=n_a,             
        n_steps=n_steps,     
        gamma=gamma,         
        lambda_sparse=lambda_sparse,
        optimizer_fn=torch.optim.AdamW,
        optimizer_params=dict(lr=lr, weight_decay=weight_decay),
        scheduler_params={"step_size":8, "gamma":0.9},
        scheduler_fn=torch.optim.lr_scheduler.StepLR,
        mask_type="entmax",
    )
    
    # Train the model and get validation loss (e.g., RMSE)
    tabNetModel.fit(
        X_train=X_train, 
        y_train=y_train.reshape(-1, 1),
        eval_set=[(X_test, y_test.reshape(-1, 1))],
        eval_name=['val'],
        eval_metric=['rmse'],
        max_epochs=40,
        patience=15,
        batch_size=64,
        virtual_batch_size=32,
        num_workers=3,
        drop_last=True
    )
    
    # Extract validation RMSE (or any other metric you prefer)
    val_rmse = tabNetModel.history.history['val_rmse'][-1]  # Get the last RMSE value
    return val_rmse

# Define the search space
space = {
    'n_d': hp.quniform('n_d', 8, 64, 16),          # Number of decision units
    'n_a': hp.quniform('n_a', 8, 64, 16),          # Number of attention units
    'n_steps': hp.quniform('n_steps', 3, 10, 2),   # Number of steps in TabNet
    'gamma': hp.uniform('gamma', 1.0, 2.0),        # Sparsity controlling factor
    'lambda_sparse': hp.loguniform('lambda_sparse', -5, -1),  # Sparsity regularization term
    'lr': hp.loguniform('lr', -5, -1),             # Learning rate (log scale)
    'weight_decay': hp.loguniform('weight_decay', -5, -1)  # Weight decay (log scale)
}

# Initialize Trials object to store results
trials = Trials()

# Run the optimization with Hyperopt
best = fmin(
    fn=objective,             # Objective function
    space=space,              # Search space
    algo=tpe.suggest,         # Optimization algorithm (Tree of Parzen Estimators)
    max_evals=15,             # Number of evaluations
    trials=trials             # Store trials
)

# Print the best hyperparameters found
print("Best hyperparameters:", best)


  0%|          | 0/15 [00:00<?, ?trial/s, best loss=?]




epoch 0  | loss: 44.52012| val_rmse: 7.11181 |  0:00:21s
epoch 1  | loss: 40.02954| val_rmse: 7.03802 |  0:00:36s
epoch 2  | loss: 35.6399 | val_rmse: 6.27565 |  0:00:54s
epoch 3  | loss: 30.95694| val_rmse: 6.79133 |  0:01:10s
epoch 4  | loss: 29.27898| val_rmse: 5.34931 |  0:01:27s
epoch 5  | loss: 26.85812| val_rmse: 6.86841 |  0:01:46s
epoch 6  | loss: 24.44915| val_rmse: 6.72198 |  0:02:02s
epoch 7  | loss: 24.74707| val_rmse: 4.88494 |  0:02:19s
epoch 8  | loss: 24.47092| val_rmse: 5.28352 |  0:02:37s
epoch 9  | loss: 22.90403| val_rmse: 5.42474 |  0:02:55s
epoch 10 | loss: 22.11533| val_rmse: 6.45587 |  0:03:13s
epoch 11 | loss: 22.78002| val_rmse: 5.34018 |  0:03:28s
epoch 12 | loss: 20.97734| val_rmse: 3.91207 |  0:03:44s
epoch 13 | loss: 18.74327| val_rmse: 3.49495 |  0:04:00s
epoch 14 | loss: 19.58238| val_rmse: 3.36599 |  0:04:19s
epoch 15 | loss: 18.98585| val_rmse: 3.00902 |  0:04:38s
epoch 16 | loss: 20.61077| val_rmse: 2.97067 |  0:04:58s
epoch 17 | loss: 19.15022| val_




  7%|▋         | 1/15 [11:06<2:35:37, 666.94s/trial, best loss: 3.907758157350823]




epoch 0  | loss: 42.46979| val_rmse: 7.49564 |  0:00:23s                          
epoch 1  | loss: 35.22186| val_rmse: 7.34564 |  0:00:44s                          
epoch 2  | loss: 31.32014| val_rmse: 7.26213 |  0:01:04s                          
epoch 3  | loss: 28.80977| val_rmse: 7.36305 |  0:01:23s                          
epoch 4  | loss: 30.22053| val_rmse: 7.05032 |  0:01:42s                          
epoch 5  | loss: 28.727  | val_rmse: 9.43821 |  0:02:03s                          
epoch 6  | loss: 37.64225| val_rmse: 7.51735 |  0:02:28s                          
epoch 7  | loss: 35.08614| val_rmse: 6.74595 |  0:02:53s                          
epoch 8  | loss: 30.47883| val_rmse: 13.508  |  0:03:19s                          
epoch 9  | loss: 31.67772| val_rmse: 6.96439 |  0:03:45s                          
epoch 10 | loss: 28.94419| val_rmse: 7.77224 |  0:04:11s                          
epoch 11 | loss: 31.71226| val_rmse: 8.07159 |  0:04:37s                          
epoc




 13%|█▎        | 2/15 [21:53<2:21:53, 654.87s/trial, best loss: 3.907758157350823]




epoch 0  | loss: 52.68773| val_rmse: 7.47257 |  0:00:18s                          
epoch 1  | loss: 45.74204| val_rmse: 7.67541 |  0:00:37s                          
epoch 2  | loss: 43.26756| val_rmse: 7.39057 |  0:00:56s                          
epoch 3  | loss: 41.18277| val_rmse: 7.30956 |  0:01:15s                          
epoch 4  | loss: 39.03738| val_rmse: 7.92848 |  0:01:34s                          
epoch 5  | loss: 39.21744| val_rmse: 7.37697 |  0:01:52s                          
epoch 6  | loss: 40.51704| val_rmse: 7.08259 |  0:02:11s                          
epoch 7  | loss: 37.76796| val_rmse: 6.57183 |  0:02:30s                          
epoch 8  | loss: 35.62853| val_rmse: 6.6441  |  0:02:49s                          
epoch 9  | loss: 32.69339| val_rmse: 6.05404 |  0:03:08s                          
epoch 10 | loss: 26.91585| val_rmse: 7.70093 |  0:03:27s                          
epoch 11 | loss: 22.37024| val_rmse: 4.26296 |  0:03:47s                          
epoc




 20%|██        | 3/15 [31:20<2:02:55, 614.65s/trial, best loss: 3.907758157350823]




epoch 0  | loss: 46.6845 | val_rmse: 7.34345 |  0:00:17s                          
epoch 1  | loss: 38.96864| val_rmse: 7.36922 |  0:00:35s                          
epoch 2  | loss: 37.18931| val_rmse: 7.29388 |  0:00:52s                          
epoch 3  | loss: 34.80742| val_rmse: 7.06651 |  0:01:10s                          
epoch 4  | loss: 36.17383| val_rmse: 7.08279 |  0:01:28s                          
epoch 5  | loss: 36.6327 | val_rmse: 6.89798 |  0:01:45s                          
epoch 6  | loss: 27.4861 | val_rmse: 5.12826 |  0:02:02s                          
epoch 7  | loss: 29.74421| val_rmse: 5.77529 |  0:02:20s                          
epoch 8  | loss: 28.96998| val_rmse: 5.41032 |  0:02:37s                          
epoch 9  | loss: 26.04263| val_rmse: 5.19904 |  0:02:54s                          
epoch 10 | loss: 25.22872| val_rmse: 5.2364  |  0:03:12s                          
epoch 11 | loss: 23.6226 | val_rmse: 4.61802 |  0:03:29s                          
epoc




 27%|██▋       | 4/15 [42:57<1:58:40, 647.36s/trial, best loss: 3.907758157350823]




epoch 0  | loss: 47.72361| val_rmse: 8.78895 |  0:00:14s                          
epoch 1  | loss: 35.59987| val_rmse: 8.5928  |  0:00:30s                          
epoch 2  | loss: 31.19815| val_rmse: 7.40569 |  0:00:44s                          
epoch 3  | loss: 24.21437| val_rmse: 6.51911 |  0:00:59s                          
epoch 4  | loss: 24.99306| val_rmse: 7.26319 |  0:01:13s                          
epoch 5  | loss: 26.51727| val_rmse: 6.77357 |  0:01:28s                          
epoch 6  | loss: 18.43472| val_rmse: 6.95237 |  0:01:42s                          
epoch 7  | loss: 23.16714| val_rmse: 4.35847 |  0:01:57s                          
epoch 8  | loss: 26.74255| val_rmse: 5.53794 |  0:02:12s                          
epoch 9  | loss: 23.98174| val_rmse: 8.57241 |  0:02:27s                          
epoch 10 | loss: 21.78472| val_rmse: 3.6876  |  0:02:41s                          
epoch 11 | loss: 22.54038| val_rmse: 6.77273 |  0:02:56s                          
epoc




 33%|███▎      | 5/15 [52:48<1:44:28, 626.86s/trial, best loss: 3.907758157350823]




epoch 0  | loss: 48.37005| val_rmse: 6.88058 |  0:00:16s                          
epoch 1  | loss: 39.22877| val_rmse: 6.7868  |  0:00:32s                          
epoch 2  | loss: 38.66765| val_rmse: 6.80709 |  0:00:48s                          
epoch 3  | loss: 33.99695| val_rmse: 6.34502 |  0:01:04s                          
epoch 4  | loss: 30.51206| val_rmse: 5.53404 |  0:01:20s                          
epoch 5  | loss: 25.48104| val_rmse: 4.60912 |  0:01:36s                          
epoch 6  | loss: 27.03928| val_rmse: 4.14338 |  0:01:52s                          
epoch 7  | loss: 26.70634| val_rmse: 5.93196 |  0:02:09s                          
epoch 8  | loss: 22.13003| val_rmse: 6.01376 |  0:02:25s                          
epoch 9  | loss: 24.12558| val_rmse: 6.48985 |  0:02:41s                          
epoch 10 | loss: 23.47802| val_rmse: 5.22398 |  0:02:57s                          
epoch 11 | loss: 19.48212| val_rmse: 6.16928 |  0:03:13s                          
epoc




 40%|████      | 6/15 [1:03:35<1:35:05, 633.94s/trial, best loss: 3.440357728977523]




epoch 0  | loss: 46.4163 | val_rmse: 8.66161 |  0:00:17s                            
epoch 1  | loss: 45.29433| val_rmse: 7.34849 |  0:00:35s                            
epoch 2  | loss: 38.95731| val_rmse: 6.55646 |  0:00:52s                            
epoch 3  | loss: 37.22578| val_rmse: 7.45708 |  0:01:10s                            
epoch 4  | loss: 34.90919| val_rmse: 7.03933 |  0:01:27s                            
epoch 5  | loss: 34.14817| val_rmse: 6.7277  |  0:01:44s                            
epoch 6  | loss: 32.60859| val_rmse: 6.43901 |  0:02:01s                            
epoch 7  | loss: 34.13007| val_rmse: 5.5476  |  0:02:19s                            
epoch 8  | loss: 31.99235| val_rmse: 6.70196 |  0:02:36s                            
epoch 9  | loss: 33.59787| val_rmse: 6.12231 |  0:02:53s                            
epoch 10 | loss: 31.36766| val_rmse: 7.25693 |  0:03:11s                            
epoch 11 | loss: 33.61009| val_rmse: 6.38646 |  0:03:28s         




 47%|████▋     | 7/15 [1:15:16<1:27:27, 655.90s/trial, best loss: 3.440357728977523]




epoch 0  | loss: 40.72873| val_rmse: 7.07804 |  0:00:15s                            
epoch 1  | loss: 33.73166| val_rmse: 6.90429 |  0:00:30s                            
epoch 2  | loss: 32.44822| val_rmse: 6.76643 |  0:00:45s                            
epoch 3  | loss: 30.18685| val_rmse: 6.02545 |  0:01:00s                            
epoch 4  | loss: 31.92297| val_rmse: 5.66758 |  0:01:15s                            
epoch 5  | loss: 27.68958| val_rmse: 6.12627 |  0:01:30s                            
epoch 6  | loss: 30.41017| val_rmse: 5.99964 |  0:01:45s                            
epoch 7  | loss: 24.43928| val_rmse: 5.01848 |  0:02:00s                            
epoch 8  | loss: 24.78031| val_rmse: 6.00547 |  0:02:14s                            
epoch 9  | loss: 25.25455| val_rmse: 4.83962 |  0:02:28s                            
epoch 10 | loss: 22.70502| val_rmse: 4.8294  |  0:02:43s                            
epoch 11 | loss: 24.10798| val_rmse: 5.02326 |  0:02:58s         




 53%|█████▎    | 8/15 [1:23:46<1:11:05, 609.42s/trial, best loss: 3.440357728977523]




epoch 0  | loss: 41.9324 | val_rmse: 7.35828 |  0:00:15s                            
epoch 1  | loss: 42.47064| val_rmse: 7.28254 |  0:00:32s                            
epoch 2  | loss: 40.87052| val_rmse: 6.98802 |  0:00:47s                            
epoch 3  | loss: 38.65668| val_rmse: 7.32098 |  0:01:03s                            
epoch 4  | loss: 38.37262| val_rmse: 7.02017 |  0:01:19s                            
epoch 5  | loss: 38.7854 | val_rmse: 6.7361  |  0:01:35s                            
epoch 6  | loss: 37.87817| val_rmse: 6.85647 |  0:01:50s                            
epoch 7  | loss: 32.34219| val_rmse: 5.78783 |  0:02:06s                            
epoch 8  | loss: 27.93181| val_rmse: 5.78574 |  0:02:22s                            
epoch 9  | loss: 25.57796| val_rmse: 5.66175 |  0:02:37s                            
epoch 10 | loss: 25.20907| val_rmse: 6.38491 |  0:02:53s                            
epoch 11 | loss: 26.18397| val_rmse: 5.79197 |  0:03:09s         




 60%|██████    | 9/15 [1:34:12<1:01:26, 614.42s/trial, best loss: 3.184476332710138]




epoch 0  | loss: 48.7147 | val_rmse: 9.6567  |  0:00:14s                            
epoch 1  | loss: 41.94994| val_rmse: 7.28022 |  0:00:28s                            
epoch 2  | loss: 37.38236| val_rmse: 7.29872 |  0:00:43s                            
epoch 3  | loss: 36.50275| val_rmse: 7.24035 |  0:00:57s                            
epoch 4  | loss: 36.22497| val_rmse: 7.01771 |  0:01:12s                            
epoch 5  | loss: 36.05574| val_rmse: 5.70319 |  0:01:26s                            
epoch 6  | loss: 30.57688| val_rmse: 7.25905 |  0:01:40s                            
epoch 7  | loss: 26.73726| val_rmse: 7.77247 |  0:01:55s                            
epoch 8  | loss: 30.05554| val_rmse: 6.71503 |  0:02:10s                            
epoch 9  | loss: 29.27158| val_rmse: 6.41209 |  0:02:24s                            
epoch 10 | loss: 31.1797 | val_rmse: 6.50188 |  0:02:38s                            
epoch 11 | loss: 29.22611| val_rmse: 5.56091 |  0:02:53s         




 67%|██████▋   | 10/15 [1:43:51<50:17, 603.45s/trial, best loss: 3.184476332710138] 




epoch 0  | loss: 45.70389| val_rmse: 7.1332  |  0:00:14s                           
epoch 1  | loss: 36.74465| val_rmse: 6.80873 |  0:00:28s                           
epoch 2  | loss: 37.56989| val_rmse: 7.17225 |  0:00:43s                           
epoch 3  | loss: 35.45186| val_rmse: 6.85704 |  0:00:57s                           
epoch 4  | loss: 35.30896| val_rmse: 6.44631 |  0:01:11s                           
epoch 5  | loss: 32.4373 | val_rmse: 5.72573 |  0:01:25s                           
epoch 6  | loss: 31.47191| val_rmse: 6.72299 |  0:01:40s                           
epoch 7  | loss: 33.73331| val_rmse: 6.10194 |  0:01:54s                           
epoch 8  | loss: 33.75272| val_rmse: 5.70478 |  0:02:08s                           
epoch 9  | loss: 27.65723| val_rmse: 4.68454 |  0:02:22s                           
epoch 10 | loss: 26.65429| val_rmse: 5.18918 |  0:02:37s                           
epoch 11 | loss: 26.03271| val_rmse: 4.31829 |  0:02:51s                    




 73%|███████▎  | 11/15 [1:53:27<39:41, 595.28s/trial, best loss: 3.184476332710138]




epoch 0  | loss: 45.61532| val_rmse: 7.42332 |  0:00:14s                           
epoch 1  | loss: 40.39279| val_rmse: 7.32526 |  0:00:28s                           
epoch 2  | loss: 31.56296| val_rmse: 7.52681 |  0:00:43s                           
epoch 3  | loss: 29.16207| val_rmse: 7.30151 |  0:00:57s                           
epoch 4  | loss: 27.81351| val_rmse: 7.27463 |  0:01:11s                           
epoch 5  | loss: 25.8471 | val_rmse: 7.1475  |  0:01:26s                           
epoch 6  | loss: 29.04041| val_rmse: 7.38143 |  0:01:40s                           
epoch 7  | loss: 27.13556| val_rmse: 6.35913 |  0:01:54s                           
epoch 8  | loss: 24.69711| val_rmse: 4.35939 |  0:02:09s                           
epoch 9  | loss: 25.0018 | val_rmse: 6.48669 |  0:02:23s                           
epoch 10 | loss: 27.2824 | val_rmse: 8.18886 |  0:02:38s                           
epoch 11 | loss: 29.28252| val_rmse: 6.5383  |  0:02:52s                    




 80%|████████  | 12/15 [2:00:13<26:52, 537.54s/trial, best loss: 3.184476332710138]




epoch 0  | loss: 46.87481| val_rmse: 7.45206 |  0:00:17s                           
epoch 1  | loss: 38.98568| val_rmse: 6.84612 |  0:00:33s                           
epoch 2  | loss: 33.16784| val_rmse: 6.92544 |  0:00:50s                           
epoch 3  | loss: 31.02076| val_rmse: 6.74913 |  0:01:07s                           
epoch 4  | loss: 26.4053 | val_rmse: 7.19159 |  0:01:24s                           
epoch 5  | loss: 25.834  | val_rmse: 6.77086 |  0:01:41s                           
epoch 6  | loss: 27.3046 | val_rmse: 6.02683 |  0:01:57s                           
epoch 7  | loss: 26.87679| val_rmse: 6.65814 |  0:02:14s                           
epoch 8  | loss: 24.75146| val_rmse: 6.0266  |  0:02:31s                           
epoch 9  | loss: 26.18322| val_rmse: 6.09102 |  0:02:48s                           
epoch 10 | loss: 25.06315| val_rmse: 7.40076 |  0:03:05s                           
epoch 11 | loss: 19.14066| val_rmse: 7.25169 |  0:03:21s                    




 87%|████████▋ | 13/15 [2:11:30<19:19, 579.83s/trial, best loss: 3.184476332710138]




epoch 0  | loss: 44.68319| val_rmse: 7.09161 |  0:00:15s                           
epoch 1  | loss: 35.32821| val_rmse: 7.12368 |  0:00:31s                           
epoch 2  | loss: 36.8371 | val_rmse: 7.21936 |  0:00:47s                           
epoch 3  | loss: 32.17844| val_rmse: 6.76721 |  0:01:02s                           
epoch 4  | loss: 27.67533| val_rmse: 5.21716 |  0:01:18s                           
epoch 5  | loss: 29.03589| val_rmse: 7.14165 |  0:01:33s                           
epoch 6  | loss: 27.76353| val_rmse: 6.67806 |  0:01:49s                           
epoch 7  | loss: 25.68824| val_rmse: 5.21645 |  0:02:05s                           
epoch 8  | loss: 22.38087| val_rmse: 6.2421  |  0:02:20s                           
epoch 9  | loss: 22.42354| val_rmse: 5.6841  |  0:02:36s                           
epoch 10 | loss: 22.26186| val_rmse: 5.02439 |  0:02:52s                           
epoch 11 | loss: 26.70617| val_rmse: 4.72539 |  0:03:07s                    




 93%|█████████▎| 14/15 [2:19:52<09:16, 556.28s/trial, best loss: 3.184476332710138]




epoch 0  | loss: 45.58073| val_rmse: 7.1364  |  0:00:13s                           
epoch 1  | loss: 37.74866| val_rmse: 6.50126 |  0:00:25s                           
epoch 2  | loss: 31.62779| val_rmse: 6.75511 |  0:00:38s                           
epoch 3  | loss: 30.85029| val_rmse: 6.48796 |  0:00:51s                           
epoch 4  | loss: 26.15702| val_rmse: 6.54951 |  0:01:04s                           
epoch 5  | loss: 24.53813| val_rmse: 7.20062 |  0:01:17s                           
epoch 6  | loss: 23.09699| val_rmse: 6.05022 |  0:01:30s                           
epoch 7  | loss: 22.73302| val_rmse: 5.73868 |  0:01:43s                           
epoch 8  | loss: 22.79028| val_rmse: 6.20147 |  0:01:56s                           
epoch 9  | loss: 21.28281| val_rmse: 5.99148 |  0:02:09s                           
epoch 10 | loss: 22.45807| val_rmse: 5.34003 |  0:02:22s                           
epoch 11 | loss: 21.40658| val_rmse: 5.36677 |  0:02:35s                    




100%|██████████| 15/15 [2:28:36<00:00, 594.46s/trial, best loss: 3.184476332710138]
Best hyperparameters: {'gamma': np.float64(1.5494371504360784), 'lambda_sparse': np.float64(0.021527849243156594), 'lr': np.float64(0.037088099470451996), 'n_a': np.float64(48.0), 'n_d': np.float64(32.0), 'n_steps': np.float64(8.0), 'weight_decay': np.float64(0.16622206018645966)}
