# Median of 2 sorted lists

## Setup

### Imports

In [1]:
from pydantic import BaseModel
import numpy as np
import pandas as pd

### Utils

In [2]:
def pad(x):
    return [-float("inf")] + x + [float("inf")]

class Pointers():
    def __init__(self, high, target, low):
        self.high = high
        self.target = target
        self.low = low

class BorderValues():
    def __init__(self, max_low, min_high):
        self.max_low = max_low
        self.min_high = min_high

class OrderedList:
    def __init__(self, nums, target=None):
        self.list = nums
        self.padded_list = pad(nums)
        self.n = len(self.padded_list)
        self.search_idx = Pointers(low=0, target=self.n // 2, high=self.n)
        # override target if given
        if target != None:
            self.set_target(target)
        else:
            self.set_border_values()

    def set_target(self, target):
        self.search_idx.target = target
        self.set_border_values()
        
    def set_border_values(self):
        self.split = BorderValues(
            max_low=self.padded_list[self.search_idx.target - 1], 
            min_high=self.padded_list[self.search_idx.target],
        )
        
    def next_target(self, up=True):
        if up:
            self.search_idx.low = self.search_idx.target
        else:
            self.search_idx.high = self.search_idx.target
        self.search_idx.target = (self.search_idx.low + self.search_idx.high) // 2
        self.set_border_values()

    def __repr__(self):
        return str(self.list)
        
def find_median(a, b, target_index):
    a_lower_lte_b_upper = a.split.max_low <= b.split.min_high
    b_lower_lte_a_upper = b.split.max_low <= a.split.min_high
    
    if not a_lower_lte_b_upper:
        a.next_target(up=False)
        b.search_idx.target = target_index - a.search_idx.target
        b.set_border_values()
        a, b = find_median(a, b, target_index)
    elif not b_lower_lte_a_upper:
        a.next_target(up=True)
        b.search_idx.target = target_index - a.search_idx.target
        b.set_border_values()
        a, b = find_median(a, b, target_index)
        
    return a, b
    
def calculate_median(a, b):
    n = a.n + b.n
    max_low = max(a.split.max_low, b.split.max_low)
    min_high = min(a.split.min_high, b.split.min_high)
    #print(median_input_values)
    if n % 2 == 0:
        med = (max_low + min_high) / 2.0
    else:
        med = min_high
        
    return med
    
def main(nums1, nums2):
    a, b = OrderedList(nums1), OrderedList(nums2)
    # make sure a contains the shorter list
    a, b = (a, b) if a.n <= b.n else (b, a)
    target_index = (a.n + b.n) // 2
    b.set_target(target_index - a.search_idx.target)
    a, b = find_median(a, b, target_index)
    med = calculate_median(a, b)
    return a, b, med

In [3]:
def highlight_ol(ol: OrderedList):
    green_cols = list(range(ol.search_idx.low, ol.search_idx.target))
    blue_cols = list(range(ol.search_idx.target, ol.search_idx.high))
    s = (
        pd.Series(ol.padded_list).to_frame().T
        .style
        .set_properties(subset=green_cols, **{"background-color": "lightgreen"})
        .set_properties(subset=blue_cols, **{"background-color": "lightblue"})
    )
    return s

In [4]:
override_a = False
override_b = False

na_ = np.random.randint(1, 100)
nb_ = np.random.randint(1, 100)

a_ = np.random.randint(0, 60, na_).tolist()
a_.sort()

# override a
#override_a = True
if override_a:
    a_ = [1]
    na_ = len(a_)

b_ = np.random.randint(40, 100, nb_).tolist()
b_.sort()

# override b
#override_b = True
if override_b:
    b_ = [1]
    nb_ = len(b_)

c_ = np.concat([a_, b_]).astype(int).tolist()
c_.sort()
n_ = len(c_)
med_calc = np.median(c_)

display(highlight_ol(OrderedList(c_)))
print(f"calculated median = {med_calc}")

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108
0,-inf,0.0,2.0,2.0,3.0,3.0,3.0,7.0,8.0,9.0,9.0,10.0,12.0,13.0,13.0,14.0,14.0,15.0,15.0,16.0,18.0,18.0,19.0,20.0,21.0,23.0,25.0,25.0,27.0,27.0,27.0,28.0,28.0,30.0,30.0,32.0,32.0,33.0,36.0,36.0,37.0,38.0,39.0,39.0,40.0,40.0,41.0,42.0,43.0,44.0,44.0,44.0,44.0,45.0,45.0,47.0,49.0,49.0,49.0,50.0,50.0,51.0,51.0,51.0,51.0,52.0,52.0,53.0,53.0,54.0,55.0,55.0,56.0,56.0,58.0,58.0,58.0,58.0,58.0,59.0,63.0,63.0,64.0,65.0,72.0,73.0,74.0,74.0,75.0,76.0,76.0,77.0,77.0,79.0,80.0,82.0,82.0,84.0,85.0,86.0,87.0,88.0,89.0,89.0,91.0,93.0,95.0,98.0,inf


calculated median = 45.0


In [5]:
a, b, med = main(a_, b_)
try:
    display(highlight_ol(a))
    display(highlight_ol(b))
except:
    pass
finally:
    display(med)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43
0,-inf,40.0,41.0,43.0,44.0,44.0,44.0,45.0,49.0,49.0,50.0,53.0,56.0,58.0,59.0,63.0,63.0,64.0,65.0,72.0,73.0,74.0,74.0,75.0,76.0,76.0,77.0,77.0,79.0,80.0,82.0,82.0,84.0,85.0,86.0,87.0,88.0,89.0,89.0,91.0,93.0,95.0,98.0,inf


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66
0,-inf,0.0,2.0,2.0,3.0,3.0,3.0,7.0,8.0,9.0,9.0,10.0,12.0,13.0,13.0,14.0,14.0,15.0,15.0,16.0,18.0,18.0,19.0,20.0,21.0,23.0,25.0,25.0,27.0,27.0,27.0,28.0,28.0,30.0,30.0,32.0,32.0,33.0,36.0,36.0,37.0,38.0,39.0,39.0,40.0,42.0,44.0,45.0,47.0,49.0,50.0,51.0,51.0,51.0,51.0,52.0,52.0,53.0,54.0,55.0,55.0,56.0,58.0,58.0,58.0,58.0,inf


45

In [6]:
print(f"calculated med == med: {med_calc == med}")

calculated med == med: True
