# Coding Exercise

You must correctly implement the function described in the prompt below.

Feel free to test out pieces of code to help you write the solution.

Please thoroughly test that the final code implements the function correctly.

## Prompt

**Function signature:** `getAllKeys(N: int, M: int) -> str`

    There are N boxes numbered from 1 to N and N keys numbered from 1 to N. 
    The i-th key can only be used to open the i-th box.
    Now, we randomly put exactly one key into each of the boxes.
    We assume that all configurations of keys in boxes occur with the same probability. Then we lock all the boxes.
    You have M bombs, each of which can be used to open one locked box.
    Once you open a locked box, you can get the key in it and perhaps open another locked box with that key.
    Your strategy is to select a box, open it with a bomb, take the key and open all the boxes you can and then repeat with another bomb.

    Return the probability that you can get all the keys.
    The return value must be a string formatted as "A/B" (quotes for clarity), representing the probability as a fraction.  A and B must both be positive integers with no leading zeroes, and the greatest common divisor of A and B must be 1.

    Constraints
    -N will be between 1 and 20, inclusive.
    -M will be between 1 and N, inclusive.
 
    Examples
    0)
        2
    1

    Returns: "1/2"
    When box 1 contains key 2, you can get all the keys.

    1)
        2
    2

    Returns: "1/1"
    When N=M, you can always get all the keys.

    2)
        3
    1

    Returns: "1/3"
    There are 6 possible configurations of keys in boxes. Using 1 bomb, you can open all the boxes in 2 of them:

    box 1 - key 2, box 2 - key 3, box 3 - key 1;
    box 1 - key 3, box 2 - key 1, box 3 - key 2.

    3)
        3
    2

    Returns: "5/6"
    Now, when you have 2 bombs, you are only unable to get all the keys in the following configuration: box 1 - key 1, box 2 - key 2, box 3 - key 3.

    4)
        4
    2

    Returns: "17/24"

    

For now I have no idea on how to solve the problem except by using brute force, that is, by testing all possible combinations.

Testing for all combinations is very bad since there are $20! = 2432902008176640000 \sim 2.4 \times 10^{18}$ possibilities.
However, I think that this problem has a symmetry the allows me to solve the problem with N-M bombs and get the results for M bombs. This means that I will need to test only $10! = 3628800 \sim 3.6 \times 10^{6}$ different combinations.

The initial ideia is to use this symmetry with memoization to speed up the algorithm.
To use memoization I will store the number of configurations for the remaining keys and boxes.
In another words, I will store the results for the remaining boxes (containing keys) and keys (True if available or False if used).

In [None]:
import copy

def solve_config(N, M, boxes, key):
    
    n_configs = 0
    
    # no key and no bombs and there are open boxes
    if key == -1 and M == 0 and boxes.count(-1) != len(boxes):
        return 0
    elif boxes.count(-1) == len(boxes):
        return 1
    
    # cannot open a new box with key (must use bomb)
    if boxes[key] == -1:
        for i in range(len(boxes)):
            if boxes[i] == -1:
                continue
                
            new_boxes = copy.copy(boxes)
            
            # use bomb
            new_key = new_boxes[i]
            new_boxes[new_key] = -1
            n_configs += solve_config(N, M-1, new_boxes, new_key)
    
    else:
        new_boxes = copy.copy(boxes)

        # use key
        new_key = new_boxes[key]
        
        # box was already opened
        # must use bomb
        if new_key == -1:
            for i in range(len(boxes)):
                if boxes[i] == -1:
                    continue

                new_boxes = copy.copy(boxes)

                # use bomb
                new_key = new_boxes[i]
                new_boxes[new_key] = -1
                n_configs += solve_config(N, M-1, new_boxes, new_key)
        else:
            new_boxes[new_key] = -1
            n_configs += solve_config(N, M-1, new_boxes, new_key)
        
    
    new_key = boxes[key]
    boxes[key] = -1
    
    
    


def getAllKeys(N: int, M: int) -> str:
    
    if M > N//2:
        res = getAllKeys(N, N-M).split('/')
        return f"{res[0]-res[1]}/{res[0]}"
    
    
    
    
    
    

Before writing the code I thought I could solve the problem by brute force by solving the equivalent problem of N-M bombs,
however this makes no sense.
While writing the code, I also realized that my initial ideia de memoization would not work due to the huge number of different states.

After thinking for some time, there is an equivalent problem that solves this one:
consider the same case, however boxes can also be empty.

