In [12]:
# Time: O(1) for each get and set query
# Space: O(n + n) ~ O(n)

from collections import defaultdict

class Node:
    def __init__(self, val):
        self.val = val
        self.prev = self.next = None

class LRUCache:

    def __init__(self, capacity: int):
        self.capacity = capacity
        self.current_capacity = 0
        self.head = self.tail = None
        self.cache = defaultdict(list)
    
    def update(self, key):
        node = self.cache[key][0]
        if node == self.head:    return
        if node == self.tail:
            self.tail = node.prev
            self.tail.next = None
        else:
            node.prev.next = node.next
            node.next.prev = node.prev
        node.prev = None
        node.next = self.head
        self.head.prev = node
        self.head = node

    def get(self, key: int) -> int:
        if not self.cache[key]:    return -1
        self.update(key)
        return self.cache[key][1]

    def put(self, key: int, value: int) -> None:
        if self.cache[key]:
            self.update(key)
            self.cache[key][1] = value
            return
        if self.current_capacity < self.capacity:
            node = Node(key)
            if not self.head:
                self.head = self.tail = node
            else:
                self.head.prev = node
                node.next = self.head
                self.head = node
            self.cache[key] = [node, value]
            self.current_capacity += 1
            return
        # delete lru node
        node = self.tail
        self.cache[node.val] = []
        self.tail = self.tail.prev
        if self.tail:   self.tail.next = None
        #add new node
        node = Node(key)
        node.next = self.head
        if self.head:   self.head.prev = node
        self.head = node
        self.cache[key] = [node, value]

if __name__=='__main__':
    obj = LRUCache(2)
    obj.put(1,1)
    obj.put(2,2)
    print(obj.get(1))
    obj.put(3,3)
    print(obj.get(2))
    obj.put(4,4)
    print(obj.get(1))
    print(obj.get(3))
    print(obj.get(4))

1
-1
-1
3
4
