# Digit DP
Digit DP is commonly used to solve problem where we are asked to **count the number of integers `x` between two integers `a` and `b` such that `x` satisfies a specific property that can be related to its digits.** In some problem we will be given an array of digits(0-9) to be used to construct `x`.

## Common Approach:
The common approach is the follows:
- First we need to define a recursive function `dfs(n)` that tells us the number of such integers in `[1, n]`.
- Then the number of such integers between a and b can be given by `dfs(b) – dfs(a-1)`.

## Defining the DFS Function:    
To define dfs(n), there are essentially 2 constraints:
1. The constraint given in the problem
2. The number `x` that we are constructing must be smaller than `n`

How to deal with the first constraint depends highly on the problem itself, but there are a general approach for the second constraint:
- We treat `n` as an array of digits and do DP on it
- Similar to regular DP, we can choose to use or not use a certain digit.

  - If we choose not to use a/some position(s), `x` will be one digit shorter, thus it will always be smaller than `n`.
    - Case 0: Which means that for all the remaining positions in `x`, we can choose whatever digit we want without limitation
    - But note that you can only choose not to use a position **only if you haven't use a position before**. Once you have already used a position, you must use all positions that follow.

- If we use a position, we have 2 cases:
  - Case 1: The digit we choose for `x` is smaller than the digit in `n`: In this scenario there will be no limitation for digits we choose in the next recursion, since no matter what digit you choose next, `x` is already guaranteed to be smaller than `n`
  - Case 2: The digit we choose for `x` is equal to the digit in `n`: In this case we need to compare the digit we want to choose to its corresponding digit in `n` again in the next recursion.
  - Note that we are not allowed to choose a digit larger than the corresponding digit in `n` 

## Implementation:
When Implementing the DFS function, we usually need the following parameter:
- `i`: telling us where we are right now
- `free`: whether we are allowed to choose the current digit freely without constraint for this recursive step

A fairly common approach is to **first count the numbers with fewer digits than `n`**, and it can usually be done in O(m) time.(check Q1 for how we did it)             
At the same time, we can usually **construct a `cnts[]` array** to help us deal with case 1 in O(1) time.
Then we count the numbers with the exact same amount of digits as `n`.
- If we choose a digit smaller than the first digit of `n` to be our first digit, then all other positions are free
- If we choose a digit that is the same, we need to recurse for the next positions
- We cannot choose a digit that is larger.

---
### Q1. Count Number With Unique Digits (LC.357)
*Given an integer n, return the count of all numbers with unique digits, x, where 0 <= x < 10^n.*

**Solution:**    
For a number of 1 digits, we have 10 numbers       
For a number of 2 digits, we have 9 * 9 = 81 numbers, since 0 cannot be at the beginning       
For a number of 3 digits, we have 9 * 9 * 8    
...   
For a number of n digits, we have 9 * 9 * 8 * 7 * ... * (9 - n + 2) numbers         
Therefore, we just need to add them all up

**Note:**
This problem may seems very simple, but it can be seen as a foundation part. This problem is quite commonly used as a part in more complex problems.

In [7]:
class Solution:
    def countNumbersWithUniqueDigits(self, n):
        if n == 0:
            ans = 1
        else:
            ans = 10
            for i in range(2, n + 1):
                curDigitCnt = 9
                for j in range(9, 9 - i + 1, -1):
                    curDigitCnt *= j
                ans += curDigitCnt
        return ans

---
### Q2. Numbers At Most N Given Digit Set (LC.902)
*Given an array of digits which is sorted in non-decreasing order.      
You can write numbers using each digits[i] as many times as we want. For example, if digits = ['1','3','5'], we may write numbers such as '13', '551', and '1351315'.*
*Return the number of positive integers that can be generated that are less than or equal to a given integer n.*

**Solution:**        
We treat the number `n` as an array of digits. For example Lets suppose n = 259 and digits = [1, 2, 5, 7]
For a digit in n, there are 3 cases:
- We ignore the digit, so 259 can become 59 then become 9. But note that if we take a digit at left we must take all digits on its right. If we take 2 we must take 259
- We take the digit, and the digit we choose in `digit` is smaller. In this case, we have no limitation on the digits we choose next. For example, if we take 2 from 259 and choose 1 as our first digit, then the two position left has no limitation
- We take the digit, and the digit we choose in `digit` is the same. In this case, we have to compare again in the next position. For example, if we take 2 from 259 and choose 2 as our first digit, we cannot choose 7 in the next position.

In [25]:
class Solution:
    def atMostNGivenDigitSet(self, digits, n):
        number = []
        while n > 0:
            number.append(n % 10)
            n //= 10
        number.reverse()
        m = len(number)

        # hasNoPrev: have we used a position before?
        # free: are we free to choose any digit for the current position?
        def dfs(i, hasNoPrev, free):
            if i == m:
                return 0 if hasNoPrev else 1
            cur = number[i]
            ans = 0
            if hasNoPrev:
                # Case 0: Don't use the current digit (only available if we never used a position before)
                ans += dfs(i + 1, True, True)
            if not free:
                for digit in digits:
                    digit = int(digit)
                    if digit < cur:
                        # Case 2.1: use and choose a smaller digit
                        ans += dfs(i + 1, False, True)
                    elif digit == cur:
                        # Case 2.2: use and choose the same digit
                        ans += dfs(i + 1, False, False)
                    else:
                        # We cannot choose larger digit, so break since digits is sorted
                        break
            else:
                # Case 1: no limitation for current position
                ans += len(digits) * dfs(i + 1, False, True)
            return ans

        return dfs(0, True, False)