How empty boxes can help solve this problem?
Whenever we use a bomb to open a box one of the keys become useless or a box becomes emtpy.
This by itself does not change anything.
However, if we consider the property that the keys are (equaly) randomly distributed,
the ideia of emtpy boxes allows us to ignore which keys are in the boxes.
Now, we just need to know how many boxes are emtpy, and how many have keys.

Let us consider an example of N = 10 and M = 3.

- Initially, we need to use a bomb. Whenever we use a bomb, the number of boxes decreases (N-=1), the number of bombs decrease, and the number of empty boxes increase (the key of the box we exploded cannot open anything).
- In the next iteration, we can select N-1 boxes: N-2 with a key or 1 that is empty. Here, it does not matter which box we select, it only matters if the box is empty of has a key. Hence, we do not need to keep track of boxes or keys, but only the number of empty boxes.

Now I will implement this idea.
I will start implementing a function that takes N, M, emtpy boxes as input and will return the number of combinations.
Whenever I call the function recursively, I will also multiply the result by the number of boxes that I can select.

Finally, I will keep track if the box that I opened in the last step has a key, and will use memoization to speed it up.

PS: the next function will calculate the number of ways we can finish (the numerator).

In [3]:
def solve_with_empty_boxes(N, M, empty, has_key=False, memory=None):
    
    # cannot open all boxes
    if M > empty:
        return 0
    
    if N == 0:
        return 1
    
    if memory is None:
        memory = dict()
        
    h = hash(str([N, M, empty, has_key]))
    if h in memory.keys():
        return memory[h]
    
    count = 0
    
    # open box for free
    if has_key:
        # box can have a key
        if N-empty != 0:
            count += (N-empty)*solve_with_empty_boxes(N-1, M, empty, has_key=True, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M, empty, memory=memory)
        
    # use bomb
    elif M > 0:
        # box can have a key
        if N-empty != 0:
            count += (N-empty)*solve_with_empty_boxes(N-1, M-1, empty, has_key=True, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M-1, empty, memory=memory)
    
    memory[h] = count
    
    return count
        
        
solve_with_empty_boxes(2, 1, 0)

0

The result above is incorrect. It should be 1.

I forgot to increase/decrease empty boxes.

Empty boxes decreases whenever I open an empty one, and increase when I use a bomb.

In [7]:
def solve_with_empty_boxes(N, M, empty, has_key=False, memory=None):
    
    # cannot open all boxes
    if M > empty:
        return 0
    
    if N == 0:
        return 1
    
    if memory is None:
        memory = dict()
        
    h = hash(str([N, M, empty, has_key]))
    if h in memory.keys():
        return memory[h]
    
    count = 0
    
    # open box for free
    if has_key:
        # box can have a key
        if N-empty != 0:
            count += (N-empty)*solve_with_empty_boxes(N-1, M, empty, has_key=True, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M, empty-1, memory=memory)
        
    # use bomb
    elif M > 0:
        # box can have a key
        if N-empty != 0:
            count += (N-empty)*solve_with_empty_boxes(N-1, M-1, empty+1, has_key=True, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M-1, empty, memory=memory)
    
    memory[h] = count
    
    return count
        
        
solve_with_empty_boxes(2, 1, 0)

0

Box can have its own key, therefore I need to consider the case when I explode the box containing its own key.
In this case, the number of empty boxes do not increase.

In [20]:
def solve_with_empty_boxes(N, M, empty, has_key=False, memory=None):
    
    if N == 0:
        return 1
    
    # cannot open all boxes
    if M > empty and not has_key:
        return 0
    
    if memory is None:
        memory = dict()
        
    h = hash(str([N, M, empty, has_key]))
    if h in memory.keys():
        return memory[h]
    count = 0
    
    # open box for free
    if has_key:
        # box can have a key
        if N-empty != 0:
            count += (N-empty)*solve_with_empty_boxes(N-1, M, empty, has_key=True, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M, empty-1, has_key=False, memory=memory)
        
    # use bomb
    elif M > 0:
        # box can have the key of another box
        if N-empty-1 > 0:
            count += (N-empty-1)*solve_with_empty_boxes(N-1, M-1, empty+1, has_key=True, memory=memory)
        
        # box can have its own key
        count += solve_with_empty_boxes(N-1, M-1, empty+1, has_key=False, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M-1, empty, has_key=False, memory=memory)
    
    memory[h] = count
    
    return count
        
        
