In [None]:
import os
import sys

import hydra
import numpy as np
import pandas as pd
from ax import (Arm, ChoiceParameter, ComparisonOp, Data, FixedParameter,
                Metric, Models, Objective, OptimizationConfig,
                OutcomeConstraint, Parameter, ParameterType, RangeParameter,
                Runner, SearchSpace, SimpleExperiment)

sys.path.append(r"C:\Users\jesga\OneDrive\Desktop\NeuroscienceLeuven\Code\Decodingv7")
sys.path.append(r"C:\Users\jesga\OneDrive\Desktop\NeuroscienceLeuven\Code\tnsbmi")
sys.path.append(r"C:\Users\jesga\OneDrive\Desktop\NeuroscienceLeuven\Code\tnspython")
import json
import logging
import warnings

from ax import save
from ax.storage.runner_registry import register_runner
from decoding.data import get_alignOn, get_timesteps
from decoding.parameters import Preprocessing_DefaultParameters
from modeling.defaults import CLASSIFIERS, FEATURE_EXTRACTORS, PREPROCESSORS
from modeling.pipeline_builder import PipelineBuilder
from modeling.pipelines import DecoderPipeline
from omegaconf import DictConfig, OmegaConf
from preprocessing.ratesTransformer import RatesTransformer
from preprocessing.reachTuningTransformer import ReachTuningTransformer
from sklearn.model_selection import train_test_split
from tnsbmi import dataconversion
from tnsbmi.binning import Bin
from ax import load
import pickle
from optimization.utils import SearchSpaceGen
from ax.service.ax_client import AxClient
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.plot.contour import interact_contour, plot_contour
from ax.plot.diagnostic import interact_cross_validation
from ax.plot.scatter import(
    interact_fitted,
    plot_objective_vs_constraints,
    tile_fitted,
    plot_fitted
)
from ax.plot.marginal_effects import plot_marginal_effects
from ax.plot.slice import plot_slice, interact_slice
from ax.modelbridge.cross_validation import cross_validate
init_notebook_plotting()

# Parameters

In [None]:
root_dir = r"C:\Users\jesga\OneDrive\Desktop\NeuroscienceLeuven\Code\outputs\different_positions\vgrasp_Sky_20200819_1137_B_trials\SVM_ICA\Basic_SVM_f1_weighted\200\2020-12-15_18-14-39"

In [None]:
# Parameters
root_dir = "outputs/different_objects/vgrasp_Sky_20200928_1312_B_trials/Anova_RF/Anova_RF_f1_weighted/Default\\2020-12-30_17-11-12"


# Config Loading

In [None]:
config_path = os.path.join(root_dir,".hydra","config.yaml")
cfg = OmegaConf.load(config_path)

In [None]:
print(f"Config:\n{OmegaConf.to_yaml(cfg)}")

# Experiment Results

In [None]:
restore_path = os.path.join(root_dir,"ax_client_snapshot.json")
ax_client = AxClient.load_from_json_file(restore_path) 

In [None]:
trials_df = ax_client.get_trials_data_frame() 

In [None]:
# Sort trials from best performing (Test) to worse
trials_df = trials_df.sort_values(by=[cfg.obj_name], ascending = False)
# Show 5 best results
trials_df.head()

In [None]:
render(ax_client.get_optimization_trace())  

In [None]:
render(ax_client.get_feature_importances())

# Model Results

In [None]:
model = ax_client.generation_strategy.model

# Contour Plots

In [None]:
render(interact_contour(model=model, metric_name= cfg.obj_name))

# Slice Plots

In [None]:
plot_params = [param for param, param_dict in cfg.optimization.space.items() if (param_dict.type=="range")]
for param in plot_params:
    render(interact_slice(model, param_name = param))

# Cross Validation Plots

In [None]:
cv_results = cross_validate(model)
render(interact_cross_validation(cv_results))

# Interaction Plot

In [None]:
render(interact_fitted(model, rel=False))

# Trade-Off Plots

In [None]:
# obj_metric = 
# render(plot_objective_vs_constraints(model, obj_metric , rel=False))