In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error

from wtencv import *


def draw_predicted_measured_plot(
    true,
    pred,
    scatter_title=None,
    scale_span=0.25,
    plot_color="black"
):
    plt.figure(figsize=(7,7))
    
    if scatter_title is not None:
        plt.title(scatter_title)
    
    plot_min = np.min((np.min(true), np.min(pred)))
    plot_max = np.max((np.max(true), np.max(pred)))
    
    plt.plot([-1,3], [-1,3], c="black")
    plt.scatter(true, pred, c=plot_color)

    plt.xlabel("Predicted")

    plt.xlim([plot_min, plot_max])
    
    plt.xticks(np.arange(plot_min//scale_span, plot_max/scale_span)*scale_span)

    plt.ylabel("Measured")

    plt.ylim([plot_min, plot_max])
    plt.yticks(np.arange(plot_min//scale_span, plot_max/scale_span)*scale_span)
    plt.show()


wavelet_num = 10
scaling = False
proportion = 4

file_name = "data_sets.csv"
df = pd.read_csv(file_name)
ir_datapoints_num = 3527
wave_numbers = np.array(df.columns[-ir_datapoints_num:]).astype(float)

colors = {
    "St": "blue",
    "THFMA": "purple",
    "CHMA": "green",
    "GMA": "red",
    "PACS": "darkgoldenrod"
}

monomer_names = [
    "St",
#     "THFMA",
#     "CHMA",
#     "GMA",
#     "PACS"
]

for monomer_name in monomer_names:
    df_monomer = df[df["M1_name"]==monomer_name]

    X = df_monomer.values[:, -ir_datapoints_num:]
    y = df_monomer["Conc_M1_monomer[mmol/g]"].values
    
    # If you give lower max_iter, you will face a large number of Warnings.
    wtencv = WTENCV(
        wave_numbers,
        wavelet_num=wavelet_num,
        proportion=proportion,
        scaling=scaling,
        cv_fold=2,
        max_iter=10,
        random_state=1,
    )
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    
    wtencv.fit(X_train, y_train)
    y_pred = wtencv.predict(X_test)    
    
    draw_predicted_measured_plot(y_test, y_pred, scatter_title=monomer_name, plot_color=colors[monomer_name])
    r2 = r2_score(y_test, y_pred)
    mae = mean_absolute_error(y_test, y_pred)
    rmse = np.sqrt(mean_squared_error(y_test, y_pred))
    print(f"R2: {r2:.3f}, MAE: {mae:.3f}, RMSE: {rmse:.3f}")

    wtencv.visualize_coef(title=monomer_name, color=colors[monomer_name])