In [53]:
from typing import List, Tuple
from sortedcontainers import SortedList

class Solution:
    # given a sorted list of integers, find all unique pairs that sum to target
    # assume that the list has at least 2 elements
    def twoSum(self, nums: List[int], target: int) -> List[List[int]]:
        result = set()
        nums_sort = SortedList(nums)
        print(nums_sort)

        for i in range(len(nums_sort)):
            val = target - nums_sort[i]
            idx = nums_sort.bisect_left(val)
            if idx < len(nums_sort) and nums_sort[idx] == val:
                result.add(tuple(sorted([nums_sort[i], val])))
        return list(result)
        

    # simple implementation of 4sum which is O(n^4). we can improve this
    # by using a hashmap to store the sum of two numbers and then checking
    # if the remaining sum exists in the hashmap. this would reduce the time
    # complexity to O(n^2) average case
    #
    # alternatively, we can get the two number and use two pointers to find
    # the remaining two numbers. this would reduce the time complexity to
    # O(n^3) average case O(1) space complexity
    def fourSumO4(self, nums: List[int], target: int) -> List[List[int]]:
        indices = set()
        nums.sort()

        for i in range(0, len(nums) - 3):
            for j in range(i+1, len(nums) - 2):
                for k in range(j+1, len(nums) - 1):
                    for l in range(k+1, len(nums) - 0):
                        if nums[i] + nums[j] + nums[k] + nums[l] == target:
                            indices.add((nums[i], nums[j], nums[k], nums[l]))
        return list(indices)
    
    def removeDuplicates(self, nums: List[int], max_count: int) -> List[int]:
        counts = dict()
        for num in nums:
            if num in counts:
                counts[num] += 1
            else:
                counts[num] = 1

        results = []
        for key, value in counts.items():
            value = min(value, max_count)
            results.extend([key] * value)
        
        return results

    # I added in the removeDuplicates function to improve the worse case
    # time complexity. Since we need to look through every combination that
    # sums to the target, if we have a lot of duplicates, it can get expensive
    # fast. Note that the most a given number can be repeated is 4 times, so
    # we can remove any number of repeats greater than 4
    def fourSumO2(self, nums: List[int], target: int) -> List[List[int]]:
        print(nums)
        indices = set()
        twoSums = dict()

        nums = self.removeDuplicates(nums, 4)
        nums.sort()
        print(nums)

        for i in range(0, len(nums) - 1):
            for j in range(i+1, len(nums)):
                sum = nums[i] + nums[j]
                if sum in twoSums:
                    twoSums[sum].append((i, j))
                else:
                    twoSums[sum] = [(i, j)]
        
        for i in range(0, len(nums) -3):
            for j in range(i+1, len(nums) - 2):
                residual = target - nums[i] - nums[j]
                if residual in twoSums:
                    for k, l in twoSums[residual]:
                        if i != k and i != l and j != k and j != l:
                            indices.add(tuple(sorted([nums[i], nums[j], nums[k], nums[l]])))
        return list(indices)

        
    
    def fourSum(self, nums: List[int], target: int) -> List[List[int]]:
        #return self.fourSumO4(nums, target)
        return self.fourSumO2(nums, target)


In [56]:
# test 1
nums = [1,0,-1,0,-2,2,10]
target = 0
print(Solution().fourSum(nums, target)) 
print()

# test 2
nums = [2,2,2,2,2]
target = 8
print(Solution().fourSum(nums, target))
print()

# test 2
nums = [2,2,2,2,2,3,3,3,3,3,3,3,3,10,10,10,10,10,10,10]
target = 8
print(Solution().fourSum(nums, target))
print()

[1, 0, -1, 0, -2, 2, 10]
[-2, -1, 0, 0, 1, 2, 10]
[(-2, -1, 1, 2), (-1, 0, 0, 1), (-2, 0, 0, 2)]

[2, 2, 2, 2, 2]
[2, 2, 2, 2]
[(2, 2, 2, 2)]

[2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 10, 10, 10, 10, 10, 10, 10]
[2, 2, 2, 2, 3, 3, 3, 3, 10, 10, 10, 10]
[(2, 2, 2, 2)]

