# State Compression DP
## Core Idea
State compression DP uses an bitmask to represent a set of states compactly:
- Each bit in the integer corresponds to whether an item or a condition is "active" or "used."
- This compact representation reduces memory usage and allows for efficient manipulation of states using bitwise operations.

## Typical Use Cases
Subset Problems:
  - Problems that involve checking or generating subsets of a given set, such as the Traveling Salesman Problem (TSP).
Assignment Problems:
  - Tasks that involve assigning a set of workers to a set of jobs with constraints.
Graph Problems:
  - Problems that involve paths or cycles in graphs, such as finding the shortest Hamiltonian path.
Games and Puzzles:
  - Problems where the current state of a game board can be encoded as a bitmask (e.g., solving N-puzzle-like problems).

## Limitations
Exponential Growth:
- The number of states grows exponentially with the size of the input: **O(2^n)** for n elements
Memory Usage:
- Even though the representation is compact, the DP table or memoization array can still be large: **O(2^n)**.
Thus, for a interview problem, if you want to use State Compression DP, the input size should be smaller or equals to 20.

## Common Bitwise Operations in State Compression DP
- Set a bit: `status |= (1 << i)`        
- Clear a bit: `status &= ~(1 << i) `        
- Check if a bit is set: `(status & (1 << i)) != 0`    
- Count the number of bits set: `bin(status).count('1')`       

---
### Q1. Traveling Salesman Problem (TSP)
A salesman needs to visit all cities exactly once and return to the starting point. Find the shortest path. Note that every city is connected to all other cities.

**Solution:**      
We use a bitmask `status` to represent visited cities(e.g., 10101 means cities 1, 3, and 5 are visited).
Then we define `dp[status][i]`: the cost to visited all cities starting from city `i` and go back to the city `0`, while the salesman have traveled to cities in `status`.

In [29]:
import random

class TSP:
    MAXN = 10

def tsp(graph, n):
    dp = [[-1] * n for _ in range(1 << n)]
    return dfs(graph, dp, 1, 0, n)

def dfs(graph, dp, status, i, n):
    # base case: we have visited every city, we just need to go back to city 0
    if status == (1 << n) - 1:
        return graph[i][0]
    if dp[status][i] != -1:
        return dp[status][i]
    ans = float('inf')
    for j in range(n):
        if status & (1 << j) == 0:
            ans = min(ans, dfs(graph, dp, status | (1 << j), j, n) + graph[i][j])
    dp[status][i] = ans
    return ans

def generate_graph(n):
    graph = [[0] * n for _ in range(n)]
    for i in range(n):
        for j in range(i + 1, n):
            distance = random.randint(1, 100)
            graph[i][j] = distance
            graph[j][i] = distance
    return graph

n = 10
graph = generate_graph(n)

print("Generated Graph:")
for row in graph:
    print(row)

result = tsp(graph, n)
print("Minimum cost of TSP:", result)

Generated Graph:
[0, 5, 46, 69, 63, 24, 59, 51, 72, 2]
[5, 0, 69, 81, 92, 47, 58, 82, 83, 46]
[46, 69, 0, 20, 63, 97, 57, 17, 65, 100]
[69, 81, 20, 0, 96, 68, 73, 58, 12, 23]
[63, 92, 63, 96, 0, 38, 4, 80, 37, 40]
[24, 47, 97, 68, 38, 0, 78, 69, 85, 82]
[59, 58, 57, 73, 4, 78, 0, 91, 20, 18]
[51, 82, 17, 58, 80, 69, 91, 0, 79, 7]
[72, 83, 65, 12, 37, 85, 20, 79, 0, 95]
[2, 46, 100, 23, 40, 82, 18, 7, 95, 0]
Minimum cost of TSP: 172


---
### Q2: Can I Win (LC.464)
*In the "100 game" two players take turns adding, to a running total, any integer from 1 to 10. The player who first causes the running total to reach or exceed 100 wins.*

*What if we change the game so that players cannot re-use integers?*

*For example, two players might take turns drawing from a common pool of numbers from 1 to 15 without replacement until they reach a total >= 100.*

*Given two integers maxChoosableInteger and desiredTotal, return true if the first player to move can force a win, otherwise, return false. Assume both players play optimally.*

*1 <= maxChoosableInteger <= 20, 
0 <= desiredTotal <= 30*    

