In [3]:
from functools import partial

import os

import numpy as np
import pandas as pd


import cloudpickle
import json

# plotting
import matplotlib.pyplot as plt
import matplotlib.lines as mlines


from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

COLORMAP = ["#DFD27F", "#C65BAA", "#27474F"]

def parityplot(filename, train, test, title=None, 
               min_max=(-0.3, 0.3), ylabel="predicted", xlabel="true"):
    # Start with a square Figure.
    fig = plt.figure(figsize=(6, 6))
    # Add a gridspec with two rows and two columns and a ratio of 1 to 4 between
    # the size of the marginal axes and the main axes in both directions.
    # Also adjust the subplot parameters for a square plot.
    gs = fig.add_gridspec(2, 2,  width_ratios=(4, 1), height_ratios=(1, 4),
                          wspace=0.0, hspace=0.0)
    # Create the Axes.
    pp = fig.add_subplot(gs[1, 0])
    pp_histx = fig.add_subplot(gs[0, 0], sharex=pp)
    pp_histy = fig.add_subplot(gs[1, 1], sharey=pp)
    pp.plot(min_max, min_max)
    pp.scatter([-10e5],[-10e5], marker="+", color="black", alpha=0.6, label="training")
    pp.scatter([-10e5],[-10e5], marker="+", color="black", alpha=1, label="testing")
    
    pp_histx.tick_params(axis="x", labelbottom=False)
    pp_histy.tick_params(axis="y", labelleft=False)
    
    binwidth = 0.025
    bins = np.arange(min_max[0], min_max[1] + binwidth, binwidth)

    if len(train) > 2:
        true_multi = []
        pred_multi = []
        plot_colors = []
        for idx, label in enumerate(set(train[2]).union(set(test[2]))):
            pp.scatter(train[0][train[2] == label],train[1][train[2] == label], 
                       marker="+", color=COLORMAP[idx],
                       alpha=0.6)
            pp.scatter(test[0][test[2] == label],test[1][test[2] == label],
                       marker="+", color=COLORMAP[idx],
                       label=f"{label}")
            true_multi.append(train[0][train[2] == label])
            pred_multi.append(train[1][train[2] == label])
            plot_colors.append(COLORMAP[idx]+"99")
            true_multi.append(test[0][test[2] == label])
            pred_multi.append(test[1][test[2] == label])
            plot_colors.append(COLORMAP[idx])            
            
        pp_histx.hist(true_multi, bins, color=plot_colors, stacked=True)
        pp_histx.set_ylim((0.2, 25))
        pp_histy.hist(pred_multi, bins, color=plot_colors, stacked=True, orientation='horizontal')
        pp_histy.set_xlim((0.2, 25))

    else:
        pp.scatter(train[0],train[1], marker="p", label="training")
        pp.scatter(test[0],test[1], marker="P", label="testing")
    
    pp.legend()
    pp.set_ylabel(ylabel)
    pp.set_ylim(min_max)
    pp.set_xlabel(xlabel)
    pp.set_xlim(min_max)
    fig.suptitle(title)
    
    plt.savefig(filename)
    fig.clear()

