# Functions for the pipeline

## Plot predictions
(1) should set PI like .89 rather than standard deviations (e.g. 3). <br/>
(2) the dots should be in the front rather than behind the lines.

In [None]:
def plot_pp(pp, x_train, x_test, y_train, y_test, label_column, d,
               lines = True, size = 50, std = 3):
    '''
    ### documentation so far. 
    pp: posterior (or prior?) predictive.
    x_train: train x data (np.array)
    x_test: test x data (np.array)
    y_train: train y data (np.array)
    y_test: test y data (np.array)
    label_column: column containing labels for title (str)
    d: your data (pd.DataFrame)
    lines: True for Krutschke style plots with draws.
    size: how many lines to draw if lines = True.
    std: standard deviation of shaded area. 
    
    ### to be done. 
    1. check that it works for several groups (non-multilevel). 
    2. plot format in cases of many subplots?
    '''
    
    # go into data and take out labels 
    labels = np.unique(d[label_column].values)
    len_idx = len(labels)
    
    # needed variables (double assignment)
    len_test_rolling = len_test = len(np.unique(x_test))
    len_train_rolling = len_train = len(np.unique(x_train))
    draws = pp['α'].shape[0]
    
    # the loop.
    for i in range(len_idx):
        
        ### if only one level 
        if len_idx == 1:
            y_pred = pp["y_pred"] # should not be hard-coded of course. 
            y_pred_mean = y_pred.mean(axis=0)
            y_pred_std = y_pred.std(axis=0)
            plt.figure(figsize=(16, 8))
            plt.scatter(x_train, y_train, c='k', zorder=10, label='Data')
            plt.scatter(x_test, y_test, zorder = 10, c="red", label='Held-out')
            plt.plot(x_test, y_pred_mean, label='Prediction Mean', linewidth = 5, c = "k")
            plt.fill_between(x_test, y_pred_mean - std*y_pred_std, y_pred_mean + std*y_pred_std, 
                             alpha=0.2, label='Uncertainty Interval ($\mu\pm3\sigma$)')
        
        ### if multilevel
        if len_idx > 1: 
            y_pred = pp["y_pred"] # should not be hard-coded of course. 
            y_pred_mean = y_pred[:, len_test_rolling-len_test:len_test_rolling].mean(axis=0)
            y_pred_std = y_pred[:, len_test_rolling-len_test:len_test_rolling].std(axis=0)
            plt.figure(figsize=(16, 8))
            plt.scatter(x_train[len_train_rolling-len_train:len_train_rolling], 
                        y_train[len_train_rolling-len_train:len_train_rolling], 
                        c='k', zorder = 10, label='Data')
            plt.scatter(x_test[len_test_rolling-len_test:len_test_rolling], 
                        y_test[len_test_rolling-len_test:len_test_rolling], 
                        c="red", label='Held-out')
            plt.plot(x_test[len_test_rolling-len_test:len_test_rolling], 
                     y_pred_mean, 
                     label='Prediction Mean', linewidth = 5, c = "k")
            plt.fill_between(x_test[len_test_rolling-len_test:len_test_rolling], 
                             y_pred_mean - std*y_pred_std, y_pred_mean + std*y_pred_std, 
                             alpha=0.2, label='Uncertainty Interval ($\mu\pm3\sigma$)')
            len_test_rolling += len_test
            len_train_rolling += len_train
            
        ## optionally add lines 
        if lines == True: 

            # should take a pair. (randomize index).
            samples = [random.randrange(draws) for x in range(size)]

            ### if only one level
            if len_idx == 1: 
                alpha = [pp["α"][sample] for sample in samples]
                beta = [pp["β"][sample] for sample in samples]

            ### if multilevel
            if len_idx > 1: 
                alpha = [pp["α"][sample, i] for sample in samples]
                beta = [pp["β"][sample, i] for sample in samples]

            for a, b in zip(alpha, beta):
                y = a + b * x_test
                plt.plot(x_test, y, c="k", alpha=0.4)
                
        ## labeling.
        plt.xlabel('$x$')
        plt.ylabel('$y$')
        plt.title(label = f"{labels[i]}", fontsize = 20)
        plt.legend(loc='upper left')
    

## MSE 
get MSE between predictions and true.

In [None]:
# mean squared error (implement more) 
def MSE_fun(y_true, y_pred): 
    MSE = np.square(np.subtract(y_true, y_pred)).mean() 
    return MSE 

# Ideas


## (1) residual together with prediction:
https://stackoverflow.com/questions/24116318/how-to-show-residual-in-the-bottom-of-a-matplotlib-plot