Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 56 additions & 135 deletions src/axiomatic/axtract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import re
from dataclasses import dataclass, asdict


@dataclass
class RequirementUserInput:
requirement_name: str
Expand Down Expand Up @@ -411,139 +410,6 @@ def save_requirements(_):
return result


def display_results(equations_dict):
"""Display equation validation results in a clear, organized format."""
results = equations_dict.get("results", {})

# Helper function to convert Eq(LHS,RHS) to LHS=RHS format
def format_equation(latex_eq):
# Remove 'Eq(' from start and ')' from end
inner = latex_eq[3:-1]
# Split by comma and join with equals sign
lhs, rhs = inner.split(',', 1)
return f"{lhs} = {rhs}"

# Split results into matching and non-matching equations
matching = []
non_matching = []

for key, value in results.items():
equation_data = {
'latex': format_equation(value.get('latex_equation')),
'lhs': value.get('lhs'),
'rhs': value.get('rhs'),
'diff': abs(value.get('lhs', 0) - value.get('rhs', 0)),
'percent_diff': abs(value.get('lhs', 0) - value.get('rhs', 0)) / max(abs(value.get('rhs', 0)), 1e-10) * 100
}
if value.get('match'):
matching.append(equation_data)
else:
non_matching.append(equation_data)

# Display summary header
total = len(results)
display(HTML(
f'<h3 style="font-family:Arial">Equation Validation Summary</h3>'
f'<p style="font-family:Arial">Total equations checked: {total}<br>'
f'✅ Matching equations: {len(matching)}<br>'
f'❌ Non-matching equations: {len(non_matching)}</p>'
))

# Display non-matching equations first (if any)
if non_matching:
display(HTML(
'<div style="background-color:#fff0f0; padding:10px; border-radius:5px; margin:10px 0;">'
'<h4 style="color:#cc0000; font-family:Arial">⚠️ Equations Not Satisfied:</h4>'
))

for eq in non_matching:
display(Math(eq['latex']))
display(HTML(
f'<div style="font-family:monospace; margin-left:20px; margin-bottom:15px">'
f'Left side = {eq["lhs"]:.6g}<br>'
f'Right side = {eq["rhs"]:.6g}<br>'
f'Difference = {eq["diff"]:.6g}<br>'
f'Percent difference = {eq["percent_diff"]:.2f}%'
'</div>'
))

display(HTML('</div>'))

# Display matching equations (if any)
if matching:
display(HTML(
'<div style="background-color:#f0fff0; padding:10px; border-radius:5px; margin:10px 0;">'
'<h4 style="color:#006600; font-family:Arial">✅ Satisfied Equations:</h4>'
))

for eq in matching:
display(Math(eq['latex']))
display(HTML(
f'<div style="font-family:monospace; margin-left:20px; margin-bottom:15px">'
f'Value = {eq["lhs"]:.6g}'
'</div>'
))

display(HTML('</div>'))

def get_eq_hypergraph(api_results, requirements, with_printing=True):

list_api_requirements = [asdict(req) for req in requirements]

# Disable external LaTeX rendering, using matplotlib's mathtext instead
plt.rcParams["text.usetex"] = False
plt.rcParams["mathtext.fontset"] = "stix"
plt.rcParams["font.family"] = "serif"

api_results = _add_used_vars_to_results(api_results, list_api_requirements)

# Prepare the data for HyperNetX visualization
hyperedges = {}
for eq, details in api_results["results"].items():
hyperedges[
_get_latex_string_format(details["latex_equation"])] = details[
"used_vars"
]

# Create the hypergraph using HyperNetX
H = hnx.Hypergraph(hyperedges)

# Plot the hypergraph with enhanced clarity
plt.figure(figsize=(16, 12))

# Draw the hypergraph with node and edge labels
hnx.draw(
H,
with_edge_labels=True,
edge_labels_on_edge=False,
node_labels_kwargs={"fontsize": 14},
edge_labels_kwargs={"fontsize": 14},
layout_kwargs={"seed": 42, "scale": 2.5},
)

node_labels = list(H.nodes)
symbol_explanations = _get_node_names_for_node_lables(
node_labels,
list_api_requirements)

# Adding the symbol explanations as a legend
explanation_text = "\n".join(
[f"${symbol}$: {desc}" for symbol, desc in symbol_explanations]
)
plt.annotate(
explanation_text,
xy=(1.05, 0.5),
xycoords="axes fraction",
fontsize=14,
verticalalignment="center",
)
plt.title(r"Enhanced Hypergraph of Equations and Variables", fontsize=20)
if with_printing:
plt.show()
return H
else:
return H


def _get_node_names_for_node_lables(node_labels, api_requirements):

Expand Down Expand Up @@ -741,4 +607,59 @@ def format_equation(latex_eq):
plt.title(r"Enhanced Hypergraph of Equations and Variables", fontsize=20)
plt.show()

return None
return None


def get_numerical_values(ax_client, path, constants_of_interest):
with open(path, "rb") as f:
file = f.read()

constants = ax_client.document.constants(file=file, constants=constants_of_interest).constants

# Create a dictionary to store processed values
processed_values = {}

# Process each constant name from the constants dictionary
for constant_name in constants:
value_str = constants[constant_name] # Get the value directly from the dictionary

if value_str is None:
# Handle None values
processed_values[constant_name] = {
"Value": 0.0,
"Units": "unknown"
}
elif 'F/' in value_str:
# Handle F-number values
f_number = float(value_str.split('/')[-1])
processed_values[constant_name] = {
"Value": f_number,
"Units": "dimensionless"
}
else:
# Handle normal values with units
# Split on the last space to separate value and unit
parts = value_str.rsplit(' ', 1)
if len(parts) == 2:
value, unit = parts
processed_values[constant_name] = {
"Value": float(value),
"Units": unit
}
else:
# If no unit is found
processed_values[constant_name] = {
"Value": float(parts[0]),
"Units": "unknown"
}

# Save as custom preset
filename = os.path.basename(path)
with open("./custom_presets.json", "r+") as f:
presets = json.load(f)
presets[filename] = processed_values
f.seek(0)
json.dump(presets, f, indent=2)
f.truncate()

return processed_values
Loading