**Optimization:**       
This problem can actually be solved with a simpler approach using problem 1.      

We actually don't need to recurse for all cases, because if for a position we choose a smaller digit, the cnt of valid numbers following our choices can be directly calculated.        

Suppose we have `n = 53278`, `digits = [1, 4, 5, 9]`.        
For our first digit, our choices are as follow:
- Choose 1: other 4 positions are free. Giving us a cnt of 4 * 4 * 4 * 4 = 256
- Choose 4: similarly, other 4 positions are free. Giving us a cnt of 4 * 4 * 4 * 4 = 256
- Choose 5: Since we choose the same number, other positions are not free. Thus we need to recurse deeper
- Choose 9: We cannot choose 9

Knowing this, we can first build a table before we run our dfs function, which will prune a lot of the branches:
- `cnts[i]`: If there are `i` positions left to be filled and the previous digit is smaller than its correpondance in `n`, how many valid numbers are there?

At the same time, we can also compute all cnts for case 1 in our last approach (when the number of digits in `x` is smaller than that in `n`)
- For the example above, we have 5 + 5^2 + 5^3 + 5^4 valid numbers.



In [35]:
class Solution:
    def atMostNGivenDigitSet(self, digits, n):
        number = list(map(int, str(n)))
        digits = list(map(int, digits))
        m = len(number)
        d = len(digits)
        cnts = [0] * m

        cnt = 0
        cnts[0] = 1     # 0 position left to be filled, meaning we find a valid number
        for i in range(1, m):
            cnts[i] = d ** i
            cnt += cnts[i]
            
        
        def dfs(i):
            if i == m:
                return 1
            cur = number[i]
            ans = 0
            for digit in digits:
                digit = int(digit)
                if digit < cur:
                    # if we are at position i, there are m - i - 1 positions left to be filled
                    ans += cnts[m - i - 1]
                elif digit == cur:
                    ans += dfs(i + 1)
                else:
                    break
            return ans

        return cnt + dfs(0)

---
### Q3. Count Of Integers (LC.2719)
*You are given two numeric strings `num1` and `num2` and two integers `max_sum` and `min_sum`. We denote an integer `x` to be good if:*
- *`num1 <= x <= num2`*
- *`min_sum <= digit_sum(x) <= max_sum.`*

*Return the number of good integers. Since the answer may be large, return it modulo `10^9 + 7`.*

*Note that `digit_sum(x)` denotes the sum of the digits of `x`.*

**Solution:**    
Following the common approach mentioned in the beginning, we need to find a function `dfs(n)` that tells us the number of good integers in `[0, n]`. Then return `dfs(num1) - dfs(num2 - 1)`

In [5]:
MOD = 10**9 + 7

class Solution:
    def count(self, num1: str, num2: str, min_sum: int, max_sum: int) -> int:
        def count_numbers(num: str, min_sum: int, max_sum: int):
            digits = list(map(int, num))  # Convert string to list of digits
            n = len(digits)

            # Memoization table dp[i][sumDigit][free]
            dp = [[[-1] * 2 for _ in range(401)] for _ in range(23)]  # MAXN = 23, MAXM = 401

            def dfs(i: int, digit_sum: int, free: int):
                if digit_sum > max_sum:  # Prune search if sum exceeds max_sum
                    return 0
                if digit_sum + (n - i) * 9 < min_sum:  # Prune if sum cannot reach min_sum
                    return 0
                if i == n:
                    return 1 if min_sum <= digit_sum <= max_sum else 0  # Valid number condition
                
                if dp[i][digit_sum][free] != -1:
                    return dp[i][digit_sum][free]

                ans = 0
                cur = digits[i]

                if free == 0:
                    # Case 2.1: Choose a digit < cur, next recursion is free
                    for digit in range(cur):
                        ans = (ans + dfs(i + 1, digit_sum + digit, 1)) % MOD
                    # Case 2.2: Choose cur as the digit, next recursion is not free
                    ans = (ans + dfs(i + 1, digit_sum + cur, 0)) % MOD
                else:
                    # Case 1: Free to choose any digit from 0 to 9
                    for digit in range(10):
                        ans = (ans + dfs(i + 1, digit_sum + digit, 1)) % MOD

                dp[i][digit_sum][free] = ans
                return ans

            return dfs(0, 0, 0)

        # Count numbers up to num2 and num1-1
        count2 = count_numbers(num2, min_sum, max_sum)
        count1 = count_numbers(str(int(num1) - 1), min_sum, max_sum)  # Convert to str and handle edge case

        return (count2 - count1) % MOD  # Ensure positive mod value

---
### Q3. Count Special Integers (LC.2376)
*We call a positive integer special if all of its digits are distinct.*

