# Backtracking

Backtracking is a systematic exploration of the search space of a particular problem.

General backtracking procedure, given a partial solution s:
- Verify if s is a solution. If s is a solution, process it (problem dependent).
- Create all extended solution starting at s.
- Verify border conditions.
- Recursively call this procedure for all extended solutions.

Backtracking can be used to explore all solutions, find the first solution (with a property in particular) or find the optimal solution.

## Solution Space

Let’s assume that a solution can be modeled with a vector $s = (s_1,s_2, \dots s_n)$, where each $s_i$ can take values from a finite set $S_i$.

- A candidate solution is of the form: $$ s_T = (s_1,\dots,s_k) $$
- Extending the solution $s_T$ is achieved by adding an element: $$ s_{T+1} = (s_1,\dots,s_T,s_{T+1}) $$

## Generic Algorithm

### Algorithm

```
backtracking(A, k):
    if done:
        then return
    if is_solution(A, k):
        process_solution(A, k)
    for each c in extend_solution(A, k):
        A[k] = c
        if test(A, k):
            backtracking(A, k+1)
```

- `is_solution(A, k)` indicates if the argument if a complete solution to the problem.
- `process_solution(A, k)` process a/the solution to the problem.
- `extend_solution(A, k)` given a partial solution, return/generates all solution one step larger.
- `test(A, k)` this function returns true if the extended solution is a valid solution.
- done global variable that signals if no more backtracking is required

### Approach

- How do we model $s$?
- How do we extend a solution?
- How do we process the solution?

### Example

#### Problem

Find **all** subsets of size $n$ of a total of $m$ elements.

#### Approach

- We model subsets as a binary array `s []` in which, is `s[i] == True` indicates that the *i*th element belongs to the subset.
- We extend a partial solution appending either a `True` or a `False` to the end of the partial solution.
- We print a solution once we find it
    - A **partial solution** is any such that $\sum_i s[i] < n$ and `len(s) < m`.
    - A **solution** to the problem is any such that $\sum_i s[i] = n$ and `len(s) == m`.

#### Implementation

##### Check if a partial solution is a solution to the problem

In [1]:
def is_solution(a, n, m):
    if len(a) == m and sum(a) == n:
        return True
    return False

##### Extend a partial solution

In [3]:
def extend_solution(a, m):
    if len(a) < m:
        for c in [True, False]:
            yield a + [c]

Why `a + [c]` and not `a.append(c)`?

##### Process the solution

In [4]:
def process_solution(a):
    friendly = map(lambda x: str(x[0]), filter(lambda x: x[1], enumerate(a)))
    print("Subset with the following elements " + ", ".join(friendly))

On this specific case we are only printing it

##### All together

In [5]:
def backtracking(n, m, a=[]):
    if is_solution(a, n, m):
        process_solution(a)
    else:
        for a_extended in extend_solution(a, m):
            backtracking(n, m, a_extended)

Now lets find all subsets of size 4 out of a set with 6 elements

In [None]:
backtracking(4, 6)

## Other Examples

### Task Assignment

#### Problem

Given $n$ workers, $n$ tasks and a cost matrix, e.g. $C[i,j]$ that represents how many man-hours it takes worker $i$ to complete task $j$, you are to find the task assignment (assign  ach worker to fulfill a task) that minimizes the total cost.

#### Modeling

- We model the assignment as a list of tuples $(w, t)$ which assigns worker $w$ to task $t$.
- We extend a solution by adding a new tuple to the list, composed of an unassigned worker and an unfulfilled task.
- Once we find a valid solution, we calculate its cost and keep the assignment with the lowest cost.
    - A valid assignment is one that all workers have a task assigned, and all tasks have a worker assigned to them.

#### Implementation

##### Check solution

In [11]:
def is_solution(assignment, n):
    if len(assignment) == n:
        workers = set(w for (w,t) in assignment)
        tasks = set(t for (w,t) in assignment)
        if len(workers) != n or len(tasks) != n:
            return False
        return True
    return False

##### Extend solution

In [14]:
def extend_solution(assignment, n):
    workers = sorted([w for (w,t) in assignment])
    tasks = sorted([t for (w,t) in assignment])
    free_workers = [x for x in range(n) if x not in workers]
    free_tasks = [x for x in range(n) if x not in tasks]
    for t in free_workers:
        for T in free_tasks:
            yield assignment + [(t, T)]

##### Process solution

In [15]:
def process_solution(assignment, costs):
    global best_assignment
    cost = sum(costs[w][t] for (w, t) in assignment)
    if best_assignment is None or cost < best_assignment[1]:
        best_assignment = (assignment, cost)


