## Q2.4 Partition

Given a singly linked list, and a value x, partition the list by x.
Such that all values smaller than x appear before x, all values greater than x appear after x.
If x in the list, nodes with value x can appear in any order.

e.g. 
- input: 3->5->8->5->10->2->1, partition = 5
- result: 3->1->2->10->5->5->8

Input: 
Linked list ll

Return: 
nothing

In [22]:
""" 
Solution1:

1. traverse the list and get tail and length
2. traverse again
    - delete the node >= x
    - append the node after tail

space O(1)
time O(n)

"""

def partition(sll, x):
    if not sll or not sll.head:
        return False
    
    # get tail and length
    length = 1
    tail = sll.head
    while tail.next:
        length += 1
        tail = tail.next
    
    # delete and append node >= x
    head = sll.head
    if head.val >= x:
        
        # only 1 elem in the list, no need to partition
        if not head.next:
            return True
        
        # update head with next
        new_head = head.next
        sll.head = new_head
        
        # head becomes new tail 1. clear next, 2. connect with prev_tail, 3. set as tail
        head.next = None
        tail.next = head
        tail = head
    
    curr = sll.head
    for i in range(length-1):
        if curr.next.val >= x:
            if curr.next is not tail:
                _next = deleteNext(curr)
                tail = appendToTail(tail,_next)
        else:
            curr = curr.next
    
    return True

def deleteNext(node):
    _next = node.next
    node.next = _next.next
    _next.next = None
    return _next

def appendToTail(tail, node):
    tail.next = node
    # return new tail
    return tail.next

In [41]:
"""
Improved version of Sol1
"""
def partition1(sll, x):
    if not sll or not sll.head:
        return False
    
    # get tail and length
    length = 1
    tail = sll.head
    while tail.next:
        length += 1
        tail = tail.next
    
    # delete and append node >= x
    curr = sll.head
    for i in range(length):
        
        if curr.val >= x:
            if curr is not tail:
                next_ = deleteNode(curr)
                tail = appendToTail(tail,next_)
                tail.next = None
        else:
            curr = curr.next

    return True

def deleteNode(node):
    """
    Same idea as Q3
    1. switch val of node and next
    2. remove and return next
    
    Note: node is not tail, so no need to checknull for node.next
    """
    
    # switch vals
    next_ = node.next
    next_val = next_.val
    
    next_.val = node.val
    node.val = next_val
    
    # remove next
    node.next = next_.next
    
    return next_

def appendToTail(tail, node):
    tail.next = node
    # return new tail
    return tail.next

In [42]:
"""
Solution2: CCI sol

Rather than append elem to the end of list if elem >= x

1. Create another linked list
2.1 if elem < x, insert at head
2.2 else, insert at tail

"""

def partition2(sll, x):
    if not sll or not sll.head:
        return False
    
    head = sll.head
    tail = head
    
    curr = sll.head
    while curr:
        _next = curr.next
        if curr.val >= x:
            # insert to tail
            tail.next = curr
            tail = tail.next 
        else:
            # insert to head
            curr.next = head
            head = curr
        curr = _next
    
    tail.next = None
    sll.head = head
    return True 

## Testing 

In [38]:
from myLinkedLists import SinglyLinkedList

def checkPartition(retList, x):
    # has meet >=  
    passBorder = False
    for val in retList:
        if val >= x:
            passBorder = True
        else:
            if passBorder:
                return False
    return True
            

def test(test_lists, partition):
    total = len(test_lists)
    correct = 0
    
    for l,x in test_lists:
        sll = SinglyLinkedList(l)
        
        print('Before: ',sll)
        partition(sll,x)
        print('After: ', sll)
        
        curr = checkPartition(sll.toList(), x)
        correct += curr
        print(curr)
        print('-'*50)
    
    print(f'{correct}/{total}')
    if correct == total:
        print('All passed')

In [39]:
# (original_list, x)
test_lists = [
    ([3,5,8,5,10,2,1], 5),
    ([1,3,5,7,100], 10)
]

In [43]:
test(test_lists, partition1)

Before:  3 -> 5 -> 8 -> 5 -> 10 -> 2 -> 1
curr: 3, sll: 3 -> 5 -> 8 -> 5 -> 10 -> 2 -> 1
curr: 5, sll: 3 -> 5 -> 8 -> 5 -> 10 -> 2 -> 1
curr: 8, sll: 3 -> 8 -> 5 -> 10 -> 2 -> 1 -> 5
curr: 5, sll: 3 -> 5 -> 10 -> 2 -> 1 -> 5 -> 8
curr: 10, sll: 3 -> 10 -> 2 -> 1 -> 5 -> 8 -> 5
curr: 2, sll: 3 -> 2 -> 1 -> 5 -> 8 -> 5 -> 10
curr: 1, sll: 3 -> 2 -> 1 -> 5 -> 8 -> 5 -> 10
After:  3 -> 2 -> 1 -> 5 -> 8 -> 5 -> 10
True
--------------------------------------------------
Before:  1 -> 3 -> 5 -> 7 -> 100
curr: 1, sll: 1 -> 3 -> 5 -> 7 -> 100
curr: 3, sll: 1 -> 3 -> 5 -> 7 -> 100
curr: 5, sll: 1 -> 3 -> 5 -> 7 -> 100
curr: 7, sll: 1 -> 3 -> 5 -> 7 -> 100
curr: 100, sll: 1 -> 3 -> 5 -> 7 -> 100
After:  1 -> 3 -> 5 -> 7 -> 100
True
--------------------------------------------------
2/2
All passed
