# Metadata

**L1 Taxonomy** - Problem Solving & Algorithms

**L2 Taxonomy** - Divide and Conquer

**Subtopic** - Python divide-and-conquer recurrence solver

**Use Case** - Develop a Python tool using SymPy that accepts divide-and-conquer recurrence definitions and computes explicit closed-form solutions. Implement handling of binary decompositions and combination of recursive subproblem outputs to analyze algorithmic processes, such as mergesort or integer multiplication techniques, enabling rigorous performance evaluation in real-life applications. This idea is inspired by detailed recurrence solutions shown in the paper fileciteturn0file1.

**Programming Language** - Python

**Target Model** - GPT-4o

# Model Breaking Hints


1) **What is the initial use case?**

The initial problem is to develop a Python tool using SymPy that accepts divide-and-conquer recurrence definitions and computes explicit closed-form solutions. It handles binary decompositions and combines recursive subproblem outputs to analyze algorithmic processes like mergesort or integer multiplication techniques.

2) **Why is the initial use case easy?**

This problem is relatively straightforward because it deals with classic divide-and-conquer recurrences, which are well-understood and often solvable using the Master Theorem or standard methods. Algorithms like mergesort have simple recursive relations that don't involve complex data structures or intricate dependencies, making them less challenging for both humans and advanced models.

3) **How could we make it harder?**

To increase the complexity, we can:

- **Integrate recurrences involving advanced tree algorithms like Heavy-Light Decomposition**, requiring symbolic analysis over tree paths and segments.
- **Incorporate dynamic programming on trees within the recurrences**, necessitating multi-state tracking and handling of subtree dependencies symbolically.
- **Include recurrences from advanced data structures like Segment Trees with Lazy Propagation**, needing symbolic analysis of lazy updates and range queries within the recursion.
- **Require handling of recurrences that involve Union-Find structures with dynamic connectivity**, where recursive calls depend on real-time merging and splitting of sets.
- **Extend the tool to solve recurrences modeling recursive Min-Cost Max-Flow computations**, embedding complex graph flow dynamics into the relations.

4) **Which parameters can we change?**

We can generalize the problem to handle recurrences involving:

- **Advanced tree algorithms**: Moving from simple binary trees to complex tree structures like those used in Heavy-Light Decomposition.
- **Dynamic programming on trees**: Introducing multiple states and dependencies that need to be tracked symbolically.
- **Complex data structures**: Incorporating Segment Trees with Lazy Propagation and Union-Find structures with dynamic connectivity, which involve non-trivial operations and dependencies.
- **Graph algorithms**: Embedding recurrences that model complex graph dynamics, such as recursive Min-Cost Max-Flow computations.

These changes introduce multi-step reasoning, obscure constraints, and hidden pitfalls that make naive approaches ineffective.

5) **What can be a final hard prompt?**

*"Develop a Python tool using SymPy that symbolically solves complex recursive definitions involving advanced algorithms and data structures, such as Heavy-Light Decomposition with dynamic programming on trees, Segment Trees with Lazy Propagation, and Union-Find structures with dynamic connectivity. The tool must handle symbolic analysis over tree paths and segments with multi-state tracking, subtree dependencies, lazy updates, and real-time merging and splitting of sets within recursions, enabling rigorous performance evaluation of algorithms with complex tree and graph dynamics."*

# Setup

```requirements.txt
sympy==1.12
```


# Prompt

I want to build a Python tool that takes a recurrence relation written using divide and conquer format and computes the simplified closed form expression. The tool should solve recurrence patterns that appear in algorithms like mergesort. The result must not include any recursion and should instead be a final math expression in terms of the variable n.

**Input Format**

The function will receive two inputs:

1. recurrence: a string that looks like T(n) = a * T(n / b) + f(n)
2. params: a dictionary with the following keys and values:

  - a: an integer greater than 0
  - b: an integer greater than 1
  - f(n): a string such as n, n * log(n), or n ^ 2
  - base_case: a string that represents the base value like 1, 5, or a constant expression

**Output Format**

- Return a string that shows the simplified closed form of T(n)
- It must be a symbolic expression using only n
- The result must not include any form of T()
- You must simplify the result using the sympy library

**Examples**

**Example 1**

Input:
recurrence = "T(n) = 2 * T(n / 2) + n"
params = {
"a": 2,
"b": 2,
"f(n)": "n",
"base_case": "1"
}

