From b076eaa629d258f45ea81be2fd601021d5a40d60 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 14 Oct 2025 23:03:01 +0000 Subject: [PATCH 1/3] Initial plan From 85577ea1094cd323bc0e39f8280aa0e2083a4cda Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 14 Oct 2025 23:07:42 +0000 Subject: [PATCH 2/3] Change FindValueOfExpression signature to return symbolic.Expression Co-authored-by: kwesiRutledge <9002730+kwesiRutledge@users.noreply.github.com> --- solution/solution.go | 30 +++++++++++++------------- testing/solution/solution_test.go | 36 +++++++++++++++++++++++++++---- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/solution/solution.go b/solution/solution.go index 9d805a0..bae1285 100644 --- a/solution/solution.go +++ b/solution/solution.go @@ -44,8 +44,8 @@ func ExtractValueOfVariable(s Solution, v symbolic.Variable) (float64, error) { // FindValueOfExpression evaluates a symbolic expression using the values from a solution. // It substitutes all variables in the expression with their values from the solution -// and returns the resulting scalar value. -func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error) { +// and returns the resulting symbolic expression (typically a constant). +func FindValueOfExpression(s Solution, expr symbolic.Expression) (symbolic.Expression, error) { // Get all variables in the expression vars := expr.Variables() @@ -54,7 +54,7 @@ func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error for _, v := range vars { val, err := ExtractValueOfVariable(s, v) if err != nil { - return 0.0, fmt.Errorf( + return nil, fmt.Errorf( "failed to extract value for variable %v: %w", v.ID, err, @@ -66,16 +66,7 @@ func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error // Substitute all variables with their values resultExpr := expr.SubstituteAccordingTo(subMap) - // Type assert to K (constant) to extract the float64 value - resultK, ok := resultExpr.(symbolic.K) - if !ok { - return 0.0, fmt.Errorf( - "expected substituted expression to be a constant, got type %T", - resultExpr, - ) - } - - return float64(resultK), nil + return resultExpr, nil } // GetOptimalObjectiveValue evaluates the objective function of an optimization problem @@ -95,10 +86,19 @@ func GetOptimalObjectiveValue(sol Solution) (float64, error) { } // Use FindValueOfExpression to evaluate the objective at the solution point - value, err := FindValueOfExpression(sol, objectiveExpr) + resultExpr, err := FindValueOfExpression(sol, objectiveExpr) if err != nil { return 0.0, fmt.Errorf("failed to evaluate objective expression: %w", err) } - return value, nil + // Type assert to K (constant) to extract the float64 value + resultK, ok := resultExpr.(symbolic.K) + if !ok { + return 0.0, fmt.Errorf( + "expected substituted expression to be a constant, got type %T", + resultExpr, + ) + } + + return float64(resultK), nil } diff --git a/testing/solution/solution_test.go b/testing/solution/solution_test.go index 6436c91..2176601 100644 --- a/testing/solution/solution_test.go +++ b/testing/solution/solution_test.go @@ -161,11 +161,18 @@ func TestSolution_FindValueOfExpression1(t *testing.T) { expr := v1.Multiply(symbolic.K(2.0)).Plus(v2.Multiply(symbolic.K(3.0))) // Algorithm - result, err := solution.FindValueOfExpression(&tempSol, expr) + resultExpr, err := solution.FindValueOfExpression(&tempSol, expr) if err != nil { t.Errorf("FindValueOfExpression returned an error: %v", err) } + // Convert to float64 + resultK, ok := resultExpr.(symbolic.K) + if !ok { + t.Errorf("Expected result to be a constant, got type %T", resultExpr) + } + result := float64(resultK) + expected := 13.0 if result != expected { t.Errorf( @@ -194,11 +201,18 @@ func TestSolution_FindValueOfExpression2(t *testing.T) { expr := symbolic.K(42.0) // Algorithm - result, err := solution.FindValueOfExpression(&tempSol, expr) + resultExpr, err := solution.FindValueOfExpression(&tempSol, expr) if err != nil { t.Errorf("FindValueOfExpression returned an error: %v", err) } + // Convert to float64 + resultK, ok := resultExpr.(symbolic.K) + if !ok { + t.Errorf("Expected result to be a constant, got type %T", resultExpr) + } + result := float64(resultK) + expected := 42.0 if result != expected { t.Errorf( @@ -231,11 +245,18 @@ func TestSolution_FindValueOfExpression3(t *testing.T) { expr := v1.Plus(symbolic.K(10.0)) // Algorithm - result, err := solution.FindValueOfExpression(&tempSol, expr) + resultExpr, err := solution.FindValueOfExpression(&tempSol, expr) if err != nil { t.Errorf("FindValueOfExpression returned an error: %v", err) } + // Convert to float64 + resultK, ok := resultExpr.(symbolic.K) + if !ok { + t.Errorf("Expected result to be a constant, got type %T", resultExpr) + } + result := float64(resultK) + expected := 15.5 if result != expected { t.Errorf( @@ -304,11 +325,18 @@ func TestSolution_FindValueOfExpression5(t *testing.T) { expr := v1.Plus(v2).Multiply(v3).Plus(symbolic.K(5.0)) // Algorithm - result, err := solution.FindValueOfExpression(&tempSol, expr) + resultExpr, err := solution.FindValueOfExpression(&tempSol, expr) if err != nil { t.Errorf("FindValueOfExpression returned an error: %v", err) } + // Convert to float64 + resultK, ok := resultExpr.(symbolic.K) + if !ok { + t.Errorf("Expected result to be a constant, got type %T", resultExpr) + } + result := float64(resultK) + expected := 14.0 if result != expected { t.Errorf( From 34c98d16a31eb28b57c078c3eb51247ced1f7a6c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 14 Oct 2025 23:09:10 +0000 Subject: [PATCH 3/3] Add helper function to reduce code duplication in tests Co-authored-by: kwesiRutledge <9002730+kwesiRutledge@users.noreply.github.com> --- testing/solution/solution_test.go | 37 +++++++++++-------------------- 1 file changed, 13 insertions(+), 24 deletions(-) diff --git a/testing/solution/solution_test.go b/testing/solution/solution_test.go index 2176601..c3a4f71 100644 --- a/testing/solution/solution_test.go +++ b/testing/solution/solution_test.go @@ -17,6 +17,15 @@ Description: (This seems like it is highly representative of the Gurobi solver; is there a reason to make it this way?) */ +// Helper function to convert a symbolic.Expression to float64 +func exprToFloat64(t *testing.T, expr symbolic.Expression) float64 { + resultK, ok := expr.(symbolic.K) + if !ok { + t.Fatalf("Expected result to be a constant, got type %T", expr) + } + return float64(resultK) +} + func TestSolution_ToMessage1(t *testing.T) { // Constants tempSol := solution.DummySolution{ @@ -166,12 +175,7 @@ func TestSolution_FindValueOfExpression1(t *testing.T) { t.Errorf("FindValueOfExpression returned an error: %v", err) } - // Convert to float64 - resultK, ok := resultExpr.(symbolic.K) - if !ok { - t.Errorf("Expected result to be a constant, got type %T", resultExpr) - } - result := float64(resultK) + result := exprToFloat64(t, resultExpr) expected := 13.0 if result != expected { @@ -206,12 +210,7 @@ func TestSolution_FindValueOfExpression2(t *testing.T) { t.Errorf("FindValueOfExpression returned an error: %v", err) } - // Convert to float64 - resultK, ok := resultExpr.(symbolic.K) - if !ok { - t.Errorf("Expected result to be a constant, got type %T", resultExpr) - } - result := float64(resultK) + result := exprToFloat64(t, resultExpr) expected := 42.0 if result != expected { @@ -250,12 +249,7 @@ func TestSolution_FindValueOfExpression3(t *testing.T) { t.Errorf("FindValueOfExpression returned an error: %v", err) } - // Convert to float64 - resultK, ok := resultExpr.(symbolic.K) - if !ok { - t.Errorf("Expected result to be a constant, got type %T", resultExpr) - } - result := float64(resultK) + result := exprToFloat64(t, resultExpr) expected := 15.5 if result != expected { @@ -330,12 +324,7 @@ func TestSolution_FindValueOfExpression5(t *testing.T) { t.Errorf("FindValueOfExpression returned an error: %v", err) } - // Convert to float64 - resultK, ok := resultExpr.(symbolic.K) - if !ok { - t.Errorf("Expected result to be a constant, got type %T", resultExpr) - } - result := float64(resultK) + result := exprToFloat64(t, resultExpr) expected := 14.0 if result != expected {