**Solution:**    
Note that this is a 1D-DP. Although there are two parameter in our dfs function, `curSum` is actually totally determined by `status`. Therefore, we only need one parameter for our dp table0

In [None]:
class Solution:
    def canIWin(self, maxChoosableInteger: int, desiredTotal: int) -> bool:
        if desiredTotal == 0:
            return True
        if (maxChoosableInteger + 1) * maxChoosableInteger // 2 < desiredTotal:
            return False

        dp = [-1] * (1 << maxChoosableInteger)

        # Can the player win starting at this status?
        def dfs(status, curSum):
            # If curSum >= desiredTotal when its our turn, opponent already won
            if curSum >= desiredTotal:
                return False
            if dp[status] != -1:
                return dp[status] == 1
            for i in range(maxChoosableInteger):
                # add i + 1 because i start from 0 but our number start from 1
                if status & (1 << i) == 0:
                    if not dfs(status | (1 << i), curSum + i + 1):
                    # if opponent didn't win, we win
                        dp[status] = 1
                        return True
            dp[status] = 0
            return False

        return dfs(0, 0)

---
### Q3. Matchstick to Square (LC.473)
*You are given an integer array matchsticks where matchsticks[i] is the length of the ith matchstick. You want to use all the matchsticks to make one square. You should not break any stick, but you can link them up, and each matchstick must be used exactly one time.*       

*Return true if you can make this square and false otherwise.*

*1 <= matchsticks.length <= 15, 1 <= matchsticks[i] <= 108*

In [1]:
class Solution:
    def makesquare(self, matchsticks):
        circumference = sum(matchsticks)
        if circumference % 4 != 0:
            return False
        sideLen = circumference // 4
        n = len(matchsticks)
        dp = [-1] * (1 << n)

        def dfs(status, curLen, sideLeft):
            if sideLeft == 0:
                return status == 0
            if dp[status] != -1:
                return dp[status] == 1
            ans = False
            for i in range(n):
                if status & (1 << i) != 0 and curLen + matchsticks[i] <= sideLen:
                    if curLen + matchsticks[i] == sideLen:
                        ans = dfs(status ^ (1 << i), 0, sideLeft - 1)
                    else:
                        ans = dfs(status ^ (1 << i), curLen + matchsticks[i], sideLeft)
                    if ans == True:
                        break
            dp[status] = 1 if ans else 0
            return ans

        return dfs((1 << n) - 1, 0, 4)

---
## Q4.Number of Ways To Wear Different Hats To Each Other (LC.1434)
*There are n people and 40 types of hats labeled from 1 to 40.*     

*Given a 2D integer array hats, where hats[i] is a list of all hats preferred by the ith person.*

*Return the number of ways that the n people wear different hats to each other.*

*Since the answer may be too large, return it modulo 109 + 7.*

**Solution:**        
First lets analyze the problem's size. There are 40 types of hats, but n <= 10, therefore we should use our status to represent whether a person has a hat instead of representing whether a hat has been wore, since 2^40 is too large.        

Therefore, we need to create a reversed map where the key is the hat and the value is people. Since there is less than 10 people, we can use a bitset to represent the people who prefer this hat.

In [17]:
class Solution:
    MOD = 1000000000 + 7
    
    def numberWays(self, hats):
        n = len(hats)
        # m is the max label of hats, since not all 40 labels may be used
        m = 0
        for i in range(n):
            m = max(m, max(hats[i]))
        
        hatsmap = [0] * (m + 1)
        for i in range(n):
            for hat in hats[i]:
                hatsmap[hat] = hatsmap[hat] | (1 << i)
                
        dp = [[-1] * (1 << n) for _ in range(m + 1)]

        def dfs(i, status):
            # every people has a hat
            if status == 0:
                return 1
                
            # status is not 0 but we are out of hats
            if i > m:
                return 0
                
            if dp[i][status] != -1:
                return dp[i][status]
                
            # case 1: we don't use hat number i at all
            ans = dfs(i + 1, status)
            # case 2: use hat i then recurse
            curHat = hatsmap[i]
            for person in range(n):
                if curHat & (1 << person) != 0 and status & (1 << person) != 0:
                    ans = (ans + dfs(i + 1, status ^ (1 << person))) % self.MOD
                    
            dp[i][status] = ans
            return ans

        # Optimization:
        # Instead of using a for loop to iterate through all 10 people and check, we can just get all the 1s from curHat
        def dfs2(i, status):
            if status == 0:
                return 1
                
            if i > m:
                return 0
                
            if dp[i][status] != -1:
                return dp[i][status]
                
            ans = dfs2(i + 1, status)
            curHat = hatsmap[i]
            while curHat != 0:
                rightMostOne = curHat & -curHat
                if status & rightMostOne != 0:
                    ans = (ans + dfs2(i + 1, status ^ rightMostOne)) % self.MOD
                curHat ^= rightMostOne
                    
            dp[i][status] = ans
            return ans

        return dfs(1, (1 << n) - 1)

