# Setup

In [None]:
import chemprop
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.offsetbox import AnchoredText
from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.decomposition import PCA

In [None]:
def plot_parity(y_true, y_pred, y_pred_unc=None):
    
    axmin = min(min(y_true), min(y_pred)) - 0.1*(max(y_true)-min(y_true))
    axmax = max(max(y_true), max(y_pred)) + 0.1*(max(y_true)-min(y_true))
    
    mae = mean_absolute_error(y_true, y_pred)
    rmse = mean_squared_error(y_true, y_pred, squared=False)
    
    plt.plot([axmin, axmax], [axmin, axmax], '--k')

    plt.errorbar(y_true, y_pred, yerr=y_pred_unc, linewidth=0, marker='o', markeredgecolor='w', alpha=1, elinewidth=1)
    
    plt.xlim((axmin, axmax))
    plt.ylim((axmin, axmax))
    
    ax = plt.gca()
    ax.set_aspect('equal')
    
    at = AnchoredText(
    f"MAE = {mae:.2f}\nRMSE = {rmse:.2f}", prop=dict(size=10), frameon=True, loc='upper left')
    at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
    ax.add_artist(at)
    
    plt.xlabel('True')
    plt.ylabel('Chemprop Predicted')
    
    plt.show()
    
    return

In [None]:
arguments = [
    '--data_path', '../data/QM_100.csv',
    '--dataset_type', 'regression',
    '--save_dir', 'model/QM_137k_checkpoints',
    '--smiles_columns', 'smiles',
    '--target_columns', 'CDD hirshfeld_fukui_neu',
    '--is_atom_bond_targets',
    '--epochs', '100',
    '--save_smiles_splits',
    '--adding_h',
    '--show_individual_scores',
    '--split_type', 'scaffold_balanced',
    '--num_folds', '5',
    '--metric', 'rmse'

]

args = chemprop.args.TrainArgs().parse_args(arguments)
mean_score, std_score = chemprop.train.cross_validate(args=args, train_func=chemprop.train.run_training)

In [None]:
from chemprop.args import HyperoptArgs
from chemprop.hyperparameter_optimization import hyperopt, chemprop_hyperopt


hyperopt_arguments = [
    '--data_path', 'demo.csv',
    '--dataset_type', 'regression',
    '--config_save_path', 'hyperopt/best_hyperparams.json',  
    '--log_dir', 'hyperopt/hyperopt_logs', 
    '--save_dir', 'hyperopt/hyperopt_checkpoints',  
    '--smiles_columns', 'smiles',
    '--target_columns', 'CDD',
    '--is_atom_bond_targets',
    '--epochs', '100',
    '--save_smiles_splits',
    '--adding_h',
    '--show_individual_scores',
    '--num_iters', '10',
    '--num_workers','128', ####
    '--search_parameter_keywords', 'all'  
]
args = HyperoptArgs().parse_args(hyperopt_arguments)

hyperopt(args)

