## Q2.7 Intersection

Given two singly linked lists, return the intersection node of the two lists. Compare by reference, not value.

e.g. 
n1 = Node(1), 
n2 = Node(1), 
n3 = Node(2)
- input1:  n1 -> n2 -> n3 
- input2: n1 -> n3
- output: n1

Input: 
2 x SinglyLinkedList head1, head2

Return: 
Node node

In [4]:
from myLinkedLists import SinglyLinkedList, Node

In [5]:
""" 
Solution1:

Put all nodes of ll1 in a hashset and see if any node in ll2 in this hashset

space O(n)
time O(n)

"""

def intersection(head1, head2):
    # put head1 nodes to hashset
    nodeSet = getNodeSet(head1)
    
    # see if head2 nodes in set
    curr = head2
    while curr:
        # python == compare the object reference rather than val
        if curr in nodeSet:
            return curr
        curr = curr.next
    
    return None

def getNodeSet(node):
    nodeSet = set()
    while node:
        nodeSet.add(node)
        node = node.next
    return nodeSet

In [18]:
""" 
Solution2:

Use two pointers, one for ll1, another for traverse the whole ll2 to compare

space O(1)
time O(n^2)

"""

def intersection2(head1, head2):
    curr1 = head1
    while curr1:
        # traverse all nodes in list2
        curr2 = head2
        while curr2:
            if curr1 == curr2:
                return curr1
            curr2 = curr2.next
            
        curr1 = curr1.next
    return None

In [26]:
"""
Solution3

1. get lengths of two lists
2. if tail not equal, no intersection, return False
3. two pointers, each for a list, advance diff in length for the shorter list pointer
4. compare until first match, return the matching node

space O(1)
time O(n)

"""

def intersection3(head1, head2):
    if not head1 or not head2:
        return None
    
    l1, tail1 = getLengthAndTail(head1)
    l2, tail2 = getLengthAndTail(head2)
    
    if tail1 != tail2:
        return None
    
    short = head1 if l1 < l2 else head2
    long = head2 if l1 < l2 else head1
    
    # advance long to same length as short
    for i in range(abs(l1-l2)):
        long = long.next
        
    while short != long:
        short = short.next
        long = long.next
        
    return short

def getLengthAndTail(head):
    tail = head
    length = 1
    while tail.next:
        tail = tail.next
        length += 1
    return length, tail
    

## Testing 

In [12]:
def getRefs(head):
    refs = []
    while head:
        refs.append(str(id(head)))
        head = head.next
    return ' -> '.join(refs)

In [13]:
def append(sll, head):
    curr = sll.head
    while curr.next:
        curr = curr.next
    curr.next = head

def test(test_lists, func):
    total = len(test_lists)
    correct = 0
    
    for l1,l2,l3 in test_lists:
        sll1 = SinglyLinkedList(l1)
        sll2 = SinglyLinkedList(l2)
        sll3 = SinglyLinkedList(l3)
        
        # append sll3 to the tail of sll1 and 2
        append(sll1, sll3.head)
        append(sll2, sll3.head)
        
        sol = id(sll3.head)
        
        # print ref strings
        print('sll1: ', getRefs(sll1.head))
        print('sll2: ', getRefs(sll2.head))
        print('common: ', sol)
        
        node = func(sll1.head, sll2.head)
        print('res: ', id(node))
        
        curr = id(node) == sol
        correct += curr
        print(curr, '\n', '-'*50, sep='')
    
    print(f'{correct}/{total}')
    if correct == total:
        print('All passed')

In [14]:
# (list1, list2, append to both l1 and l2)
test_lists = [
    ([1,2,3], [1,2], [10,1,2]),
    ([2], [1,2,2,2], [3,2])
]

In [27]:
test(test_lists, intersection3)

sll1:  140644467358016 -> 140644467357232 -> 140644467356504 -> 140644467357792 -> 140644467357736 -> 140644467357960
sll2:  140644467356448 -> 140644467357288 -> 140644467357792 -> 140644467357736 -> 140644467357960
common:  140644467357792
res:  140644467357792
True
--------------------------------------------------
sll1:  140644467358408 -> 140644467356448 -> 140644467357232
sll2:  140644467358016 -> 140644467356504 -> 140644467358184 -> 140644467358128 -> 140644467356448 -> 140644467357232
common:  140644467356448
res:  140644467356448
True
--------------------------------------------------
2/2
All passed
