# Backtrack
- [Backtracking](#backtrack)
- [Generation](#generation)
- [More constrained backtracking](#more-constrained-backtracking)
- [Problems](#problems)

## Backtracking
- Efficiently run through all possibilities in a problem 
- 'Pruning': Abandon a path
    - Use an optimization that involves abaondoning a path once that path cannot lead to a solution
- Similar to binary search tree
- Exhaustive search vs. backtracking
    all possibilities vs. pruning
- When to use
    - Find all of something
    - Need to check all logical possibilities
- Time complexity: exponential, `n<=15`
- Think of a graph: how to represent the state of each node
## Generation
- Example 1: [46. Permutations] 
- Example 2: 
 -Example 3: [77. Combinations] (https://leetcode.com/problems/combinations/)

In [None]:
# [46. Permutations] 
class Solution:
    def permute(self, nums: List[int]) -> List[List[int]]:
        ans = []
        def backtrack(curr,remain):
            if len(remain) == 0:
                ans.append(curr)
                return 
            for i in remain:
                nxt_remain= remain.copy()
                nxt_remain.remove(i)
                backtrack(curr+[i], nxt_remain)
        backtrack([], set(nums))
        return ans

In [None]:
# [78. Subsets] (https://leetcode.com/problems/subsets/)
class Solution:
    def subsets(self, nums: List[int]) -> List[List[int]]:
        ans = []
        def backtrack(i, path):
            ans.append(path)
            for j in range(i,len(nums)):
                backtrack(j+1, path+[nums[j]])
            return 
        backtrack(0, [])
        return ans

In [None]:
# All Paths From Source to Target (https://leetcode.com/problems/all-paths-from-source-to-target/)
class Solution:
    def allPathsSourceTarget(self, graph: List[List[int]]) -> List[List[int]]:
        ans = []
        def dfs(i,path):
            if i == len(graph)-1:
                ans.append(path)
                return 
            for j in graph[i]:
                dfs(j, path+[j])
            return ans
        dfs(0,[0])
        return ans

In [None]:
# Letter Combinations of a Phone Number (https://leetcode.com/problems/letter-combinations-of-a-phone-number/)
class Solution:
    def letterCombinations(self, digits: str) -> List[str]:
        ans = []
        if digits == "":
            return []
        num2l = {'2':['a','b','c'], '3':['d','e','f'], '4':['g','h','i'], '5':['j','k','l'], '6':['m','n','o'], '7':['p','q','r','s'], '8':['t','u','v'], '9':['w','x','y','z']}
        def backtrack(path):
            n = len(path)
            if n == len(digits):
                ans.append(path)
                return 
            num = digits[n]
            for l in num2l[num]:
                backtrack(path+l)
        backtrack('')
        return ans

## More Constrained backtracking
Example 1. [39. Combination Sum]
[52. N-Queens II] (https://leetcode.com/problems/n-queens-ii/)
[79.Word Search] (https://leetcode.com/problems/word-search/)
## Problems

In [None]:
# [39. Combination Sum] (https://leetcode.com/problems/combination-sum/)
class Solution:
    def combinationSum(self, candidates: List[int], target: int) -> List[List[int]]:
        ans = []
        def backtrack(i, path, remain):
            if remain == 0:
                ans.append(path)
                return 
            if remain < 0:
                return 
            for j in range(i,len(candidates)):
                backtrack(j,path+[candidates[j]],remain-candidates[j])
            return 
        backtrack(0,[],target)
        return ans

In [None]:
# [52. N-Queens II] (https://leetcode.com/problems/n-queens-ii/)
class Solution:
    def totalNQueens(self, n: int) -> int:
        def backtrack(row, cols, diags, antidiags):
            nonlocal ans
            if row == n:
                ans += 1
                return 
            for col in range(n):
                if canplace(row,col,cols, diags, antidiags):
                    placeQ(row,col,cols, diags, antidiags)
                    backtrack(row+1, cols, diags, antidiags)
                    removeQ(row,col,cols, diags, antidiags)          
            return 
        
        def placeQ(i,j,cols, diags, antidiags):
            cols.add(j)
            diags.add(i-j)
            antidiags.add(i+j)
            return 

        def removeQ(i,j,cols, diags, antidiags):
            cols.remove(j)
            diags.remove(i-j)
            antidiags.remove(i+j)
            return 
        def canplace(i,j,cols, diags, antidiags):
            if j not in cols and (i-j) not in diags and (i+j) not in antidiags:
                return True
            else:
                return False
        
        ans = 0
        backtrack(0,set(), set(), set())
        return ans  

In [None]:
# [79.Word Search] (https://leetcode.com/problems/word-search/)
class Solution:
    def exist(self, board: List[List[str]], word: str) -> bool:
        directions = [[0,1],[0,-1],[-1,0],[1,0]]
        def dfs(i,j,widx,used):
            if widx == len(word):
                return True
            if 0<=i<len(board) and 0<=j<len(board[0]) and board[i][j] == word[widx] and (i,j) not in used:
                used.add((i,j))
                for di,dj in directions:
                    if dfs(i+di,j+dj,widx+1,used):
                            return True
                used.remove((i,j)) # Remove (i,j) from used if the word is not found in the current path
            return False
        
        for i in range(len(board)):
            for j in range(len(board[0])):
                if dfs(i,j,0,set()):
                    return True
        return False

In [None]:
# numbers with Same Consecutive Differences (https://leetcode.com/problems/numbers-with-same-consecutive-differences/)
class Solution:
    def numsSameConsecDiff(self, n: int, k: int) -> List[int]:
        ans = []
        def decode(arr):
            num = arr[0]
            for i in range(1,len(arr)):
                num = num*10+arr[i]
            return num

        def backtrack(curr):
            if len(curr) == n:
                ans.append(decode(curr))
                return 
            if len(curr) > 0:
                for j in range(10):
                    if abs(j-curr[-1]) == k:
                        backtrack(curr+[j])
            else:
                for j in range(1,10):
                    backtrack(curr+[j])
            return 
        backtrack([])
        return ans

In [None]:
# Combination Sum III (https://leetcode.com/problems/combination-sum-iii/)
class Solution:
    def combinationSum3(self, k: int, n: int) -> List[List[int]]:
        ans = []
        def backtrack(path):
            if len(path) == k and sum(path) == n:
                ans.append(path)
                return 
            if len(path) >= k or sum(path)>n:
                return 
            start = 1 if len(path) == 0 else path[-1]+1
            for i in range(start,10):
                backtrack(path+[i])
            return 
        backtrack([])
        return ans

## Problems
- [91. Decode Ways] https://leetcode.com/problems/decode-ways/
- [465. Optimal Account Balancing](https://leetcode.com/problems/optimal-account-balancing/)
- [22. Generate Parentheses] (https://leetcode.com/problems/generate-parentheses/)


In [None]:
class Solution:
    def cleanRoom(self, robot):
        """
        :type robot: Robot
        :rtype: None
        """
        visited = set()
        directions = [(-1,0),(0,1),(1,0),(0,-1)] 

        def go_back():
            robot.turnRight()
            robot.turnRight()
            robot.move()
            robot.turnRight()
            robot.turnRight()
        

        def dfs(x,y,d):
            robot.clean()
            visited.add((x,y))
            for i in range(4):
                next_d = (d + i) % 4
                next_x = x + directions[next_d][0]
                next_y = y + directions[next_d][1]
                if (next_x, next_y) not in visited and robot.move():
                    dfs(next_x,next_y, next_d)
                    go_back()
                robot.turnRight()
        dfs(0,0,0)

In [None]:
#[91. Decode Ways] https://leetcode.com/problems/decode-ways/
class Solution:
    def numDecodings(self, s: str) -> int:
        decode =set()
        @cache
        def dfs(a):
            if a == '':
                return 1
            if a[0] == "0":
                return 0
            if len(a) == 1:
                return 1  
            ans = dfs(a[1:])
            if int(a[0:2]) <= 26:
                ans += dfs(a[2:])
            return ans 
        return dfs(s)     

In [None]:
# [465. Optimal Account Balancing](https://leetcode.com/problems/optimal-account-balancing/)
class Solution:
    def minTransfers(self, transactions: List[List[int]]) -> int:
        dic = collections.defaultdict(int)
        for fr, to, amount in transactions:
            dic[fr] -= amount
            dic[to] += amount

        nums = sorted(dic.values())
        def dfs(arr, res):
            ans = 100000
            if all(x==0 for x in arr):
                return res 
            for i in range(len(arr)):
                if arr[i] < 0:
                    for j in range(i+1, len(arr)):
                        if arr[j] > 0:
                            arr1 = arr.copy()
                            arr1[i] = min(0, arr[i]+arr[j])
                            arr1[j] = max(0, arr[i]+arr[j])
                            ans = min(ans, dfs(arr1, res+1))
                    return ans
        return dfs(nums, 0)

    def minTransfers(self, transactions: List[List[int]]) -> int:
        n = 12
        a = [0] * 12
        for u, v, w in transactions:
            a[u] += w
        a[v] -= w
        a.sort() 
        best = 10000000
        def dfs(step):
            nonlocal best
            if step >= best:
                return
            if all(x == 0 for x in a):
                best = step
                return
            for i in range(n):
                if a[i] < 0:
                    for j in range(i + 1, n):
                        if a[j] > 0:
                            tran = min(-a[i], a[j])
                            a[i] += tran
                            a[j] -= tran
                            dfs(step + 1)
                            a[i] -= tran
                            a[j] += tran
                    return

        dfs(0)
        return best

In [None]:
# [22. Generate Parentheses] https://leetcode.com/problems/generate-parentheses/

class Solution:
    def generateParenthesis(self, n: int) -> List[str]:
        res = []
        def backtrack(i, path, stack):
            if i == n:
                if len(stack) == 0:
                    res.append(path)
                    return
                else:
                    path += stack.pop()
                    backtrack(i, path, stack)
            else:
                backtrack(i+1, path+'(', stack+[')'])
                if len(stack) > 0: 
                    backtrack(i, path+')', stack[:-1])
            return
        
        backtrack(0,"",[])
        return res

# @2024-09-02
class Solution:
    def generateParenthesis(self, n: int) -> List[str]:
        ans = []
        def backtrack(path, cntL, cntR):
            if cntL == 0 and cntR == 0:
                ans.append(path)
                return 
            if cntL > 0:
                backtrack(path+'(', cntL-1, cntR+1)
            if cntR > 0:
                backtrack(path+')', cntL, cntR-1)   
        backtrack("",n,0)
        return ans