---
### Q5.Optimal Account Balancing (LC.465)
*You are given an array of transactions transactions where transactions[i] = [fromi, toi, amounti] indicates that the person with ID = fromi gave amounti $ to the person with ID = toi.*         

*Return the minimum number of transactions required to settle the debt.*

**Solution:**         
**Analysis 1:**          
From the example given we can tell that this is not really about "debt": we only care about everyone's balance but not the relationship of "who owes who".      
The only thing we want to do is to make everyone's balance 0 again after all the transactions.         
Therefore, we can use an array to record each person's balance, and we want to make this array full of 0s.

**Analysis 2:**         
How many transaction do we need?  
Suppose we have a balance array of this: `[4, -2, -2, 2, 2, 4]`. 
- One way is to group `4, -2, -2` together and `2, 2, 4` together. then we need 2 + 2 = 4 transactions in total.
- The other way is we group `4, -4` together, `2, -2` together, and `2, -2` together. Then we need 1 + 1 + 1 = 3 transaction in total.

As we can see from this example, we are essentially **breaking the array into multiple sets. The sum of each set need to be 0.**

**Analysis 3:**            
We want to break the array into **as many atomic set with sum of 0 as possible**. Atomic means the set cannot be broke into smaller sets again.          
Then from the example we can see, for a atomic set of `n` elements, we need `n - 1` transactions.       
Because the number of elements is fixed, `n1 + n2 + n3 + ... + nk` is a fixed. Therefore, the more sets we have, the less number of transactions we need.     
The total number of transactions needed is `numElements - numSets`

**dfs:**   
Now we successfully redefined the problem: **how can we break the array into as many atomic sets with sum of 0 as possible**

Lets have an example array, `6, -3, -3, 2, -2, 5, -5`, this is an array with sum of 0.
We pick an element, then recurse to see how many sets can we have for the rest of the element.
You will see that for an array with sum of 0 picking any element will give us the same result.
Suppose we pick 6, then we will have `2, -2`, `5, -5`. Then we add 1 for `6, -3, -3`. We have a total of 3 sets
Suppose we pick 5, we will have `6, -3, -3`, `2, -2`. Then again we have total of 3 sets.

However, for an array with sum that is not 0, we need to try every element.

During our recursion, we can use state compression to keep track of the elements in a set.

In [None]:
class Solution:
    MAXN = 13

    def minTransfers(self, transactions: List[List[int]]) -> int:
        balances = self.buildBalance(transactions)
        m = len(balances)
        dp = [-1] * (1 << m)

        def dfs(status, arrSum):
            if dp[status] != -1:
                return dp[status]
            ans = 0
            # if the current set has more than one element
            if status & (status - 1) != 0:
                if arrSum == 0:
                    for i in range(m):
                        if status & (1 << i) != 0:
                            ans = dfs(status ^ (1 << i), arrSum - balances[i]) + 1
                            break
                # if the current sum is not 0, we need to try every element
                else:
                    for i in range(m):
                        if status & (1 << i) != 0:
                            ans = max(ans, dfs(status ^ (1 << i), arrSum - balances[i]))
            # note that if the set has only one element, ans will be just 0
            dp[status] = ans
            return ans

        return m - dfs((1 << m) - 1, 0)

        # Build an array of balances that doesn't contains 0
    def buildBalance(self, transactions):
        balanceWithZeros = [0] * self.MAXN
        for transaction in transactions:
            balanceWithZeros[transaction[0]] -= transaction[2]
            balanceWithZeros[transaction[1]] += transaction[2]
        balances = []
        for balance in balanceWithZeros:
            if balance != 0:
                balances.append(balance)
        return balances

---
## Technique: Enumerating All Subsets of A Status
Suppose we have a status 1011, if we want to iterate through all the subset status(1000, 0010, 0001, 1010, 1001, 0011, 1011), use this for loop:

