In [None]:
from mainv3 import SystemDesign
import matplotlib.pyplot as plt
from equations import JouybanAcreeModel
import numpy as np
from groups import ja_groups

# Set up initial configurations for plots
plt.rcParams.update({
    'font.size': 12,          # Default font size
    'axes.labelsize': 14,     # Axis labels
    'axes.titlesize': 16,     # Subplot titles
    'xtick.labelsize': 12,    # X-axis tick labels
    'ytick.labelsize': 12,    # Y-axis tick labels
    'legend.fontsize': 12,    # Legend text
    'figure.titlesize': 18    # Figure title
})

In [None]:
system_load = SystemDesign.load("../../output/models/vae_system_s_2_50_features.pkl")

x,y = system_load.get_data_split_df()
y_pred = system_load.predict_model(x)

from data_module import DataProcessor
otherDataProcessor,_ = DataProcessor.CreateDataProcessor("curve_fit_results_x_is_3.csv")
results_df = system_load.dataprocess.raw_data[['group_index','temperature','solvent_1_pure','solvent_2_pure','J0','J1','J2']].merge(
    y_pred,
    left_index=True,
    right_index=True,
    suffixes=('','_pred')
).merge(
    otherDataProcessor.raw_data[['group_index','J0','J1','J2']],
    on='group_index',
    suffixes=('', '_JA5')
).drop_duplicates()

In [None]:
n = -1

In [None]:
n +=1 
group_index = int(results_df.iloc[n]['group_index'])
group = ja_groups[group_index]

ja_model = JouybanAcreeModel()  
x_values = np.linspace(0, 1, 101)

JA_fit_real = ja_model.predict(
    x_values, 
    results_df['solvent_1_pure'].iloc[n],
    results_df['solvent_2_pure'].iloc[n], 
    results_df['temperature'].iloc[n],
    results_df['J0_JA5'].iloc[n],
    results_df['J1_JA5'].iloc[n],
    results_df['J2_JA5'].iloc[n],
)

JA_fit_NN = ja_model.predict(
    x_values, 
    results_df['solvent_1_pure'].iloc[n],
    results_df['solvent_2_pure'].iloc[n], 
    results_df['temperature'].iloc[n],
    results_df['J0_pred'].iloc[n],
    results_df['J1_pred'].iloc[n],
    results_df['J2_pred'].iloc[n],
)

# Plot the JA model
plt.figure(figsize=(16*1.3/3, 9*1.3/3))
plt.plot(x_values, JA_fit_real, label='Empirical', color='blue')
plt.plot(x_values, JA_fit_NN, label='NN', color='red')


# Add the experimental data points to the plot
plt.scatter(group['solvent_1_weight_fraction'], group['solubility_g_g'], color='gray', label='Experimental Data')
plt.xlabel('Solvent 1 Weight Fraction')
plt.ylabel('Solubility (g/g)')
plt.legend()
plt.grid(True)
plt.show()