##### Backtracking

In [16]:
def backtracking(costs, assignment=[]):
    n = len(costs)
    if is_solution(assignment, n):
        process_solution(assignment, costs)
    else:
        for extended_assignment in extend_solution(assignment, n):
            backtracking(costs, extended_assignment)

In [None]:
best_assignment = None
costs=[
    [4,2,3,1],
    [9,3,4,2],
    [2,4,6,2],
    [7,3,1,0]
]

backtracking(costs)

print(f"The best assignment has a cost of {best_assignment[1]}, corresponding to:")
for (w, t) in best_assignment[0]:
    print(f"\tAssign worker {w} to task {t}.")

### Sudoku

#### Problem

- A Sudoku is a $9x9$ board where repeated values are not allowed by looking at each:
    - Row
    - Column
    - $3x3$ block
- The values are numbers that go from $1$ to $9$

```
+-------+-------+-------+
|       |       |       |
|       |     3 |   5 8 |
|     1 |   2   |       |
+-------+-------+-------+
|       | 5   7 |       |
|     4 |       | 1     |
|   9   |       |       |
+-------+-------+-------+
| 5     |       |   7 3 |
|     2 |   1   |       |
|       |   4   |     9 |
+-------+-------+-------+
```

#### Modeling

- We model the sudoku as list of 81 elements, where:
    - for every row $r \in {0,8}$, the index of elements in the list that belong to that row are: $[9r + 0, 9r + 2, \dots, 9r + 8]$
    - for every column $c \in {0,8}$, the index of elements in the list that belong to that column are: $[0 + 9c, 2 + 9c, \dots, 8 + 9c]$
    - for every row $r \in {0,8}$ and column $c \in {0,8}$, the index of elements in the list that belong to a block are $[9*(c//3) + c\%3 + (r//3)*27 + (r\%3)*3]$
    - the list may have empty elements, that will need to be filled
- We extend the solution by replacing an empty element with a value.
- Once we find a solution, we check if every row, column and block are composed of unique elements.

#### Implementation

##### Check solution

In [19]:
def is_solution(solution):
    none_elems = list(filter(lambda x: x is None, solution))
    if len(none_elems) == 0:
        return True
    return False

##### Extend solution

In [20]:
def extend_solution(solution):
    none_elems = list(filter(lambda x: x is None, solution))
    if len(none_elems) > 0:
        idx = next(i for i in range(81) if solution[i] is None)
        for value in range(1,10):
            new_solution = list(solution)
            new_solution[idx] = value
            yield new_solution

In [21]:
def test(solution):
    for r in range(9):
        partial = [[], [], []]
        for c in range(9):
            if solution[9*r + c] is not None:
                partial[0].append(solution[9*r + c])
            if solution[r + 9*c] is not None:
                partial[1].append(solution[r + 9*c])
            if solution[9*(c//3) + c%3 + (r//3)*27 + (r%3)*3] is not None:
                partial[2].append(solution[9*(c//3) + c%3 + (r//3)*27 + (r%3)*3])
        for _, partial_ in enumerate(partial):
            if len(list(set(partial_))) != len(partial_):
                return False
    return True

##### Process solution

In [22]:
def process_solution(solution):
    for r in range(9):
        if r % 3 == 0:
            print("+-------+-------+-------+")
        print(
            '|', solution[9*r], solution[9*r + 1], solution[9*r + 2],
            '|', solution[9*r + 3], solution[9*r + 4], solution[9*r + 5],
            '|', solution[9*r + 6], solution[9*r + 7], solution[9*r + 8],
            '|'
        )
    print("+-------+-------+-------+")

##### Backtracking

In [23]:
def backtracking(solution=[None]*81):
    global done
    if done:
        return
    if is_solution(solution):
        done = True
        process_solution(solution)
    else:
        for next_solution in extend_solution(solution):
            if test(next_solution):
                backtracking(next_solution)

In [None]:
done = False
N = None
sodoku = [
    N, N, 3, 9, N, N, N, 5, 1,
    5, 4, 6, N, 1, 8, 3, N, N,
    N, N, N, N, N, 7, 4, 2, N,
    N, N, 9, N, 5, N, N, 3, N,
    2, N, N, 6, N, 3, N, N, 4,
    N, 8, N, N, 7, N, 2, N, N,
    N, 9, 7, 3, N, N, N, N, N,
    N, N, 1, 8, 2, N, 9, 4, 7,
    8, 5, N, N, N, 4, 6, N, N
]
backtracking(sodoku)