### Scenario A
Last updated Mar 11, 2022

In [None]:
from sys import path as pylib
import os
pylib += [os.path.abspath('../')]

In [None]:
import pickle
from transfer_learning.transfer_utils import get_device
from utils.utility_functions import get_best_params
from models.models_manager import BasicMLPModelManager, BasicTabNetModelManager, BasicXGBoostModelManager
from models.trainers import train_with_grid_search, train_with_CI

from configs.features_sets import ALL_FEATURES
from configs.training import BIN_CLASSIFICATION_WEIGHTS, MLP_EVAL_METRIC, TABNET_EVAL_METRIC, XGBOOST_EVAL_METRIC
from configs.files import *

In [None]:
device = get_device()

## 1. DNN

### 1.1. DNN (from scratch) - grid search

In [None]:
model_manager = BasicMLPModelManager(device=device)
MLP_GS_results, weekly_best_params = train_with_grid_search(model_manager=model_manager,
                                                            train_data_path=PREVENT_EXP1_PATH,
                                                            features=ALL_FEATURES, test_data_path=TEST_EXP1_PATH,
                                                            interval=7, eval_metric=MLP_EVAL_METRIC,
                                                            class_weights=BIN_CLASSIFICATION_WEIGHTS)
MLP_GS_results.to_csv(EXP_1_MLP_GS_FILE, index=False)
with open(EXP_1_MLP_WEEKLY_PARAMS_FILE, 'wb') as f:
    pickle.dump(weekly_best_params, f)


### 1.2. DNN (from scratch) - Confidence interval estimation

In [None]:
model_manager = BasicMLPModelManager(device=device)

with open(EXP_1_MLP_WEEKLY_PARAMS_FILE, 'rb') as f:
    best_mlp_params = pickle.load(f)
params = get_best_params(params_AUC_pairs=best_mlp_params, n_weeks=3)

MLP_CI_results = train_with_CI(model_manager=model_manager, train_data_path=PREVENT_EXP1_PATH, features=ALL_FEATURES,
                               test_data_path=TEST_EXP1_PATH, interval=7, params=params, eval_metric=MLP_EVAL_METRIC,
                               class_weights=BIN_CLASSIFICATION_WEIGHTS)
MLP_CI_results.to_csv(EXP_1_MLP_CI_FILE, index=False)


### 1.3. DNN (domain adaptation) - grid search

In [None]:
model_manager = BasicMLPModelManager(from_pretrained=True, device=device,
                                     pretrained_path='../transfer_learning/pretrained_models/EXP_1_MLP.pkl') # Add pre-trained models here
MLP_GS_results, weekly_best_params = train_with_grid_search(model_manager=model_manager,
                                                            train_data_path=PREVENT_EXP1_PATH,
                                                            features=ALL_FEATURES, test_data_path=TEST_EXP1_PATH,
                                                            interval=7, eval_metric=MLP_EVAL_METRIC,
                                                            class_weights=BIN_CLASSIFICATION_WEIGHTS)
MLP_GS_results.to_csv(EXP_1_MLP_TL_GS_FILE, index=False)
with open(EXP_1_MLP_TL_WEEKLY_PARAMS_FILE, 'wb') as f:
    pickle.dump(weekly_best_params, f)


### 1.4. DNN (domain adaptation) - Confidence interval estimation

In [None]:
model_manager = BasicMLPModelManager(from_pretrained=True, device=device,
                                     pretrained_path='../transfer_learning/pretrained_models/EXP_1_MLP.pkl') # Add pre-trained models here

with open(EXP_1_MLP_TL_WEEKLY_PARAMS_FILE, 'rb') as f:
    best_mlp_params = pickle.load(f)

params = get_best_params(params_AUC_pairs=best_mlp_params, n_weeks=3)
MLP_TL_CI_results = train_with_CI(model_manager=model_manager, train_data_path=PREVENT_EXP1_PATH, features=ALL_FEATURES,
                                  test_data_path=TEST_EXP1_PATH, interval=7, params=params,
                                  eval_metric=MLP_EVAL_METRIC, class_weights=BIN_CLASSIFICATION_WEIGHTS)
MLP_TL_CI_results.to_csv(EXP_1_MLP_TL_CI_FILE, index=False)


### 1.5. DNN (domain adaptation) - Extended study period - Grid search

In [None]:
model_manager = BasicMLPModelManager(from_pretrained=True, device=device,
                                     pretrained_path='../transfer_learning/pretrained_models/EXP_1_MLP.pkl') # Add pre-trained models here
