diff --git a/src/axiomatic/axtract.py b/src/axiomatic/axtract.py index 9dda406..4c16fec 100644 --- a/src/axiomatic/axtract.py +++ b/src/axiomatic/axtract.py @@ -1,6 +1,9 @@ import ipywidgets as widgets # type: ignore -from IPython.display import display # type: ignore +from IPython.display import display, Math, HTML # type: ignore from dataclasses import dataclass, field +import hypernetx as hnx # type: ignore +import matplotlib.pyplot as plt +import re OPTION_LIST = { "Select a template": [], @@ -18,7 +21,7 @@ "Pixel size (multispectral)", "Swath width", ], - "PAYLOAD": [ + "PAYLOAD": [ "Resolution (panchromatic)", "Ground sampling distance (panchromatic)", "Resolution (multispectral)", @@ -117,7 +120,6 @@ def requirements_from_table(results, variable_dict): name = key numerical_value = value["Value"] unit = value["Units"] - tolerance = value["Tolerance"] requirements.append( Requirement( @@ -125,7 +127,7 @@ def requirements_from_table(results, variable_dict): latex_symbol=latex_symbol, value=numerical_value, units=unit, - tolerance=tolerance, + tolerance=0.0, ) ) @@ -178,9 +180,15 @@ def display_table(change): if selected_option in preset_options_dict: rows = preset_options_dict[selected_option] - max_name_length = max(len(name) for name in rows) - # Update the name_label_width based on the longest row name - name_label_width[0] = f"{max_name_length + 2}ch" + + if selected_option != "Select a template": + max_name_length = max(len(name) for name in rows) + # Update the name_label_width based on the longest row name + name_label_width[0] = f"{max_name_length + 2}ch" + else: + max_name_length = 40 + # Update the name_label_width based on the longest row name + name_label_width[0] = f"{max_name_length + 2}ch" # Add Headers header_labels = [ @@ -194,16 +202,6 @@ def display_table(change): layout=widgets.Layout(width="150px"), style={'font_weight': 'bold'} ), - widgets.Label( - value="Tolerance", - layout=widgets.Layout(width="150px"), - style={'font_weight': 'bold'} - ), - widgets.Label( - value="Accuracy", - layout=widgets.Layout(width="150px"), - style={'font_weight': 'bold'} - ), widgets.Label( value="Units", layout=widgets.Layout(width="150px"), @@ -216,7 +214,6 @@ def display_table(change): header.layout = widgets.Layout( border='1px solid black', padding='5px', - background_color='#f0f0f0' ) # Add the header to the rows_output VBox @@ -244,19 +241,12 @@ def display_table(change): # Create input widgets value_text = widgets.FloatText( - placeholder="Value", value=default_value, layout=widgets.Layout(width="150px"), ) - tolerance_text = widgets.FloatText( - placeholder="Tolerance", layout=widgets.Layout(width="150px") - ) - accuracy_text = widgets.FloatText( - placeholder="Accuracy", layout=widgets.Layout(width="150px") - ) units_text = widgets.Text( - placeholder="Units", layout=widgets.Layout(width="150px"), - value = default_unit + layout=widgets.Layout(width="150px"), + value=default_unit ) # Combine widgets into a horizontal box @@ -264,8 +254,6 @@ def display_table(change): [ name_label, value_text, - tolerance_text, - accuracy_text, units_text, ] ) @@ -291,16 +279,12 @@ def submit_values(_): if key.startswith("req_"): updated_values[variable] = { "Value": widget.children[1].value, - "Tolerance": widget.children[2].value, - "Accuracy": widget.children[3].value, - "Units": widget.children[4].value, + "Units": widget.children[2].value, } else: updated_values[key] = { "Value": widget.children[1].value, - "Tolerance": widget.children[2].value, - "Accuracy": widget.children[3].value, - "Units": widget.children[4].value, + "Units": widget.children[2].value, } result["values"] = updated_values @@ -327,18 +311,13 @@ def add_req(_): placeholder="Value", layout=widgets.Layout(width="150px"), ) - tolerance_text = widgets.FloatText( - placeholder="Tolerance", layout=widgets.Layout(width="150px") - ) - accuracy_text = widgets.FloatText( - placeholder="Accuracy", layout=widgets.Layout(width="150px") - ) + units_text = widgets.Text( placeholder="Units", layout=widgets.Layout(width="150px") ) new_row = widgets.HBox( - [variable_dropdown, value_text, tolerance_text, accuracy_text, units_text] + [variable_dropdown, value_text, units_text] ) rows_output.children += (new_row,) @@ -354,3 +333,162 @@ def add_req(_): display(buttons_box) return result + + +def display_formatted_answers(equations_dict): + """ + Display LaTeX formatted equations and numerical results from a nested + dictionary structure in Jupyter Notebook. + + Parameters: + equations_dict (dict): The dictionary containing the equations. + """ + results = equations_dict.get('results', {}) + print("We identified the following equations that are relevant to your requirements:") + + for key, value in results.items(): + latex_equation = value.get('latex_equation') + lhs = value.get('lhs') + rhs = value.get('rhs') + match = value.get('match') + if latex_equation: + display(Math(latex_equation)) + print(f"For provided values:\nleft hand side = {lhs}\nright hand side = {rhs}") + if match: + print("Provided requirements fulfill this mathematical relation") + else: + print(f"No LaTeX equation found for {key}") + + +def display_results(equations_dict): + + results = equations_dict.get('results', {}) + not_match_counter = 0 + + for key, value in results.items(): + match = value.get('match') + latex_equation = value.get('latex_equation') + lhs = value.get('lhs') + rhs = value.get('rhs') + if not match: + not_match_counter += 1 + display(HTML( + '
' + 'Provided requirements DO NOT fulfill the following mathematical relation:' + '
' + )) + display(Math(latex_equation)) + print(f"For provided values:\nleft hand side = {lhs}\nright hand side = {rhs}") + if not_match_counter == 0: + display(HTML( + '' + 'Requirements you provided do not cause any conflicts' + '
' + )) + + +def _get_latex_string_format(input_string): + """ + Properly formats LaTeX strings for matplotlib when text.usetex is False. + No escaping needed since mathtext handles backslashes properly. + """ + return f"${input_string}$" # No backslash escaping required + + +def _get_requirements_set(requirements): + variable_set = set() + for req in requirements: + variable_set.add(req['latex_symbol']) + + return variable_set + + +def _find_vars_in_eq(equation, variable_set): + patterns = [re.escape(var) for var in variable_set] + combined_pattern = r'|'.join(patterns) + matches = re.findall(combined_pattern, equation) + return {fr"${match}$" for match in matches} + + +def _add_used_vars_to_results(api_results, api_requirements): + requirements = _get_requirements_set(api_requirements) + + for key, value in api_results['results'].items(): + latex_equation = value.get('latex_equation') + # print(latex_equation) + if latex_equation: + used_vars = _find_vars_in_eq(latex_equation, requirements) + api_results['results'][key]['used_vars'] = used_vars + + return api_results + + +def get_eq_hypergraph(api_results, api_requirements): + # Disable external LaTeX rendering, using matplotlib's mathtext instead + plt.rcParams['text.usetex'] = False + plt.rcParams['mathtext.fontset'] = 'stix' + plt.rcParams['font.family'] = 'serif' + + api_results = _add_used_vars_to_results(api_results, api_requirements) + + # Prepare the data for HyperNetX visualization + hyperedges = {} + for eq, details in api_results["results"].items(): + hyperedges[_get_latex_string_format( + details["latex_equation"])] = details["used_vars"] + + # Create the hypergraph using HyperNetX + H = hnx.Hypergraph(hyperedges) + + # Plot the hypergraph with enhanced clarity + plt.figure(figsize=(16, 12)) + + # Draw the hypergraph with node and edge labels + hnx.draw( + H, + with_edge_labels=True, + edge_labels_on_edge=False, + node_labels_kwargs={'fontsize': 14}, + edge_labels_kwargs={'fontsize': 14}, + layout_kwargs={'seed': 42, 'scale': 2.5} + ) + + node_labels = list(H.nodes) + symbol_explanations = _get_node_names_for_node_lables(node_labels, api_requirements) + + # Adding the symbol explanations as a legend + explanation_text = "\n".join([f"${symbol}$: {desc}" for symbol, desc in symbol_explanations]) + plt.annotate( + explanation_text, + xy=(1.05, 0.5), + xycoords='axes fraction', + fontsize=14, + verticalalignment='center' + ) + + plt.title(r"Enhanced Hypergraph of Equations and Variables", fontsize=20) + plt.show() + + +def _get_node_names_for_node_lables(node_labels, api_requirements): + + # Create the output list + node_names = [] + + # Iterate through each symbol in S + for symbol in node_labels: + # Search for the matching requirement + symbol = symbol.replace("$", "") + for req in api_requirements: + if req['latex_symbol'] == symbol: + # Add the matching tuple to SS + node_names.append((req["latex_symbol"], req["requirement_name"])) + break # Stop searching once a match is found + + return node_names