Output:
T(n) = n * log(n) + n

**Example 2**

Input:
recurrence = "T(n) = 3 * T(n / 4) + n ^ 2"
params = {
"a": 3,
"b": 4,
"f(n)": "n ^ 2",
"base_case": "1"
}

Output:
T(n) = n ^ 2

# Requirements

**Explicit and Implicit Points**

- Parse the recurrence string
- Use the values from params, not from the string itself
- Use sympy for all math expressions and simplifications
- Base case value must be used when n becomes 1 or the lowest possible integer
- If the recurrence cannot be solved, return the string "Cannot solve"

**Solution Expectations**

- Use only sympy for symbolic computation
- Final result must be simplified
- You must derive the expression by unrolling or using the master theorem
- The expression must be returned as a string
- Do not include any unresolved or recursive parts

**Function Signature**
```Python
def solve_recurrence(recurrence: str, params: dict) -> str
```

**Edge Case Behavior**

Implement these edge cases:

- If a is less than or equal to 0 or b is less than or equal to 1, raise ValueError with message "Invalid recurrence parameters"
- If f(n) contains factorial or unsupported operations, return the string "Unsupported function"
- If base_case cannot be parsed by sympy, raise SyntaxError
- If any input is empty or invalid, return the string "Invalid input"

**Constraints**

1. Do not use eval or exec anywhere in your solution
2. Do not hardcode answers for known recurrence cases
3. Use sympy.sympify to convert all math strings to expressions
4. Your function must handle expressions like n * log(n) and n ^ k where k is any non negative integer
5. You may only import sympy and math; no other libraries are allowed
6. Do not use the re module
7. Final result must be returned using str() and not print

In [None]:
# code

import sympy as sp
from sympy import symbols, log, simplify, sympify, expand, factor
import math

def solve_recurrence(recurrence: str, params: dict) -> str:
    """Solves divide-and-conquer recurrence relations to closed form expressions."""
    try:
        # Input validation
        if not _validate_inputs(recurrence, params):
            return "Invalid input"

        # Extract and validate parameters
        a, b, f_n_str, base_case_str = _extract_params(params)

        # Parameter validation
        if a <= 0 or b <= 1:
            raise ValueError("Invalid recurrence parameters")

        # Parse symbolic expressions
        n = symbols('n', positive=True, integer=True)
        f_n, base_case = _parse_expressions(f_n_str, base_case_str, n)

        # Apply Master Theorem
        result = _apply_master_theorem(a, b, f_n, base_case, n)

        # Simplify and return as string
        simplified = simplify(result)
        return f"T(n) = {simplified}"

    except ValueError as e:
        if "Invalid recurrence parameters" in str(e):
            raise e
        elif "Unsupported function" in str(e):
            return "Unsupported function"
        return "Invalid input"
    except SyntaxError:
        raise SyntaxError()
    except Exception:
        return "Cannot solve"


def _validate_inputs(recurrence: str, params: dict) -> bool:
    """Validates input format and completeness."""
    if not recurrence or not isinstance(recurrence, str):
        return False
    if not params or not isinstance(params, dict):
        return False

    required_keys = {'a', 'b', 'f(n)', 'base_case'}
    if not all(key in params for key in required_keys):
        return False

    return all(params[key] is not None and str(params[key]).strip() for key in required_keys)


def _extract_params(params: dict) -> tuple:
    """Extracts parameters from params dict with type validation."""
    try:
        a = int(params['a'])
        b = int(params['b'])
        f_n_str = str(params['f(n)']).strip()
        base_case_str = str(params['base_case']).strip()
        return a, b, f_n_str, base_case_str
    except (ValueError, TypeError):
        raise ValueError("Parameter type conversion failed")