MLP_EXPANDED_GS_results, weekly_best_params = train_with_grid_search(model_manager=model_manager, interval=7,
                                                                     train_data_path=PREVENT_EXPANDED_EXP1_PATH,
                                                                     features=ALL_FEATURES, eval_metric=MLP_EVAL_METRIC,
                                                                     test_data_path=TEST_SHRINKED_EXP1_PATH,
                                                                     class_weights=BIN_CLASSIFICATION_WEIGHTS)
MLP_EXPANDED_GS_results.to_csv(EXP_1_MLP_EXPANDED_TL_GS_FILE, index=False)
with open(EXP_1_MLP_EXPANDED_TL_WEEKLY_PARAMS_FILE, 'wb') as f:
    pickle.dump(weekly_best_params, f)


### 1.6. DNN (domain adaptation) - Extended study period -  Confidence interval estimation

In [None]:
model_manager = BasicMLPModelManager(from_pretrained=True, device=device,
                                     pretrained_path='../transfer_learning/pretrained_models/EXP_1_MLP.pkl') # Add pre-trained models here

with open(EXP_1_MLP_EXPANDED_TL_WEEKLY_PARAMS_FILE, 'rb') as f:
    best_mlp_params = pickle.load(f)
params = get_best_params(params_AUC_pairs=best_mlp_params, n_weeks=3)

MLP_TL_CI_results = train_with_CI(model_manager=model_manager, train_data_path=PREVENT_EXPANDED_EXP1_PATH, features=ALL_FEATURES,
                                  test_data_path=TEST_SHRINKED_EXP1_PATH, interval=7, params=params,
                                  eval_metric=MLP_EVAL_METRIC, class_weights=BIN_CLASSIFICATION_WEIGHTS)
MLP_TL_CI_results.to_csv(EXP_1_MLP_EXPANDED_TL_CI_FILE, index=False)


## 2. TabNet

### 2.1. TabNet (from scratch) - grid search

In [None]:
model_manager = BasicTabNetModelManager()
TABNET_GS_results, weekly_best_params = train_with_grid_search(model_manager=model_manager,
                                                               train_data_path=PREVENT_EXP1_PATH,
                                                               features=ALL_FEATURES, test_data_path=TEST_EXP1_PATH,
                                                               interval=7, eval_metric=TABNET_EVAL_METRIC,
                                                               class_weights=BIN_CLASSIFICATION_WEIGHTS)
TABNET_GS_results.to_csv(EXP_1_TABNET_GS_FILE, index=False)
with open(EXP_1_TABNET_WEEKLY_PARAMS_FILE, 'wb') as f:
    pickle.dump(weekly_best_params, f)

### 2.2. TabNet (from scratch) - Confidence interval estimation

In [None]:
model_manager = BasicTabNetModelManager()

with open(EXP_1_TABNET_WEEKLY_PARAMS_FILE, 'rb') as f:
    best_TanNet_params = pickle.load(f)
params = get_best_params(params_AUC_pairs=best_TanNet_params, n_weeks=3)

TabNet_CI_results = train_with_CI(model_manager=model_manager, train_data_path=PREVENT_EXP1_PATH, features=ALL_FEATURES,
                                  test_data_path=TEST_EXP1_PATH, interval=7, params=params,
                                  eval_metric=TABNET_EVAL_METRIC, class_weights=BIN_CLASSIFICATION_WEIGHTS)
TabNet_CI_results.to_csv(EXP_1_TabNet_CI_FILE, index=False)

### 2.3. TabNet (domain adaptation) - grid search

In [None]:
model_manager = BasicTabNetModelManager(from_pretrained=True,
                                        pretrained_path='../transfer_learning/pretrained_models/TabNet.zip') # Add pre-trained models here
TABNET_GS_results, weekly_best_params = train_with_grid_search(model_manager=model_manager,
                                                               train_data_path=PREVENT_EXP1_PATH,
                                                               features=ALL_FEATURES, test_data_path=TEST_EXP1_PATH,
                                                               interval=7, eval_metric=TABNET_EVAL_METRIC,
                                                               class_weights=BIN_CLASSIFICATION_WEIGHTS)
TABNET_GS_results.to_csv(EXP_1_TABNET_TL_GS_FILE, index=False)
with open(EXP_1_TABNET_TL_WEEKLY_PARAMS_FILE, 'wb') as f:
    pickle.dump(weekly_best_params, f)