solve_with_empty_boxes(2, 1, 0)

2 1 0 False


0

M > empty condition is wrong.

The write condition is M < emtpy. This means that we cannot open the all remaining boxes because the ran out of keys and bombs.

In [23]:
def solve_with_empty_boxes(N, M, empty, has_key=False, memory=None):
    
    if N == 0:
        return 1
    
    # cannot open all boxes
    if M < empty and not has_key:
        return 0
    
    if memory is None:
        memory = dict()
        
    h = hash(str([N, M, empty, has_key]))
    if h in memory.keys():
        return memory[h]
    count = 0
    
    # open box for free
    if has_key:
        # box can have a key
        if N-empty != 0:
            count += (N-empty)*solve_with_empty_boxes(N-1, M, empty, has_key=True, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M, empty-1, has_key=False, memory=memory)
        
    # use bomb
    elif M > 0:
        # box can have the key of another box
        if N-empty-1 > 0:
            count += (N-empty-1)*solve_with_empty_boxes(N-1, M-1, empty+1, has_key=True, memory=memory)
        
        # box can have its own key
        count += solve_with_empty_boxes(N-1, M-1, empty+1, has_key=False, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M-1, empty, has_key=False, memory=memory)
    
    memory[h] = count
    
    return count
        
        
solve_with_empty_boxes(2, 1, 0)

1

Now the result is correct.

Let me implment the final function to output the result according to the prompt.

For the final function, I will call solve_with_empty_boxes to calculate the numerator.
The denominator is simply the total number of combinations $N!$.

To simply the fraction I will divide both by the greatest common divisor (gcd), until gcd == 1.

In [26]:
from math import factorial, gcd

def solve_with_empty_boxes(N, M, empty, has_key=False, memory=None):
    
    if N == 0:
        return 1
    
    # cannot open all boxes
    if M < empty and not has_key:
        return 0
    
    if memory is None:
        memory = dict()
        
    h = hash(str([N, M, empty, has_key]))
    if h in memory.keys():
        return memory[h]
    
    count = 0
    
    # open box for free
    if has_key:
        # box can have a key
        if N-empty != 0:
            count += (N-empty)*solve_with_empty_boxes(N-1, M, empty, has_key=True, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M, empty-1, has_key=False, memory=memory)
        
    # use bomb
    elif M > 0:
        # box can have the key of another box
        if N-empty-1 > 0:
            count += (N-empty-1)*solve_with_empty_boxes(N-1, M-1, empty+1, has_key=True, memory=memory)
        
        # box can have its own key
        count += solve_with_empty_boxes(N-1, M-1, empty+1, has_key=False, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M-1, empty, has_key=False, memory=memory)
    
    memory[h] = count
    
    return count

def getAllKeys(N: int, M: int) -> str:
    
    denominator = factorial(N)
    numerator = solve_with_empty_boxes(N, M, 0)
    
    div = gcd(numerator, denominator)
    while div != 1:
        numerator //= div
        denominator //= div
        
    return f"{numerator}/{denominator}"
        
getAllKeys(2, 1)

'1/2'

In [27]:
def test():
    assert (a:=getAllKeys(2, 1)) == '1/2', a
    assert (a:=getAllKeys(2, 2)) == '1/1', a
    print('done')

test()

AssertionError: 3/2

Looks like I am counting it more than I should.

I will print the inputs to help me debug.

In [28]:
from math import factorial, gcd

def solve_with_empty_boxes(N, M, empty, has_key=False, memory=None):
    
    print(N, M, empty, has_key)
    
    if N == 0:
        return 1
    
    # cannot open all boxes
    if M < empty and not has_key:
        return 0
    
    if memory is None:
        memory = dict()
        
    h = hash(str([N, M, empty, has_key]))
    if h in memory.keys():
        return memory[h]
    
    count = 0
    
    # open box for free
    if has_key:
        # box can have a key
        if N-empty != 0:
            count += (N-empty)*solve_with_empty_boxes(N-1, M, empty, has_key=True, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M, empty-1, has_key=False, memory=memory)
        
    # use bomb
    elif M > 0:
        # box can have the key of another box
        if N-empty-1 > 0:
            count += (N-empty-1)*solve_with_empty_boxes(N-1, M-1, empty+1, has_key=True, memory=memory)
        
        # box can have its own key
        count += solve_with_empty_boxes(N-1, M-1, empty+1, has_key=False, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M-1, empty, has_key=False, memory=memory)
    
    memory[h] = count
    
    return count