def _parse_expressions(f_n_str: str, base_case_str: str, n) -> tuple:
    """Parses f(n) and base_case strings into sympy expressions."""
    # Check for unsupported operations
    unsupported = ['factorial', '!', 'gamma', 'sin', 'cos', 'tan', 'exp']
    if any(op in f_n_str.lower() for op in unsupported):
        raise ValueError("Unsupported function")

    try:
        # Convert ^ to ** for sympy compatibility
        f_n_processed = f_n_str.replace('^', '**')
        f_n_processed = f_n_processed.replace('log(n)', f'log({n})')

        # Parse using sympify
        f_n = sympify(f_n_processed, locals={'n': n, 'log': log})

        # Handle base case parsing with proper error handling
        try:
            base_case = sympify(base_case_str)
            # Check if base_case contains unexpected symbols (like 'abc')
            # Valid base cases should be numbers or simple expressions
            if base_case.free_symbols:
                # If it has symbols that aren't standard math constants, it's invalid
                for sym in base_case.free_symbols:
                    sym_str = str(sym)
                    # Allow common math symbols but reject arbitrary text like 'abc'
                    if sym_str not in ['e', 'pi', 'I', 'E'] and not sym_str.isdigit():
                        raise SyntaxError()
        except:
            raise SyntaxError()

        return f_n, base_case
    except SyntaxError:
        raise SyntaxError()
    except Exception:
        raise ValueError("Invalid input")


def _apply_master_theorem(a: int, b: int, f_n, base_case, n):
    """Applies Master Theorem to derive closed form solution."""
    # Calculate critical exponent log_b(a)
    log_b_a = log(a) / log(b)

    # Determine polynomial degree of f(n)
    f_degree = _get_polynomial_degree(f_n, n)

    # Apply Master Theorem cases
    if f_degree is not None:
        return _handle_polynomial_case(a, b, f_n, base_case, n, log_b_a, f_degree)
    elif _is_n_log_n_form(f_n, n):
        return _handle_n_log_n_case(a, b, f_n, base_case, n, log_b_a)
    elif _is_log_form(f_n, n):
        return _handle_log_case(a, b, f_n, base_case, n, log_b_a)
    else:
        return _solve_by_expansion(a, b, f_n, base_case, n)


def _get_polynomial_degree(expr, n) -> int:
    """Determines polynomial degree of expression in n."""
    try:
        expanded = expand(expr)
        poly = sp.Poly(expanded, n)
        return poly.degree()
    except Exception:
        return None


def _is_n_log_n_form(expr, n) -> bool:
    """Checks if expression is of form c*n*log(n)."""
    try:
        # Check for n*log(n) pattern
        simplified = simplify(expr / (n * log(n)))
        return simplified.is_constant() and not simplified.has(n)
    except Exception:
        return False


def _is_log_form(expr, n) -> bool:
    """Checks if expression is of form c*log(n)."""
    try:
        simplified = simplify(expr / log(n))
        return simplified.is_constant() and not simplified.has(n)
    except Exception:
        return False


def _handle_polynomial_case(a: int, b: int, f_n, base_case, n, log_b_a, f_degree: int):
    """Handles polynomial f(n) using Master Theorem cases."""
    # Extract leading coefficient
    expanded = expand(f_n)
    leading_coeff = expanded.coeff(n, f_degree)
    if leading_coeff is None:
        leading_coeff = 1

    # Convert log_b_a to float for comparison
    try:
        log_b_a_float = float(log_b_a.evalf())
    except:
        log_b_a_float = float(log_b_a)

    if f_degree < log_b_a_float:
        # Case 1: f(n) = O(n^c) where c < log_b(a)
        return base_case * n**log_b_a
    elif abs(f_degree - log_b_a_float) < 1e-10:
        # Case 2: f(n) = Θ(n^c) where c = log_b(a)
        return n**f_degree * log(n)
    else:
        # Case 3: f(n) = Ω(n^c) where c > log_b(a)
        # Check regularity condition
        if _check_regularity_condition(a, b, leading_coeff, f_degree):
            return leading_coeff * n**f_degree
        else:
            return leading_coeff * n**f_degree


def _handle_n_log_n_case(a: int, b: int, f_n, base_case, n, log_b_a):
    """Handles f(n) = c*n*log(n) case."""
    coeff = simplify(f_n / (n * log(n)))

    # Special case for when log_b(a) = 1 and f(n) = n*log(n)
    if abs(float(log_b_a - 1)) < 1e-10:  # log_b(a) = 1
        return coeff * n * log(n)**2
    elif log_b_a > 1:
        return base_case * n**log_b_a
    else:
        return coeff * n * log(n)**2


def _handle_log_case(a: int, b: int, f_n, base_case, n, log_b_a):
    """Handles f(n) = c*log(n) case."""
    coeff = simplify(f_n / log(n))

    # For T(n) = 1*T(n/2) + log(n), the solution is log(n)^2
    if abs(float(log_b_a)) < 1e-10:  # log_b(a) = 0, so a = 1
        return coeff * log(n)**2
    elif log_b_a > 0:
        return base_case * n**log_b_a
    else:
        return coeff * log(n)**2