### 2.4. TabNet (domain adaptation) - Confidence interval estimation

In [None]:
model_manager = BasicTabNetModelManager(from_pretrained=True,
                                        pretrained_path='../transfer_learning/pretrained_models/TabNet.zip') # Add pre-trained models here
with open(EXP_1_TABNET_TL_WEEKLY_PARAMS_FILE, 'rb') as f:
    best_TanNet_params = pickle.load(f)
params = get_best_params(params_AUC_pairs=best_TanNet_params, n_weeks=3)


TabNet_CI_results = train_with_CI(model_manager=model_manager, train_data_path=PREVENT_EXP1_PATH, features=ALL_FEATURES,
                                  test_data_path=TEST_EXP1_PATH, interval=7, params=params,
                                  eval_metric=TABNET_EVAL_METRIC, class_weights=BIN_CLASSIFICATION_WEIGHTS)
TabNet_CI_results.to_csv(EXP_1_TABNET_TL_CI_FILE, index=False)

## 3. XGBoost

### 3.1. XGBoost - grid search

In [None]:
xgboost_GC_results, weekly_best_params = train_with_grid_search(model_manager=BasicXGBoostModelManager(),
                                                                train_data_path=PREVENT_EXP1_PATH,
                                                                test_data_path=TEST_EXP1_PATH, threshold=0.5,
                                                                features=ALL_FEATURES, interval=7,
                                                                class_weights=BIN_CLASSIFICATION_WEIGHTS,
                                                                eval_metric=XGBOOST_EVAL_METRIC)
xgboost_GC_results.to_csv(EXP_1_XGBOOST_GS_FILE, index=False)
with open(EXP_1_XGBOOST_WEEKLY_PARAMS_FILE, 'wb') as f:
    pickle.dump(weekly_best_params, f)


### 3.2. XGBoost - Confidence interval estimation

In [None]:
with open(EXP_1_XGBOOST_WEEKLY_PARAMS_FILE, 'rb') as f:
    best_xgboost_params = pickle.load(f)
params = get_best_params(params_AUC_pairs=best_xgboost_params, n_weeks=3)

xgboost_CI_results = train_with_CI(model_manager=BasicXGBoostModelManager(), train_data_path=PREVENT_EXP1_PATH,
                                   test_data_path=TEST_EXP1_PATH, features=ALL_FEATURES, interval=7,params=params,
                                   class_weights=BIN_CLASSIFICATION_WEIGHTS, eval_metric=XGBOOST_EVAL_METRIC,
                                   threshold=0.5)
xgboost_CI_results.to_csv(EXP_1_XGBOOST_CI_FILE, index=False)


### 3.3. XGBoost - Extended study period - Grid search

In [None]:
xgboost_EXPANDED_GC_results, weekly_best_params = train_with_grid_search(model_manager=BasicXGBoostModelManager(),
                                                                         train_data_path=PREVENT_EXPANDED_EXP1_PATH,
                                                                         test_data_path=TEST_SHRINKED_EXP1_PATH,
                                                                         features=ALL_FEATURES, interval=7,
                                                                         class_weights=BIN_CLASSIFICATION_WEIGHTS,
                                                                         eval_metric=XGBOOST_EVAL_METRIC,
                                                                         threshold=0.5)
xgboost_EXPANDED_GC_results.to_csv(EXP_1_XGBOOST_EXPANDED_GS_FILE, index=False)
with open(EXP_1_XGBOOST_EXPANDED_WEEKLY_PARAMS_FILE, 'wb') as f:
    pickle.dump(weekly_best_params, f)


### 3.4. XGBoost - Extended study period - Confidence interval estimation

In [None]:
with open(EXP_1_XGBOOST_EXPANDED_WEEKLY_PARAMS_FILE, 'rb') as f:
    best_xgboost_params = pickle.load(f)
params = get_best_params(params_AUC_pairs=best_xgboost_params, n_weeks=3)

xgboost_EXPANDED_CI_results = train_with_CI(model_manager=BasicXGBoostModelManager(),
                                            train_data_path=PREVENT_EXPANDED_EXP1_PATH, eval_metric=XGBOOST_EVAL_METRIC,
                                            test_data_path=TEST_SHRINKED_EXP1_PATH, features=ALL_FEATURES, interval=7,
                                            params=params, class_weights=BIN_CLASSIFICATION_WEIGHTS, threshold=0.5)
xgboost_EXPANDED_CI_results.to_csv(EXP_1_XGBOOST_EXPANDED_CI_FILE, index=False)