def getAllKeys(N: int, M: int) -> str:
    
    denominator = factorial(N)
    numerator = solve_with_empty_boxes(N, M, 0)
    
    div = gcd(numerator, denominator)
    while div != 1:
        numerator //= div
        denominator //= div
        
    return f"{numerator}/{denominator}"
        
def test():
#     assert (a:=getAllKeys(2, 1)) == '1/2', a
    assert (a:=getAllKeys(2, 2)) == '1/1', a
    print('done')

test()

2 2 0 False
1 1 1 True
0 1 0 False
1 1 1 False
0 0 2 False
0 0 1 False


AssertionError: 3/2

The second to last row show that there is more empty boxes than it should.

Reviewing the code, I realized that, whenever I explode a box containing is own key, the number of empty boxes do not increase.

Also, in the while loop, I forgot to calculate div once againg.

In [30]:
from math import factorial, gcd

def solve_with_empty_boxes(N, M, empty, has_key=False, memory=None):
    
    print(N, M, empty, has_key)
    
    if N == 0:
        return 1
    
    # cannot open all boxes
    if M < empty and not has_key:
        return 0
    
    if memory is None:
        memory = dict()
        
    h = hash(str([N, M, empty, has_key]))
    if h in memory.keys():
        return memory[h]
    
    count = 0
    
    # open box for free
    if has_key:
        # box can have a key
        if N-empty != 0:
            count += (N-empty)*solve_with_empty_boxes(N-1, M, empty, has_key=True, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M, empty-1, has_key=False, memory=memory)
        
    # use bomb
    elif M > 0:
        # box can have the key of another box
        if N-empty-1 > 0:
            count += (N-empty-1)*solve_with_empty_boxes(N-1, M-1, empty+1, has_key=True, memory=memory)
        
        # box can have its own key
        count += solve_with_empty_boxes(N-1, M-1, empty, has_key=False, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M-1, empty, has_key=False, memory=memory)
    
    memory[h] = count
    
    return count

def getAllKeys(N: int, M: int) -> str:
    
    denominator = factorial(N)
    numerator = solve_with_empty_boxes(N, M, 0)
    
    div = gcd(numerator, denominator)
    while div != 1:
        numerator //= div
        denominator //= div
        div = gcd(numerator, denominator)
        
    return f"{numerator}/{denominator}"
        
def test():
#     assert (a:=getAllKeys(2, 1)) == '1/2', a
    assert (a:=getAllKeys(2, 2)) == '1/1', a
    print('done')

test()

2 2 0 False
1 1 1 True
0 1 0 False
1 1 0 False
0 0 0 False
done


I will remove print and keep testing 

In [32]:
from math import factorial, gcd

def solve_with_empty_boxes(N, M, empty, has_key=False, memory=None):
    
    if N == 0:
        return 1
    
    # cannot open all boxes
    if M < empty and not has_key:
        return 0
    
    if memory is None:
        memory = dict()
        
    h = hash(str([N, M, empty, has_key]))
    if h in memory.keys():
        return memory[h]
    
    count = 0
    
    # open box for free
    if has_key:
        # box can have a key
        if N-empty != 0:
            count += (N-empty)*solve_with_empty_boxes(N-1, M, empty, has_key=True, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M, empty-1, has_key=False, memory=memory)
        
    # use bomb
    elif M > 0:
        # box can have the key of another box
        if N-empty-1 > 0:
            count += (N-empty-1)*solve_with_empty_boxes(N-1, M-1, empty+1, has_key=True, memory=memory)
        
        # box can have its own key
        count += solve_with_empty_boxes(N-1, M-1, empty, has_key=False, memory=memory)
        
        # or can be empty
        if empty != 0:
            count += empty*solve_with_empty_boxes(N-1, M-1, empty, has_key=False, memory=memory)
    
    memory[h] = count
    
    return count

def getAllKeys(N: int, M: int) -> str:
    
    denominator = factorial(N)
    numerator = solve_with_empty_boxes(N, M, 0)
    
    div = gcd(numerator, denominator)
    while div != 1:
        numerator //= div
        denominator //= div
        div = gcd(numerator, denominator)
        
    return f"{numerator}/{denominator}"
        
def test():
    assert (a:=getAllKeys(2, 1)) == '1/2', a
    assert (a:=getAllKeys(2, 2)) == '1/1', a
    assert (a:=getAllKeys(3, 1)) == '1/3', a
    print('done')

