In [None]:
# --- Imports
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split

# Plotting
import matplotlib.pyplot as plt
import seaborn as sns

# Model imports
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import Lasso, Ridge, LinearRegression
from sklearn.preprocessing import PolynomialFeatures, StandardScaler

# --- Settings
sns.set_theme(style='whitegrid')
np.random.seed(42)

In [None]:
# Functions for plotting data and predictions - code from Fitting_topological_data

def MSE_R2(N, MSE_train, MSE_test, R2_train, R2_test): # better with twinax? / more compact
    # Plotting MSE and R2 against complexity
    plt.plot(N[1:], MSE_train, 'o-', label='train')
    plt.plot(N[1:], MSE_test, 'x-', label='test')
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5),
            ncol=1, fancybox=True)
    plt.show()

    plt.plot(N[1:],R2_train, 'o-', label='train')
    plt.plot(N[1:],R2_test, 'o-', label='test')
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5),
            ncol=1, fancybox=True)
    plt.show()



def plot_3D(x_plot, y_plot, z_plot, z_orig_plot, model='OLS'):
    fig = plt.figure(figsize=(12, 6))

    # Plotting the datapoints
    ax = fig.add_subplot(121, projection='3d')
    ax.plot_surface(x_plot, y_plot, z_orig_plot, cmap='viridis', edgecolor='none')
    ax.set_title('Topological data')

    # Plot the polynomial regression prediction (ideally a plane)
    ax = fig.add_subplot(122, projection='3d')
    ax.plot_surface(x_plot, y_plot, z_plot, cmap='viridis', edgecolor='none')
    ax.set_title(f'{model} prediction')
    
    plt.show()


def contour(x_plot, y_plot, z_plot, z_orig_plot, model='OLS'):
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    # Plotting the datapoints
    original_data = axs[0].contourf(x_plot, y_plot, z_orig_plot, cmap='viridis')
    axs[0].set_title('Topological data')

    # Plot the polynomial regression prediction
    fitted_data = axs[1].contourf(x_plot, y_plot, z_plot, cmap='viridis')
    axs[1].set_title(f'{model} prediction')

    # Create the colorbar
    contourf_objects = [original_data, fitted_data] #contour-format? why is it called that?
    fig.colorbar(contourf_objects[0], ax=axs, orientation='vertical', fraction=0.02, pad=0.04)

    plt.show()


