Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 180 additions & 42 deletions src/axiomatic/axtract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import ipywidgets as widgets # type: ignore
from IPython.display import display # type: ignore
from IPython.display import display, Math, HTML # type: ignore
from dataclasses import dataclass, field
import hypernetx as hnx # type: ignore
import matplotlib.pyplot as plt
import re

OPTION_LIST = {
"Select a template": [],
Expand All @@ -18,7 +21,7 @@
"Pixel size (multispectral)",
"Swath width",
],
"PAYLOAD": [
"PAYLOAD": [
"Resolution (panchromatic)",
"Ground sampling distance (panchromatic)",
"Resolution (multispectral)",
Expand Down Expand Up @@ -117,15 +120,14 @@ def requirements_from_table(results, variable_dict):
name = key
numerical_value = value["Value"]
unit = value["Units"]
tolerance = value["Tolerance"]

requirements.append(
Requirement(
requirement_name=name,
latex_symbol=latex_symbol,
value=numerical_value,
units=unit,
tolerance=tolerance,
tolerance=0.0,
)
)

Expand Down Expand Up @@ -178,9 +180,15 @@ def display_table(change):

if selected_option in preset_options_dict:
rows = preset_options_dict[selected_option]
max_name_length = max(len(name) for name in rows)
# Update the name_label_width based on the longest row name
name_label_width[0] = f"{max_name_length + 2}ch"

if selected_option != "Select a template":
max_name_length = max(len(name) for name in rows)
# Update the name_label_width based on the longest row name
name_label_width[0] = f"{max_name_length + 2}ch"
else:
max_name_length = 40
# Update the name_label_width based on the longest row name
name_label_width[0] = f"{max_name_length + 2}ch"

# Add Headers
header_labels = [
Expand All @@ -194,16 +202,6 @@ def display_table(change):
layout=widgets.Layout(width="150px"),
style={'font_weight': 'bold'}
),
widgets.Label(
value="Tolerance",
layout=widgets.Layout(width="150px"),
style={'font_weight': 'bold'}
),
widgets.Label(
value="Accuracy",
layout=widgets.Layout(width="150px"),
style={'font_weight': 'bold'}
),
widgets.Label(
value="Units",
layout=widgets.Layout(width="150px"),
Expand All @@ -216,7 +214,6 @@ def display_table(change):
header.layout = widgets.Layout(
border='1px solid black',
padding='5px',
background_color='#f0f0f0'
)

# Add the header to the rows_output VBox
Expand Down Expand Up @@ -244,28 +241,19 @@ def display_table(change):

# Create input widgets
value_text = widgets.FloatText(
placeholder="Value",
value=default_value,
layout=widgets.Layout(width="150px"),
)
tolerance_text = widgets.FloatText(
placeholder="Tolerance", layout=widgets.Layout(width="150px")
)
accuracy_text = widgets.FloatText(
placeholder="Accuracy", layout=widgets.Layout(width="150px")
)
units_text = widgets.Text(
placeholder="Units", layout=widgets.Layout(width="150px"),
value = default_unit
layout=widgets.Layout(width="150px"),
value=default_unit
)

# Combine widgets into a horizontal box
row = widgets.HBox(
[
name_label,
value_text,
tolerance_text,
accuracy_text,
units_text,
]
)
Expand All @@ -291,16 +279,12 @@ def submit_values(_):
if key.startswith("req_"):
updated_values[variable] = {
"Value": widget.children[1].value,
"Tolerance": widget.children[2].value,
"Accuracy": widget.children[3].value,
"Units": widget.children[4].value,
"Units": widget.children[2].value,
}
else:
updated_values[key] = {
"Value": widget.children[1].value,
"Tolerance": widget.children[2].value,
"Accuracy": widget.children[3].value,
"Units": widget.children[4].value,
"Units": widget.children[2].value,
}

result["values"] = updated_values
Expand All @@ -327,18 +311,13 @@ def add_req(_):
placeholder="Value",
layout=widgets.Layout(width="150px"),
)
tolerance_text = widgets.FloatText(
placeholder="Tolerance", layout=widgets.Layout(width="150px")
)
accuracy_text = widgets.FloatText(
placeholder="Accuracy", layout=widgets.Layout(width="150px")
)

units_text = widgets.Text(
placeholder="Units", layout=widgets.Layout(width="150px")
)

new_row = widgets.HBox(
[variable_dropdown, value_text, tolerance_text, accuracy_text, units_text]
[variable_dropdown, value_text, units_text]
)

rows_output.children += (new_row,)
Expand All @@ -354,3 +333,162 @@ def add_req(_):
display(buttons_box)

return result


def display_formatted_answers(equations_dict):
"""
Display LaTeX formatted equations and numerical results from a nested
dictionary structure in Jupyter Notebook.

Parameters:
equations_dict (dict): The dictionary containing the equations.
"""
results = equations_dict.get('results', {})
print("We identified the following equations that are relevant to your requirements:")

for key, value in results.items():
latex_equation = value.get('latex_equation')
lhs = value.get('lhs')
rhs = value.get('rhs')
match = value.get('match')
if latex_equation:
display(Math(latex_equation))
print(f"For provided values:\nleft hand side = {lhs}\nright hand side = {rhs}")
if match:
print("Provided requirements fulfill this mathematical relation")
else:
print(f"No LaTeX equation found for {key}")


def display_results(equations_dict):

results = equations_dict.get('results', {})
not_match_counter = 0

for key, value in results.items():
match = value.get('match')
latex_equation = value.get('latex_equation')
lhs = value.get('lhs')
rhs = value.get('rhs')
if not match:
not_match_counter += 1
display(HTML(
'<p style="color:red; '
'font-weight:bold; '
'font-family:\'Times New Roman\'; '
'font-size:16px;">'
'Provided requirements DO NOT fulfill the following mathematical relation:'
'</p>'
))
display(Math(latex_equation))
print(f"For provided values:\nleft hand side = {lhs}\nright hand side = {rhs}")
if not_match_counter == 0:
display(HTML(
'<p style="color:green; '
'font-weight:bold; '
'font-family:\'Times New Roman\'; '
'font-size:16px;">'
'Requirements you provided do not cause any conflicts'
'</p>'
))


def _get_latex_string_format(input_string):
"""
Properly formats LaTeX strings for matplotlib when text.usetex is False.
No escaping needed since mathtext handles backslashes properly.
"""
return f"${input_string}$" # No backslash escaping required


def _get_requirements_set(requirements):
variable_set = set()
for req in requirements:
variable_set.add(req['latex_symbol'])

return variable_set


def _find_vars_in_eq(equation, variable_set):
patterns = [re.escape(var) for var in variable_set]
combined_pattern = r'|'.join(patterns)
matches = re.findall(combined_pattern, equation)
return {fr"${match}$" for match in matches}


def _add_used_vars_to_results(api_results, api_requirements):
requirements = _get_requirements_set(api_requirements)

for key, value in api_results['results'].items():
latex_equation = value.get('latex_equation')
# print(latex_equation)
if latex_equation:
used_vars = _find_vars_in_eq(latex_equation, requirements)
api_results['results'][key]['used_vars'] = used_vars

return api_results


def get_eq_hypergraph(api_results, api_requirements):
# Disable external LaTeX rendering, using matplotlib's mathtext instead
plt.rcParams['text.usetex'] = False
plt.rcParams['mathtext.fontset'] = 'stix'
plt.rcParams['font.family'] = 'serif'

api_results = _add_used_vars_to_results(api_results, api_requirements)

# Prepare the data for HyperNetX visualization
hyperedges = {}
for eq, details in api_results["results"].items():
hyperedges[_get_latex_string_format(
details["latex_equation"])] = details["used_vars"]

# Create the hypergraph using HyperNetX
H = hnx.Hypergraph(hyperedges)

# Plot the hypergraph with enhanced clarity
plt.figure(figsize=(16, 12))

# Draw the hypergraph with node and edge labels
hnx.draw(
H,
with_edge_labels=True,
edge_labels_on_edge=False,
node_labels_kwargs={'fontsize': 14},
edge_labels_kwargs={'fontsize': 14},
layout_kwargs={'seed': 42, 'scale': 2.5}
)

node_labels = list(H.nodes)
symbol_explanations = _get_node_names_for_node_lables(node_labels, api_requirements)

# Adding the symbol explanations as a legend
explanation_text = "\n".join([f"${symbol}$: {desc}" for symbol, desc in symbol_explanations])
plt.annotate(
explanation_text,
xy=(1.05, 0.5),
xycoords='axes fraction',
fontsize=14,
verticalalignment='center'
)

plt.title(r"Enhanced Hypergraph of Equations and Variables", fontsize=20)
plt.show()


def _get_node_names_for_node_lables(node_labels, api_requirements):

# Create the output list
node_names = []

# Iterate through each symbol in S
for symbol in node_labels:
# Search for the matching requirement
symbol = symbol.replace("$", "")
for req in api_requirements:
if req['latex_symbol'] == symbol:
# Add the matching tuple to SS
node_names.append((req["latex_symbol"], req["requirement_name"]))
break # Stop searching once a match is found

return node_names
Loading