# Removing Minimum Number of Magic Beans

**Problem**:
Given an array `beans` of positive integers, where each integer represents the number of magic beans in a magic bag, the task is to remove beans from each bag such that all remaining non-empty bags have an equal number of beans. You cannot return beans to any bag once removed. Find the minimum number of magic beans that you need to remove to achieve this.

**Examples**:

1. **Input**:
   `beans = [4,1,6,5]`
   
   **Output**: `4`
   
   **Explanation**:
   Remove beans to result in bags with [4,0,4,4]. A total of 1 + 2 + 1 = 4 beans are removed.

2. **Input**:
   `beans = [2,10,3,2]`
   
   **Output**: `7`
   
   **Explanation**:
   Remove beans to result in bags with [0,10,0,0]. A total of 2 + 2 + 3 = 7 beans are removed.

**Constraints**:
- `1 <= beans.length <= 10^5`
- `1 <= beans[i] <= 10^5`


In [3]:
from typing import List
def test(s):
    test_cases = [[4,1,6,5], [2,10,3,2]]
    ref = [4, 7]
    for i, (beans, expected) in enumerate(zip(test_cases, ref)):
        assert s.minimumRemoval(beans) == expected, f"wrong answer at test case {i + 1}: beans = {beans}"
    print("Succeed")

# Example usage
# s = Solution()
# test(s)

In [7]:
'''
    Have to say that this comes to my mind at first glance but it's
    O(N^2) in time complexity so unfortunately fails the time limits.
'''

class Solution1:
    def minimumRemoval(self, beans: List[int]) -> int:
        def calc(num):
            res = 0
            for i in range(1, len(num)):
                res += num[i] - num[0]
            return res
        beans.sort()
        min_list = []
        sum = 0
        for i in range(len(beans)):
            min_list.append(calc(beans[i:])+sum)
            sum += beans[i]

        return min(min_list)

test(Solution1())

Succeed


In [16]:
'''
    Prefix sum. Time complexity is O(N)

    Instead of thinking what to remove, think what to keep.
'''

class Solution2:
    def minimumRemoval(self, beans: List[int]) -> int:
        beans.sort()
        n = len(beans)
        ans = float("inf")
        prefix_sum = [0] * (n+1)

        for i in range(1, n+1):
            prefix_sum[i] = beans[i-1] + prefix_sum[i-1]

        for i in range(n):
            tmp = prefix_sum[n] - (n-i)*beans[i]
            '''
                I learned a lot..
                for every i when we iterate beans[], tmp is actually consisted of three parts:
                (1). The sum of the past elements(which can be handled directly using prefix_sum[])
                (2). Itself.
                (3). The number of beans need to be taken if we want to make the following elements the same as beans[i],
                     This can be handled using prefix_sum[] indirectly(which is smart) as : sum of the rest - (n-i-1)*beans[i]

                So if we decompose this expression it's actually like this:
                tmp = prefix_sum[i]
                tmp += beans[i]
                tmp += (prefix_sum[i] - prefix_sum[i+1]) - (n-i-1)*beans[i]
            '''
            ans = min(ans, tmp)

        return ans


test(Solution2())

Succeed
