diff --git a/solution/solution.go b/solution/solution.go index 9d805a0..bae1285 100644 --- a/solution/solution.go +++ b/solution/solution.go @@ -44,8 +44,8 @@ func ExtractValueOfVariable(s Solution, v symbolic.Variable) (float64, error) { // FindValueOfExpression evaluates a symbolic expression using the values from a solution. // It substitutes all variables in the expression with their values from the solution -// and returns the resulting scalar value. -func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error) { +// and returns the resulting symbolic expression (typically a constant). +func FindValueOfExpression(s Solution, expr symbolic.Expression) (symbolic.Expression, error) { // Get all variables in the expression vars := expr.Variables() @@ -54,7 +54,7 @@ func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error for _, v := range vars { val, err := ExtractValueOfVariable(s, v) if err != nil { - return 0.0, fmt.Errorf( + return nil, fmt.Errorf( "failed to extract value for variable %v: %w", v.ID, err, @@ -66,16 +66,7 @@ func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error // Substitute all variables with their values resultExpr := expr.SubstituteAccordingTo(subMap) - // Type assert to K (constant) to extract the float64 value - resultK, ok := resultExpr.(symbolic.K) - if !ok { - return 0.0, fmt.Errorf( - "expected substituted expression to be a constant, got type %T", - resultExpr, - ) - } - - return float64(resultK), nil + return resultExpr, nil } // GetOptimalObjectiveValue evaluates the objective function of an optimization problem @@ -95,10 +86,19 @@ func GetOptimalObjectiveValue(sol Solution) (float64, error) { } // Use FindValueOfExpression to evaluate the objective at the solution point - value, err := FindValueOfExpression(sol, objectiveExpr) + resultExpr, err := FindValueOfExpression(sol, objectiveExpr) if err != nil { return 0.0, fmt.Errorf("failed to evaluate objective expression: %w", err) } - return value, nil + // Type assert to K (constant) to extract the float64 value + resultK, ok := resultExpr.(symbolic.K) + if !ok { + return 0.0, fmt.Errorf( + "expected substituted expression to be a constant, got type %T", + resultExpr, + ) + } + + return float64(resultK), nil } diff --git a/testing/solution/solution_test.go b/testing/solution/solution_test.go index 6436c91..c3a4f71 100644 --- a/testing/solution/solution_test.go +++ b/testing/solution/solution_test.go @@ -17,6 +17,15 @@ Description: (This seems like it is highly representative of the Gurobi solver; is there a reason to make it this way?) */ +// Helper function to convert a symbolic.Expression to float64 +func exprToFloat64(t *testing.T, expr symbolic.Expression) float64 { + resultK, ok := expr.(symbolic.K) + if !ok { + t.Fatalf("Expected result to be a constant, got type %T", expr) + } + return float64(resultK) +} + func TestSolution_ToMessage1(t *testing.T) { // Constants tempSol := solution.DummySolution{ @@ -161,11 +170,13 @@ func TestSolution_FindValueOfExpression1(t *testing.T) { expr := v1.Multiply(symbolic.K(2.0)).Plus(v2.Multiply(symbolic.K(3.0))) // Algorithm - result, err := solution.FindValueOfExpression(&tempSol, expr) + resultExpr, err := solution.FindValueOfExpression(&tempSol, expr) if err != nil { t.Errorf("FindValueOfExpression returned an error: %v", err) } + result := exprToFloat64(t, resultExpr) + expected := 13.0 if result != expected { t.Errorf( @@ -194,11 +205,13 @@ func TestSolution_FindValueOfExpression2(t *testing.T) { expr := symbolic.K(42.0) // Algorithm - result, err := solution.FindValueOfExpression(&tempSol, expr) + resultExpr, err := solution.FindValueOfExpression(&tempSol, expr) if err != nil { t.Errorf("FindValueOfExpression returned an error: %v", err) } + result := exprToFloat64(t, resultExpr) + expected := 42.0 if result != expected { t.Errorf( @@ -231,11 +244,13 @@ func TestSolution_FindValueOfExpression3(t *testing.T) { expr := v1.Plus(symbolic.K(10.0)) // Algorithm - result, err := solution.FindValueOfExpression(&tempSol, expr) + resultExpr, err := solution.FindValueOfExpression(&tempSol, expr) if err != nil { t.Errorf("FindValueOfExpression returned an error: %v", err) } + result := exprToFloat64(t, resultExpr) + expected := 15.5 if result != expected { t.Errorf( @@ -304,11 +319,13 @@ func TestSolution_FindValueOfExpression5(t *testing.T) { expr := v1.Plus(v2).Multiply(v3).Plus(symbolic.K(5.0)) // Algorithm - result, err := solution.FindValueOfExpression(&tempSol, expr) + resultExpr, err := solution.FindValueOfExpression(&tempSol, expr) if err != nil { t.Errorf("FindValueOfExpression returned an error: %v", err) } + result := exprToFloat64(t, resultExpr) + expected := 14.0 if result != expected { t.Errorf(