def run_regressor_nested_cv(feats, target, model, params,
                            sample_class= None, view_class=None,
                            test_split=0.2, unit="eV", xlabel="$\mathrm{E}_{\mathrm{ads}}$",
                            name="test", scaler=None, pp_kws=None):
    pp_kws = {} if pp_kws is None else pp_kws
    folder = f"outputs/{name}"
    os.makedirs(folder, exist_ok=True)
  
    if scaler:
        scaler = scaler()
        scaler.fit(feats)
        feats = scaler.transform(feats)
        with open(os.path.join(folder, "scaler.pt"), "wb") as f:
            cloudpickle.dump(scaler,f)
            
    
    eval = { 
        "train_mae" : [],
        "test_mae" : [],
        "train_mse" : [],
        "test_mse" : [],
        "train_r2" : [],
        "test_r2" : []
    }
    
    def format_eval():
        return f"MAE [{unit}]: {eval['test_mae'][-1]:.4f} ({eval['train_mae'][-1]:.4f}) - "\
               f"R²: {eval['test_r2'][-1]:.2f} ({eval['train_r2'][-1]:.2f})"
    
    view_class = np.zeros_like(target) if view_class is None else view_class
        
    for i in range(5):
        feats_cv, feats_test, target_cv, target_test, class_cv, class_test = \
            train_test_split(feats, target, view_class, random_state=1868+i, test_size=test_split, stratify=sample_class)
        hypmodel = GridSearchCV(model(), params, 
                                cv=5, scoring="neg_mean_squared_error", n_jobs=8)
        hypmodel.fit(feats_cv, target_cv)
        pred_cv = hypmodel.predict(feats_cv)
        pred_test = hypmodel.predict(feats_test)
        
        for t, est, name in list(
            [(target_cv, pred_cv, "train",), (target_test, pred_test, "test")]):
            eval[f"{name}_mae"].append(mean_absolute_error(t, est))
            eval[f"{name}_mse"].append(mean_squared_error(t, est))
            eval[f"{name}_r2"].append(r2_score(t, est))
            
            
        parityplot(os.path.join(folder, f"pp{str(i).zfill(2)}.pdf",),
                  (target_cv, pred_cv, class_cv),
                  (target_test, pred_test, class_test),
                  title=format_eval(), xlabel=f"{xlabel} [{unit}]",
                   ylabel=f"prediction [{unit}]",
                  **pp_kws)
        
        with open(os.path.join(folder, f"model{str(i).zfill(2)}.pt"), "wb") as f:
            cloudpickle.dump(hypmodel,f)
            
    with open(os.path.join(folder, "metrics.json"), 'w') as f:
        json.dump(eval, f)
    
    eval_agg = dict(
        ((k, 
          (f"{sum(v)/len(v):.4f}", f"{np.std(v):.4f}")
          ) for k, v in eval.items()))
    return eval_agg


def run_regressor_manual(feats_cv, target_cv,
                         feats_test, target_test,
                         model, params,
                         name="test", scaler=None, pp_kws=None):
    pp_kws = {} if pp_kws is None else pp_kws
    folder = f"outputs_manual/{name}"
    os.makedirs(folder, exist_ok=True)
    
    if scaler:
        scaler = scaler()
        scaler.fit(feats_cv)
        feats_cv = scaler.transform(feats_cv)
        feats_test = scaler.transform(feats_test)
        with open(os.path.join(folder, "scaler.pt"), "wb") as f:
            cloudpickle.dump(scaler,f)
            
    
    eval = { 
        "train_mae" : [],
        "test_mae" : [],
        "train_mse" : [],
        "test_mse" : [],
        "train_r2" : [],
        "test_r2" : []
    }
    
    def format_eval():
        return f"mae: {eval['test_mae'][-1]:.4f} ({eval['train_mae'][-1]:.4f})"\
               f"R²: {eval['test_r2'][-1]:.2f} ({eval['train_r2'][-1]:.2f})"
        
    hypmodel = GridSearchCV(model(), params, 
                     cv=5, scoring="neg_mean_squared_error", n_jobs=8)
    hypmodel.fit(feats_cv, target_cv)
    pred_cv = hypmodel.predict(feats_cv)
    pred_test = hypmodel.predict(feats_test)

    for t, est, name in list(
        [(target_cv, pred_cv, "train",), (target_test, pred_test, "test")]):
        eval[f"{name}_mae"].append(mean_absolute_error(t, est))
        eval[f"{name}_mse"].append(mean_squared_error(t, est))
        eval[f"{name}_r2"].append(r2_score(t, est))


    parityplot(os.path.join(folder, f"pp.pdf",),
              (target_cv, pred_cv),
              (target_test, pred_test),
              title=format_eval(), **pp_kws)

    with open(os.path.join(folder, f"model.pt"), "wb") as f:
        cloudpickle.dump(hypmodel,f)

    with open(os.path.join(folder, "metrics.json"), 'w') as f:
        json.dump(eval, f)
    
    eval_agg = dict(
        ((k, 
          (sum(v)/len(v), np.std(v))
          ) for k, v in eval.items()))
    return eval_agg

from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, MultiOutputMixin

class StratifiedMedianRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
    def __init__(self):
        self.lookup_table = {}
    
    def fit(self, X, y):
        x_flat = X.ravel()
        for ux in np.unique(x_flat):
            self.lookup_table[ux] = np.median(y[x_flat == ux])
        return self
    
    def predict(self, X):
        x_flat = X.ravel()
        return np.array([self.lookup_table[ux] for ux in x_flat])
    
    def score(self, X, y, sample_weight=None):
        if X is None:
            X = np.zeros(shape=(len(y), 1))
        return super().score(X, y, sample_weight)