In [None]:
from main import SystemDesign
import pickle


In [None]:
system = SystemDesign(
    system_columns=['solvent_1','solvent_2','temperature'],
    raw_data_path='curve_fit_results_x_is_7.csv',
    extra_fitted_points=1,
    target_columns=['J0','J1','J2']
)

# Train the model
system.train_model(
    feature_selection_method='random_forest',
    n_features=10,
    keep_prefixes=['solvent_1_pure','solvent_2_pure','system','solubility_','temperature'],
    epochs=100, 
    batch_size=32, 
    verbose=1
)

# Evaluate the model
system.evaluate_model()

# Get predictions and metrics
predictions, actuals, mae = system.get_predictions_and_metrics()

In [None]:
# Save the trained system to a file
with open('system_model.pkl', 'wb') as file:
    pickle.dump(system, file)
    
print("System model saved to 'system_model.pkl'")

# If you want to load it later, you can use:
# with open('system_model.pkl', 'rb') as file:
#     loaded_system = pickle.load(file)

In [None]:
with open('system_model.pkl', 'rb') as file:
    loaded_system = pickle.load(file)

In [None]:
from groups import ja_groups

In [None]:
x,y = system.get_data_split_df()
y_pred = system.predict_model(x[:5])

In [None]:
results_df = system.dataprocess.raw_data[['group_index','temperature','solvent_1_pure','solvent_2_pure']].merge(
    y,
    left_index=True, 
    right_index=True,
).merge(
    y_pred,
    left_index=True,
    right_index=True,
    suffixes=('','_pred')
)

In [None]:
results_df

In [None]:
import matplotlib.pyplot as plt

# Set up initial configurations for plots
plt.rcParams.update({
    'font.size': 12,          # Default font size
    'axes.labelsize': 14,     # Axis labels
    'axes.titlesize': 16,     # Subplot titles
    'xtick.labelsize': 12,    # X-axis tick labels
    'ytick.labelsize': 12,    # Y-axis tick labels
    'legend.fontsize': 12,    # Legend text
    'figure.titlesize': 18    # Figure title
})

In [None]:
from equations import JouybanAcreeModel
import numpy as np

In [None]:
n = -1

In [None]:
n +=1 
group_index = int(results_df.iloc[n]['group_index'])
group = ja_groups[group_index]



ja_model = JouybanAcreeModel()  
x_values = np.linspace(0, 1, 101)

JA_fit_real = ja_model.predict(
    x_values, 
    results_df['solvent_1_pure'].iloc[n],
    results_df['solvent_2_pure'].iloc[n], 
    results_df['temperature'].iloc[n],
    results_df['J0'].iloc[n],
    results_df['J1'].iloc[n],
    results_df['J2'].iloc[n],
)

JA_fit_NN = ja_model.predict(
    x_values, 
    results_df['solvent_1_pure'].iloc[n],
    results_df['solvent_2_pure'].iloc[n], 
    results_df['temperature'].iloc[n],
    results_df['J0'].iloc[n],
    results_df['J1'].iloc[n],
    results_df['J2'].iloc[n],
)

# Plot the JA model
plt.figure(figsize=(16*1.3/3, 9*1.3/3))
plt.plot(x_values, JA_fit_real, label='Empirical', color='blue')
plt.plot(x_values, JA_fit_real, label='NN', color='red')


# Add the experimental data points to the plot
plt.scatter(group['solvent_1_weight_fraction'], group['solubility_g_g'], color='lightgray', label='Experimental Data', zorder=5)
plt.xlabel('Solvent 1 Weight Fraction')
plt.ylabel('Solubility (g/g)')
plt.legend()
plt.grid(True)
plt.show()