def _check_regularity_condition(a: int, b: int, coeff, degree: int) -> bool:
    """Checks Master Theorem regularity condition."""
    # Simplified regularity check
    return coeff > 0 and a * (1/b)**degree < 1


def _solve_by_expansion(a: int, b: int, f_n, base_case, n):
    """Solves recurrence by expansion method when Master Theorem doesn't apply."""
    # Expansion method for complex f(n)
    try:
        # Build solution through geometric series analysis
        log_b_a = log(a) / log(b)

        # Approximate solution based on recurrence structure
        if f_n.has(log):
            # Handle logarithmic terms
            return base_case * n**log_b_a + f_n
        else:
            # Default to dominant term analysis
            f_at_n = f_n.subs(n, n)
            return simplify(base_case * n**log_b_a + f_at_n)

    except Exception:
        raise ValueError("Cannot solve complex recurrence")

In [None]:
# tests

import unittest
from sympy import symbols, log
from main import solve_recurrence

class TestSolveRecurrence(unittest.TestCase):
    def test_basic_case_linear(self):
        recurrence = "T(n) = 2 * T(n / 2) + n"
        params = {"a": 2, "b": 2, "f(n)": "n", "base_case": "1"}
        self.assertEqual(solve_recurrence(recurrence, params), "T(n) = n*log(n)")

    def test_quadratic_case(self):
        recurrence = "T(n) = 3 * T(n / 4) + n ^ 2"
        params = {"a": 3, "b": 4, "f(n)": "n ^ 2", "base_case": "1"}
        self.assertEqual(solve_recurrence(recurrence, params), "T(n) = n**2")

    def test_log_case(self):
        recurrence = "T(n) = 1 * T(n / 2) + log(n)"
        params = {"a": 1, "b": 2, "f(n)": "log(n)", "base_case": "1"}
        self.assertEqual(solve_recurrence(recurrence, params), "T(n) = log(n)**2")

    def test_cubic_case(self):
        recurrence = "T(n) = 8 * T(n / 2) + n ^ 2"
        params = {"a": 8, "b": 2, "f(n)": "n ^ 2", "base_case": "1"}
        self.assertEqual(solve_recurrence(recurrence, params), "T(n) = n**3")

    def test_log_linear_case(self):
        recurrence = "T(n) = 2 * T(n / 2) + n * log(n)"
        params = {"a": 2, "b": 2, "f(n)": "n * log(n)", "base_case": "1"}
        self.assertEqual(solve_recurrence(recurrence, params), "T(n) = n*log(n)**2")

    def test_a_less_than_equal_zero(self):
        recurrence = "T(n) = 0 * T(n / 2) + n"
        params = {"a": 0, "b": 2, "f(n)": "n", "base_case": "1"}
        with self.assertRaises(ValueError):
            solve_recurrence(recurrence, params)

    def test_b_less_than_equal_one(self):
        recurrence = "T(n) = 2 * T(n / 1) + n"
        params = {"a": 2, "b": 1, "f(n)": "n", "base_case": "1"}
        with self.assertRaises(ValueError):
            solve_recurrence(recurrence, params)

    def test_invalid_base_case(self):
        recurrence = "T(n) = 2 * T(n / 2) + n"
        params = {"a": 2, "b": 2, "f(n)": "n", "base_case": "abc"}
        with self.assertRaises(SyntaxError):
            solve_recurrence(recurrence, params)

    def test_unsupported_function(self):
        recurrence = "T(n) = 2 * T(n / 2) + n!"
        params = {"a": 2, "b": 2, "f(n)": "n!", "base_case": "1"}
        self.assertEqual(solve_recurrence(recurrence, params), "Unsupported function")

    def test_empty_input(self):
        self.assertEqual(solve_recurrence("", {}), "Invalid input")

    def test_missing_params(self):
        recurrence = "T(n) = 2 * T(n / 2) + n"
        self.assertEqual(solve_recurrence(recurrence, {}), "Invalid input")

    def test_invalid_fn_expression(self):
        recurrence = "T(n) = 2 * T(n / 2) + n +"
        params = {"a": 2, "b": 2, "f(n)": "n +", "base_case": "1"}
        self.assertEqual(solve_recurrence(recurrence, params), "Invalid input")

    def test_unrecognized_expression(self):
        recurrence = "T(n) = 2 * T(n / 2) + [n]"
        params = {"a": 2, "b": 2, "f(n)": "[n]", "base_case": "1"}
        self.assertEqual(solve_recurrence(recurrence, params), "Invalid input")

    def test_constant_fn_case(self):
        recurrence = "T(n) = 4 * T(n / 2) + 1"
        params = {"a": 4, "b": 2, "f(n)": "1", "base_case": "1"}
        self.assertEqual(solve_recurrence(recurrence, params), "T(n) = n**2")

    def test_equal_fn_and_log_threshold(self):
        recurrence = "T(n) = 4 * T(n / 2) + n^2"
        params = {"a": 4, "b": 2, "f(n)": "n^2", "base_case": "1"}
        self.assertEqual(solve_recurrence(recurrence, params), "T(n) = n**2*log(n)")

    def test_polylog_case(self):
        recurrence = "T(n) = 2 * T(n / 2) + n * log(n)"
        params = {"a": 2, "b": 2, "f(n)": "n * log(n)", "base_case": "1"}
        self.assertEqual(solve_recurrence(recurrence, params), "T(n) = n*log(n)**2")

    def test_unresolvable_case(self):
        recurrence = "T(n) = T(n / 3) + sin(n)"
        params = {"a": 1, "b": 3, "f(n)": "sin(n)", "base_case": "1"}
        # CORRECTED: sin(n) is an unsupported function, not an unsolvable case
        self.assertEqual(solve_recurrence(recurrence, params), "Unsupported function")

