In [None]:
import os
import pandas as pd
import numpy as np
from joblib import dump
import joblib
import yaml
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
stage = "train"
params = yaml.safe_load(open("/workspace/growthcurves/params.yaml"))
paths = params["config"]["paths"]
root_path = '/workspace/data/out/results_2024_09_25/target_3_months_sds/'
data_save_root = f"{root_path}/reports/figures/"
os.chdir(root_path)
print(os.curdir)

In [None]:
# Load pre-trained model
ebm_path = os.path.join(root_path, paths["models"], "other", "ExplainableBoostingRegressor_model.joblib")
ebm = joblib.load(ebm_path)

linreg_path = os.path.join(root_path, paths["models"], "other", "LinearRegression_model.joblib")
linreg = joblib.load(linreg_path)

symbolic_path = os.path.join(root_path, paths["models"], "symbolic", "gpg_model.joblib")
symbolic = joblib.load(symbolic_path)

# Get data 
model_data_path = os.path.join(root_path, paths["features"], "modelling_dataset.joblib")
modelling_dataset = joblib.load(model_data_path)

x_train = modelling_dataset['x_train']
y_train = modelling_dataset['y_train']
x_test = modelling_dataset['x_test']
y_test = modelling_dataset['y_test']


y_test['prediction_ebm'] = ebm.predict(x_test)
y_test['prediction_symbolic'] = symbolic.predict(x_test)
y_test['prediction_linreg'] = linreg.predict(x_test)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharex=True, sharey=True)

sns.scatterplot(ax=axes[0], data=y_test, x='target_igf_1_sds', y='prediction_linreg')
plt.sca(axes[0])
plt.plot([-1, 3], [-1, 3], color='black', linestyle='dashed')
axes[0].set_title("Linear Regression", size=20)
sns.scatterplot(ax=axes[1], data=y_test, x='target_igf_1_sds', y='prediction_ebm')
axes[1].set_title("Explainable Boosting Machine", size=20)
plt.sca(axes[1])
plt.plot([-1, 3], [-1, 3], color='black', linestyle='dashed')
sns.scatterplot(ax=axes[2], data=y_test, x='target_igf_1_sds', y='prediction_symbolic')
axes[2].set_title("Symbolic Regression", size=20)
plt.sca(axes[2])
plt.plot([-1, 3], [-1, 3], color='black', linestyle='dashed')

axes[0].set_ylabel('IGF-1 SDS predicted', size=16)
axes[0].set_xlabel('IGF-1 SDS true value', size=16)
axes[1].set_xlabel('IGF-1 SDS true value', size=16)
axes[2].set_xlabel('IGF-1 SDS true value', size=16)
plt.tight_layout()

plt.savefig(data_save_root + 'predictions_3_months.png')
plt.show()

In [None]:
symbolic