Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions solution/solution.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ func ExtractValueOfVariable(s Solution, v symbolic.Variable) (float64, error) {

// FindValueOfExpression evaluates a symbolic expression using the values from a solution.
// It substitutes all variables in the expression with their values from the solution
// and returns the resulting scalar value.
func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error) {
// and returns the resulting symbolic expression (typically a constant).
func FindValueOfExpression(s Solution, expr symbolic.Expression) (symbolic.Expression, error) {
// Get all variables in the expression
vars := expr.Variables()

Expand All @@ -54,7 +54,7 @@ func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error
for _, v := range vars {
val, err := ExtractValueOfVariable(s, v)
if err != nil {
return 0.0, fmt.Errorf(
return nil, fmt.Errorf(
"failed to extract value for variable %v: %w",
v.ID,
err,
Expand All @@ -66,16 +66,7 @@ func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error
// Substitute all variables with their values
resultExpr := expr.SubstituteAccordingTo(subMap)

// Type assert to K (constant) to extract the float64 value
resultK, ok := resultExpr.(symbolic.K)
if !ok {
return 0.0, fmt.Errorf(
"expected substituted expression to be a constant, got type %T",
resultExpr,
)
}

return float64(resultK), nil
return resultExpr, nil
}

// GetOptimalObjectiveValue evaluates the objective function of an optimization problem
Expand All @@ -95,10 +86,19 @@ func GetOptimalObjectiveValue(sol Solution) (float64, error) {
}

// Use FindValueOfExpression to evaluate the objective at the solution point
value, err := FindValueOfExpression(sol, objectiveExpr)
resultExpr, err := FindValueOfExpression(sol, objectiveExpr)
if err != nil {
return 0.0, fmt.Errorf("failed to evaluate objective expression: %w", err)
}

return value, nil
// Type assert to K (constant) to extract the float64 value
resultK, ok := resultExpr.(symbolic.K)
if !ok {
return 0.0, fmt.Errorf(
"expected substituted expression to be a constant, got type %T",
resultExpr,
)
}

return float64(resultK), nil
}
25 changes: 21 additions & 4 deletions testing/solution/solution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ Description:
(This seems like it is highly representative of the Gurobi solver; is there a reason to make it this way?)
*/

// Helper function to convert a symbolic.Expression to float64
func exprToFloat64(t *testing.T, expr symbolic.Expression) float64 {
resultK, ok := expr.(symbolic.K)
if !ok {
t.Fatalf("Expected result to be a constant, got type %T", expr)
}
return float64(resultK)
}

func TestSolution_ToMessage1(t *testing.T) {
// Constants
tempSol := solution.DummySolution{
Expand Down Expand Up @@ -161,11 +170,13 @@ func TestSolution_FindValueOfExpression1(t *testing.T) {
expr := v1.Multiply(symbolic.K(2.0)).Plus(v2.Multiply(symbolic.K(3.0)))

// Algorithm
result, err := solution.FindValueOfExpression(&tempSol, expr)
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
if err != nil {
t.Errorf("FindValueOfExpression returned an error: %v", err)
}

result := exprToFloat64(t, resultExpr)

expected := 13.0
if result != expected {
t.Errorf(
Expand Down Expand Up @@ -194,11 +205,13 @@ func TestSolution_FindValueOfExpression2(t *testing.T) {
expr := symbolic.K(42.0)

// Algorithm
result, err := solution.FindValueOfExpression(&tempSol, expr)
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
if err != nil {
t.Errorf("FindValueOfExpression returned an error: %v", err)
}

result := exprToFloat64(t, resultExpr)

expected := 42.0
if result != expected {
t.Errorf(
Expand Down Expand Up @@ -231,11 +244,13 @@ func TestSolution_FindValueOfExpression3(t *testing.T) {
expr := v1.Plus(symbolic.K(10.0))

// Algorithm
result, err := solution.FindValueOfExpression(&tempSol, expr)
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
if err != nil {
t.Errorf("FindValueOfExpression returned an error: %v", err)
}

result := exprToFloat64(t, resultExpr)

expected := 15.5
if result != expected {
t.Errorf(
Expand Down Expand Up @@ -304,11 +319,13 @@ func TestSolution_FindValueOfExpression5(t *testing.T) {
expr := v1.Plus(v2).Multiply(v3).Plus(symbolic.K(5.0))

// Algorithm
result, err := solution.FindValueOfExpression(&tempSol, expr)
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
if err != nil {
t.Errorf("FindValueOfExpression returned an error: %v", err)
}

result := exprToFloat64(t, resultExpr)

expected := 14.0
if result != expected {
t.Errorf(
Expand Down