class Solution:
def findMaxLength(self, nums: List[int]) -> int:
p = [0]
for n in nums:
p.append(p[-1] + (1 if n else -1))
dit = {}
ans = 0
for i in range(2,len(p)):
if p[i - 2] not in dit :
dit[p[i - 2]] = i - 2
if p[i] in dit:
ans = max(ans, i - dit[p[i]])
return ans