if __name__ == '__main__':
    unittest.main(argv=[''], exit=False)

ModuleNotFoundError: No module named 'main'

# Model Breaking Proof

#### Model Breaking Task URL: <https://rlhf-v3.turing.com/prompt/224d022a-8f5f-4452-a572-f37f30a5beb8?origin=https%3A%2F%2Flabeling-ms.turing.com&redirect_url=https%3A%2F%2Flabeling-ms.turing.com%2Fconversations%2F45813%2Fview>

#### Model code:

```python
from sympy import symbols, sympify, log, oo, simplify
from sympy.parsing.sympy_parser import parse_expr

def solve_recurrence(recurrence: str, params: dict) -> str:
    try:
        # Validate and parse the inputs
        a = params.get('a')
        b = params.get('b')
        fn = params.get('f(n)')
        base_case = params.get('base_case')
        
        if not all([a, b, fn, base_case]):
            return "Invalid input"
        if a <= 0 or b <= 1:
            raise ValueError("Invalid recurrence parameters")
        
        # Define symbolic variables
        n = symbols('n', positive=True)
        
        # Parse function f(n) with sympy
         # <Issue>: Uses parse_expr instead of sympify, which violates "Use sympy.sympify to convert all math strings"
        try:
            f_n = parse_expr(fn.replace('^', '**'), evaluate=True)
        except SyntaxError:
        # <Issue>: <Only catches SyntaxError; doesn't check for unsupported functions like factorial>
            return "Unsupported function"
        
        # Ensure the base case is a valid expression
        try:
            base_case_expr = parse_expr(base_case.replace('^', '**'), evaluate=True)
        except SyntaxError:
            raise SyntaxError("Invalid base case")
        
        # Compute critical exponent log_b(a)
        critical_exponent = log(a) / log(b)
        
        # Analyze the growth rate of f(n)
        fn_degree = f_n.as_poly(n).degree() if f_n.is_polynomial(n) else oo
        threshold_exponent = simplify(critical_exponent)
        
        # Apply Master Theorem to find the closed-form solution
        if fn_degree < threshold_exponent:
            # Case 1: T(n) = Θ(n^log_b(a))
            closed_form = simplify(n**threshold_exponent)
        elif fn_degree == threshold_exponent:
            # Case 2: T(n) = Θ(n^c * log(n))
            closed_form = simplify(f_n * log(n))
        else:
            # Case 3: T(n) = Θ(f(n))
            closed_form = simplify(f_n)
        
        # Return the final expression, replace `**` with `^` for output
        return f"T(n) = {str(closed_form).replace('**', '^')}"
        
    except Exception as e:
        return "Cannot solve"
```