*Given a positive integer n, return the number of special integers that belong to the interval [1, n].*

**Solution:**     
Suppose number `n` has `m` digits.      
We first calculate number of all special integers with fewer digits, then find the cnt of special integers with exactly `m` digits.
Again, we construct a `cnts` array to help us deal with case 1:
- Define `cnts[i]`: given that the prefix is already fixed, and we have `i` positions left to be filled, how many different numbers can we find?
- Suppose our `n` has 7 digits, and the first two digits are already fixed. Therefore `cnts[5] = 8 * 7 * 6 * 5 * 4`. Similarly, `cnts[4] = 7 * 6 * 5 * 4` and `cnts[6] = 9 * 8 * 7 * 6 * 5 * 4`

In [3]:
class Solution(object):
    def countSpecialNumbers(self, n):
        num = list(map(int, str(n)))
        m = len(num)
        ans = 0
        cnts = [1] * m
        for i in range(1, m):
            # Build cnts[] array
            cnts[i] = cnts[i - 1] * (10 - m + i)
            # Add cnt of smaller numbers to ans
            curDigitCnt = 9
            for j in range(9, 9 - i + 1, -1):
                curDigitCnt *= j
            ans += curDigitCnt

        def dfs(i, status):
            if i == m:
                return 1
            cur = num[i]
            res = 0
            for digit in range(cur):
                if status & (1 << digit) == 0:
                    res += cnts[m - i - 1]
            if status & (1 << cur) == 0:
                res += dfs(i + 1, status | (1 << cur))
            return res
        
        first = num[0]
        ans += (first - 1) * cnts[m - 1]
        return ans + dfs(1, 1 << first)

---
### Q4. Numbers With Repeated Digits (LC.1012)
The exact same question as Q3. 
*Given an integer n, return the number of positive integers in the range [1, n] that have at least one repeated digit.*

In [15]:
class Solution:
    def numDupDigitsAtMostN(self, n: int) -> int:
        return n - self.countSpecialNumbers(n)

---
### Q5. Windy Number
*A Windy number is a positive integer that satisfies the following conditions:*
- *No leading zeros*
- *The absolute difference between any two adjacent digits is at least 2*

*Given an interval [a, b], count how many Windy numbers exist within this range.*

Test Link: https://www.luogu.com.cn/problem/P2657

In [26]:
def countWindyNumbers(a, b):
    num = list(map(int, str(a - 1)))
    m = len(num)
    # dp: int[m][pre][free]
    dp = [[[-1] * 2 for _ in range(10)] for _ in range(m1)]

    # pre: the number we choose for the last position
    # if pre == -1, it means that we have never chosen a number
    def dfs(i, pre, free):
        if i == m:
            return 1
        if dp[i][pre][free] != -1:
            return dp[i][pre][free]
        ans = 0
        cur = num[i]
        if not free:
            if pre == -1: 
                # not free and pre == -1, meaning that we are at the leftmost position, so we cannot choose 0, and we don't need to care about pre
                for digit in range(1, cur):
                    ans += dfs(i + 1, digit, True)
                ans += dfs(i + 1, cur, False)
            else:
                for digit in range(10):
                    if digit <= pre - 2 or digit >= pre + 2:
                        if digit < cur:
                            ans += dfs(i + 1, digit, True)
                        elif digit == cur:
                            ans += dfs(i + 1, cur, False)
                        else:
                            break
        # free to choose
        else:
            if pre == -1:
                # we havent use a position before, so dont need to think about pre
                # use the current position
                for digit in range(1, 10):
                    ans += dfs(i + 1, digit, True)
                # dont use the current position
                ans += dfs(i + 1, -1, True)
            else:
                for digit in range(10):
                    if digit <= pre - 2 or digit >= pre + 2:
                        ans += dfs(i + 1, digit, True)
                
        dp[i][pre][free] = ans
        return ans
        
    cntA = dfs(0, -1, False)
    
    num = list(map(int, str(b)))
    m = len(num)
    dp = [[[-1] * 2 for _ in range(10)] for _ in range(m)]
    cntB = dfs(0, -1, False)
    
    return cntB - cntA

---
### Q6. Cute Number
*A number is called a Cute Number if it contains at least one palindromic substring of length ≥ 2.*

*For example:*
- ✅ 101 (contains "101")
- ✅ 110 (contains "11")
- ✅ 111 (contains "11", "111")
- ✅ 1234321 (contains "23432", "1234321", etc.)
- ✅ 45568 (contains "55")
- ❌ 123456 (contains no palindromic substring of length ≥ 2)

*Given an integer range [l, r], count the number of cute numbers in this range.*

*Since the answer may be large, output the result modulo 1000000007.*

---
### Q7. Non-negative Integers Without Consecutive Ones (LC.600)
*Given a positive integer n, return the number of the integers in the range [0, n] whose binary representations do not contain consecutive ones.*

---
### Q8. Digit Count In Range (LC.1067 (same as LC.233))
*Given a single-digit integer `d` and two integers `low` and `high`, return the number of times that d occurs as a digit in all integers in the inclusive range [low, high].*