# Segment Tree

1. [Range Sum I: point update, range query, static](#Range-Sum-I)
2. [Range Sum II: range update, range query, static, lazy propagation](#Range-Sum-II)
3. [Range Sum III: point update, range query, dynamic](#Range-Sum-III)
4. [Range Sum IV: range update, range query, dynamic, lazy propagation](#Range-Sum-IV)

# Range Sum I

In [1]:
class Node:
    def __init__(self, start: int, end: int, val, left: 'Node' = None, right: 'Node' = None):
        """
        Segment tree node.
        :param start: interval start (inclusive)
        :param end: interval end (inclusive)
        :param val: value store in the node
        :param left: left child
        :param right: right child
        """
        self.start = start
        self.end = end
        self.val = val
        self.left = left
        self.right = right

    def mid(self):
        return (self.start + self.end) // 2

class SegmentTree:
    def __init__(self, nums):
        self.nums = list(nums)
        self.root = self._build(0, len(nums) - 1)

    def _build(self, l: int, r: int):
        if l == r:
            return Node(l, r, self.nums[l])
        else:
            mid = (l + r) // 2
            left_child = self._build(l, mid)
            right_child = self._build(mid + 1, r)
            val = left_child.val + right_child.val
            return Node(l, r, val, left_child, right_child)

    def update(self, node: Node, idx: int, val):
        if node.start == node.end:
            node.val = val
        else:
            mid = node.mid()
            if idx <= mid:
                self.update(node.left, idx, val)
            else:
                self.update(node.right, idx, val)
            # update parent
            node.val = node.left.val + node.right.val

    def sum(self, node: Node, l: int, r: int):
        if l <= node.start and node.end <= r:
            return node.val
        mid = node.mid()
        result = 0
        if l <= mid:
            result += self.sum(node.left, l, r)
        if r > mid:
            result += self.sum(node.right, l, r)
        return result

In [2]:
nums = [1, 2, 3, 4, 5]
st = SegmentTree(nums)
assert st.sum(st.root, 0, 4) == sum(nums[0: 5])
assert st.sum(st.root, 1, 4) == sum(nums[1: 5])
assert st.sum(st.root, 2, 4) == sum(nums[2: 5])
assert st.sum(st.root, 2, 3) == sum(nums[2: 4])

st.update(st.root, 2, -1)
nums[2] = -1
assert st.sum(st.root, 0, 4) == sum(nums[0: 5])
assert st.sum(st.root, 1, 4) == sum(nums[1: 5])
assert st.sum(st.root, 2, 4) == sum(nums[2: 5])
assert st.sum(st.root, 2, 3) == sum(nums[2: 4])

# Range Sum II

In [27]:
class Node:
    def __init__(self, start: int, end: int, val, left: 'Node' = None, right: 'Node' = None, delta=0):
        """
        Segment tree node.
        :param start: interval start (inclusive)
        :param end: interval end (inclusive)
        :param val: value store in the node
        :param left: left child
        :param right: right child
        """
        self.start = start
        self.end = end
        self.val = val
        self.delta = delta
        self.left = left
        self.right = right

    def mid(self):
        return (self.start + self.end) // 2

class SegmentTree:
    def __init__(self, nums):
        self.nums = list(nums)
        self.root = self._build(0, len(nums) - 1)

    def _build(self, l: int, r: int):
        if l == r:
            return Node(l, r, self.nums[l])
        else:
            mid = (l + r) // 2
            left_child = self._build(l, mid)
            right_child = self._build(mid + 1, r)
            val = left_child.val + right_child.val
            return Node(l, r, val, left_child, right_child)

    def _push(self, node: Node):
        delta = node.delta
        node.delta = 0
        if delta != 0:
            node.left.val += delta * (node.left.end - node.left.start + 1)
            node.right.val += delta * (node.right.end - node.right.start + 1)
            node.left.delta += delta
            node.right.delta += delta

    def add(self, node: Node, l: int, r: int, delta):
        if l <= node.start and node.end <= r:
            # add delta to all elements in the range
            node.val += (node.end - node.start + 1) * delta
            node.delta += delta
        else:
            self._push(node)
            mid = node.mid()
            if l <= mid:
                self.add(node.left, l, r, delta)
            if r > mid:
                self.add(node.right, l, r, delta)
            # update parent
            node.val = node.left.val + node.right.val

    def sum(self, node: Node, l: int, r: int):
        if l <= node.start and node.end <= r:
            return node.val
        self._push(node)
        mid = node.mid()
        result = 0
        if l <= mid:
            result += self.sum(node.left, l, r)
        if r > mid:
            result += self.sum(node.right, l, r)
        return result

In [30]:
nums = [1, 2, 3, 4, 5]
st = SegmentTree(nums)
assert st.sum(st.root, 0, 4) == sum(nums[:5])
st.add(st.root, 0, 1, 1)
for i in range(0, 2):
    nums[i] += 1
assert st.sum(st.root, 0, 4) == sum(nums[:5])
assert st.sum(st.root, 0, 1) == sum(nums[:2])

st.add(st.root, 0, 2, 1)
for i in range(0, 3):
    nums[i] += 1
assert st.sum(st.root, 0, 4) == sum(nums[:5])
assert st.sum(st.root, 0, 1) == sum(nums[:2])
assert st.sum(st.root, 2, 4) == sum(nums[2:5])

# Range Sum III

In [40]:
class Node:
    def __init__(self, start: int, end: int, val, left: 'Node' = None, right: 'Node' = None):
        """
        Segment tree node.
        :param start: interval start (inclusive)
        :param end: interval end (inclusive)
        :param val: value store in the node
        :param left: left child
        :param right: right child
        """
        self.start = start
        self.end = end
        self.val = val
        self.left = left
        self.right = right

    def mid(self):
        return (self.start + self.end) // 2

class SegmentTree:
    def __init__(self, start: int, end: int):
        self.root = Node(start, end, 0)

    def update(self, node: Node, idx: int, val):
        if node.start == node.end:
            node.val = val
        else:
            mid = node.mid()
            # create left and right child nodes
            if not node.left:
                node.left = Node(node.start, mid, 0)
            if not node.right:
                node.right = Node(mid + 1, node.end, 0)

            if idx <= mid:
                self.update(node.left, idx, val)
            else:
                self.update(node.right, idx, val)
            # update parent
            node.val = node.left.val + node.right.val

    def sum(self, node: Node, l: int, r: int):
        if not node:
            return 0

        if l <= node.start and node.end <= r:
            return node.val
        mid = node.mid()
        result = 0
        if l <= mid:
            result += self.sum(node.left, l, r)
        if r > mid:
            result += self.sum(node.right, l, r)
        return result

In [41]:
nums = [0] * 10
st = SegmentTree(0, 9)
assert st.sum(st.root, 0, 9) == sum(nums)
assert st.sum(st.root, 0, 0) == sum(nums[0:1])

st.update(st.root, 0, 1)
nums[0] = 1
assert st.sum(st.root, 0, 4) == sum(nums[:5])
assert st.sum(st.root, 0, 1) == sum(nums[:2])

st.update(st.root, 1, 3)
nums[1] = 3
st.update(st.root, 5, -4)
nums[5] = -4
assert st.sum(st.root, 0, 9) == sum(nums)

# Range Sum IV

In [53]:
class Node:
    def __init__(self, start: int, end: int, val, left: 'Node' = None, right: 'Node' = None, delta=0):
        """
        Segment tree node.
        :param start: interval start (inclusive)
        :param end: interval end (inclusive)
        :param val: value store in the node
        :param left: left child
        :param right: right child
        """
        self.start = start
        self.end = end
        self.val = val
        self.delta = delta
        self.left = left
        self.right = right

    def mid(self):
        return (self.start + self.end) // 2

class SegmentTree:
    def __init__(self, start, end):
        self.root = Node(start, end, 0)

    def _push(self, node: Node):
        delta = node.delta
        node.delta = 0

        # create left and right child nodes
        mid = node.mid()
        if not node.left:
            node.left = Node(node.start, mid, 0)
        if not node.right:
            node.right = Node(mid + 1, node.end, 0)

        if delta != 0:
            node.left.val += delta * (node.left.end - node.left.start + 1)
            node.right.val += delta * (node.right.end - node.right.start + 1)
            node.left.delta += delta
            node.right.delta += delta

    def add(self, node: Node, l: int, r: int, delta):
        if l <= node.start and node.end <= r:
            # add delta to all elements in the range
            node.val += (node.end - node.start + 1) * delta
            node.delta += delta
        else:
            self._push(node)
            mid = node.mid()
            if l <= mid:
                self.add(node.left, l, r, delta)
            if r > mid:
                self.add(node.right, l, r, delta)
            # update parent
            node.val = node.left.val + node.right.val

    def sum(self, node: Node, l: int, r: int):
        # No need to add this is statement, since _push always creates nodes with default val = 0
        # if not node:
        #     return 0

        if l <= node.start and node.end <= r:
            return node.val
        self._push(node)
        mid = node.mid()
        result = 0
        if l <= mid:
            result += self.sum(node.left, l, r)
        if r > mid:
            result += self.sum(node.right, l, r)
        return result

In [54]:
nums = [0] * 5
st = SegmentTree(0, 4)
assert st.sum(st.root, 1, 1) == sum(nums[1:2])
assert st.sum(st.root, 0, 4) == sum(nums[:5])
st.add(st.root, 0, 1, 1)
for i in range(0, 2):
    nums[i] += 1
assert st.sum(st.root, 0, 4) == sum(nums[:5])
assert st.sum(st.root, 0, 1) == sum(nums[:2])

st.add(st.root, 0, 2, 1)
for i in range(0, 3):
    nums[i] += 1
assert st.sum(st.root, 0, 4) == sum(nums[:5])
assert st.sum(st.root, 0, 1) == sum(nums[:2])
assert st.sum(st.root, 2, 4) == sum(nums[2:5])