test()

done


In [33]:
def test():
    assert (a:=getAllKeys(2, 1)) == '1/2', a
    assert (a:=getAllKeys(2, 2)) == '1/1', a
    assert (a:=getAllKeys(3, 1)) == '1/3', a
    assert (a:=getAllKeys(3, 2)) == '5/6', a
    print('done')

test()

done


In [34]:
def test():
    assert (a:=getAllKeys(2, 1)) == '1/2', a
    assert (a:=getAllKeys(2, 2)) == '1/1', a
    assert (a:=getAllKeys(3, 1)) == '1/3', a
    assert (a:=getAllKeys(3, 2)) == '5/6', a
    assert (a:=getAllKeys(4, 2)) == '17/24', a
    print('done')

test()

done


Test time execution 

In [35]:
from datetime import datetime

for N in range(1, 21):
    for M in range(1, N):
        
        t1 = datetime.now()
        results = getAllKeys(N, M)
        t2 = datetime.now()
        
        print(N, M, t2-t1, results)

2 1 0:00:00.000031 1/2
3 1 0:00:00.000032 1/3
3 2 0:00:00.000018 5/6
4 1 0:00:00.000016 1/4
4 2 0:00:00.000027 17/24
4 3 0:00:00.000030 23/24
5 1 0:00:00.000019 1/5
5 2 0:00:00.000036 37/60
5 3 0:00:00.000053 109/120
5 4 0:00:00.000051 119/120
6 1 0:00:00.000024 1/6
6 2 0:00:00.000048 197/360
6 3 0:00:00.000066 619/720
6 4 0:00:00.000077 44/45
6 5 0:00:00.000078 719/720
7 1 0:00:00.000025 1/7
7 2 0:00:00.000055 69/140
7 3 0:00:00.000080 1027/1260
7 4 0:00:00.000096 4843/5040
7 5 0:00:00.000108 2509/2520
7 6 0:00:00.000086 5039/5040
8 1 0:00:00.000021 1/8
8 2 0:00:00.000050 503/1120
8 3 0:00:00.000073 781/1008
8 4 0:00:00.000091 38009/40320
8 5 0:00:00.000106 4441/4480
8 6 0:00:00.000117 40291/40320
8 7 0:00:00.000115 40319/40320
9 1 0:00:00.000025 1/9
9 2 0:00:00.000057 347/840
9 3 0:00:00.000087 67007/90720
9 4 0:00:00.000109 20957/22680
9 5 0:00:00.000128 357761/362880
9 6 0:00:00.000139 362297/362880
9 7 0:00:00.000144 362843/362880
9 8 0:00:00.000168 362879/362880
10 1 0:00:00.0000

It is pretty fast!

Lastly, there is a special case when we only have one bomb.
In this case the bomb is used in the beggining and we cannot select the emtpy box.

- Starting with N boxes, we explode any of them
    - if the exploded box contains its own key, we count 0
    - if the exploded box contains another key, there are N-1 possibilities.
- The next box can be any of the remaining N-1, except the empty (multiply by N-2)
- The next box can be any of the remaining N-2, except the empty (multiply by N-3)
- $\ldots$
- The next box can be any of the remaining 2, except the empty (multiply by 1)
- choose the empty

This gives us a probabilty of $(N-1)!/N! = 1/N$

In [36]:
from datetime import datetime

for N in range(1, 21):
    for M in range(1, 2):
        
        t1 = datetime.now()
        results = getAllKeys(N, M)
        t2 = datetime.now()
        
        assert results == f'1/{N}', results
        
        print(N, M, t2-t1, results)

1 1 0:00:00.000018 1/1
2 1 0:00:00.000014 1/2
3 1 0:00:00.000014 1/3
4 1 0:00:00.000016 1/4
5 1 0:00:00.000018 1/5
6 1 0:00:00.000021 1/6
7 1 0:00:00.000024 1/7
8 1 0:00:00.000028 1/8
9 1 0:00:00.000031 1/9
10 1 0:00:00.000034 1/10
11 1 0:00:00.000037 1/11
12 1 0:00:00.000041 1/12
13 1 0:00:00.000044 1/13
14 1 0:00:00.000047 1/14
15 1 0:00:00.000051 1/15
16 1 0:00:00.000054 1/16
17 1 0:00:00.000056 1/17
18 1 0:00:00.000061 1/18
19 1 0:00:00.000064 1/19
20 1 0:00:00.000067 1/20