In [51]:
status = 0b0000001011
j = status
while j > 0:
    # Do something with j
    print(bin(j))
    j = (j - 1) & status

0b1011
0b1010
0b1001
0b1000
0b11
0b10
0b1


---
### Q6. Distribute Repeating Integers (LC.1655)
*You are given an array of n integers, nums, where there are at most 50 unique values in the array. You are also given an array of m customer order quantities, quantity, where quantity[i] is the amount of integers the ith customer ordered. Determine if it is possible to distribute nums such that:*

- *The ith customer gets exactly quantity[i] integers,*
- *The integers the ith customer gets are all equal, and*
- *Every customer is satisfied.*

*Return true if it is possible to distribute nums according to the above conditions.*

---
### Q7. The Number Of Good Subsets (LC.1994)
*You are given an integer array nums. We call a subset of nums good if its product can be represented as a product of one or more distinct prime numbers.*

*For example, if nums = [1, 2, 3, 4]:
[2, 3], [1, 2, 3], and [1, 3] are good subsets with products 6 = 2*3, 6 = 2*3, and 3 = 3 respectively.
[1, 4] and [4] are not good subsets with products 4 = 2*2 and 4 = 2*2 respectively.
Return the number of different good subsets in nums modulo 109 + 7.*

*A subset of nums is any array that can be obtained by deleting some (possibly none or all) elements from nums. Two subsets are different if and only if the chosen indices to delete are different.*

In [37]:
from typing import List
from collections import defaultdict

class Solution:
    MAXV = 30
    LIMIT = (1 << 10)
    MOD = 1000000007

    # Mapping to determine the prime factor status of numbers
    own = [
        0b0000000000,  # 0
        0b0000000000,  # 1
        0b0000000001,  # 2
        0b0000000010,  # 3
        0b0000000000,  # 4
        0b0000000100,  # 5
        0b0000000011,  # 6
        0b0000001000,  # 7
        0b0000000000,  # 8
        0b0000000000,  # 9
        0b0000000101,  # 10
        0b0000010000,  # 11
        0b0000000000,  # 12
        0b0000100000,  # 13
        0b0000001001,  # 14
        0b0000000110,  # 15
        0b0000000000,  # 16
        0b0001000000,  # 17
        0b0000000000,  # 18
        0b0010000000,  # 19
        0b0000000000,  # 20
        0b0000001010,  # 21
        0b0000010001,  # 22
        0b0100000000,  # 23
        0b0000000000,  # 24
        0b0000000000,  # 25
        0b0000100001,  # 26
        0b0000000000,  # 27
        0b0000000000,  # 28
        0b1000000000,  # 29
        0b0000000111   # 30
    ]

    def numberOfGoodSubsetsSimple(self, nums: List[int]) -> int:
        cnt = [0] * (self.MAXV + 1)
        for num in nums:
            cnt[num] += 1

        dp = [[-1] * self.LIMIT for _ in range(self.MAXV + 1)]

        def f1(i, s):
            if dp[i][s] != -1:
                return dp[i][s]
            ans = 0
            if i == 1:
                if s == 0:
                    ans = 1
                    for _ in range(cnt[1]):
                        ans = (ans * 2) % self.MOD
            else:
                ans = f1(i - 1, s)
                cur = self.own[i]
                times = cnt[i]
                if cur != 0 and times != 0 and (s & cur) == cur:
                    ans = (f1(i - 1, s ^ cur) * times + ans) % self.MOD
            dp[i][s] = ans
            return ans

        ans = 0
        for s in range(1, self.LIMIT):
            ans = (ans + f1(self.MAXV, s)) % self.MOD
        return ans

    def numberOfGoodSubsets(self, nums: List[int]) -> int:
        cnt = [0] * (self.MAXV + 1)
        dp = [0] * self.LIMIT

        for num in nums:
            cnt[num] += 1

        dp[0] = 1
        for _ in range(cnt[1]):
            dp[0] = (dp[0] * 2) % self.MOD

        for i in range(2, self.MAXV + 1):
            cur = self.own[i]
            times = cnt[i]
            if cur != 0 and times != 0:
                for status in range(self.LIMIT - 1, -1, -1):
                    if (status & cur) == cur:
                        dp[status] = (dp[status] + dp[status ^ cur] * times) % self.MOD

        ans = 0
        for s in range(1, self.LIMIT):
            ans = (ans + dp[s]) % self.MOD
        return ans