# Segment Tree

1. [Range sum: point update, range query, static](#another_cell)
2. Range sum: range update, range query, static, lazy propagation
3. Range sum: point update, range query, dynamic

[Section 1](#section_1)

# section 1
<a id='section_1'>as<a>

In [None]:
class Node:
    def __init__(self, start: int, end: int, val, left: 'Node' = None, right: 'Node' = None):
        """
        Segment tree node.
        :param start: interval start (inclusive)
        :param end: interval end (inclusive)
        :param val: value store in the node
        :param left: left child
        :param right: right child
        """
        self.start = start
        self.end = end
        self.val = val
        self.left = left
        self.right = right

    def mid(self):
        return (self.start + self.end) // 2

class SegmentTree:
    """
    Range sum.

    Point update, range query, static
    """
    def __init__(self, nums):
        self.nums = list(nums)
        self.root = self._build(0, len(nums) - 1)

    def _build(self, l: int, r: int):
        if l == r:
            return Node(l, r, self.nums[l])
        else:
            mid = (l + r) // 2
            left_child = self._build(l, mid)
            right_child = self._build(mid + 1, r)
            val = left_child.val + right_child.val
            return Node(l, r, val, left_child, right_child)

    def update(self, node: Node, idx: int, val):
        if node.start == node.end:
            node.val = val
        else:
            mid = node.mid()
            if idx <= mid:
                self.update(node.left, idx, val)
            else:
                self.update(node.right, idx, val)
            # update parent
            node.val = node.left.val + node.right.val

    def sum(self, node: Node, l: int, r: int):
        if l <= node.start and node.end <= r:
            return node.val
        mid = node.mid()
        result = 0
        if l <= mid:
            result += self.sum(node.left, l, r)
        if r > mid:
            result += self.sum(node.right, l, r)
        return result

In [26]:
nums = [1, 2, 3, 4, 5]
st = SegmentTree(nums)
assert st.sum(st.root, 0, 4) == sum(nums[0: 5])
assert st.sum(st.root, 1, 4) == sum(nums[1: 5])
assert st.sum(st.root, 2, 4) == sum(nums[2: 5])
assert st.sum(st.root, 2, 3) == sum(nums[2: 4])

st.update(st.root, 2, -1)
nums[2] = -1
assert st.sum(st.root, 0, 4) == sum(nums[0: 5])
assert st.sum(st.root, 1, 4) == sum(nums[1: 5])
assert st.sum(st.root, 2, 4) == sum(nums[2: 5])
assert st.sum(st.root, 2, 3) == sum(nums[2: 4])

In [None]:
class Node:
    def __init__(self, start: int, end: int, val, left: 'Node' = None, right: 'Node' = None):
        """
        Segment tree node.
        :param start: interval start (inclusive)
        :param end: interval end (inclusive)
        :param val: value store in the node
        :param left: left child
        :param right: right child
        """
        self.start = start
        self.end = end
        self.val = val
        self.left = left
        self.right = right

    def mid(self):
        return (self.start + self.end) // 2

class SegmentTree:
    """
    Range sum.

    Range update, range query, static
    """
    def __init__(self, nums):
        self.nums = list(nums)
        self.root = self._build(0, len(nums) - 1)

    def _build(self, l: int, r: int):
        if l == r:
            return Node(l, r, self.nums[l])
        else:
            mid = (l + r) // 2
            left_child = self._build(l, mid)
            right_child = self._build(mid + 1, r)
            val = left_child.val + right_child.val
            return Node(l, r, val, left_child, right_child)

    def add(self, node: Node, l: int, r: int, delta):
        if l <= node.start and node.end <= r:
            # add delta to all elements in the range
            node.val += (node.end - node.start + 1) * delta
        else:
            mid = node.mid()
            if idx <= mid:
                self.update(node.left, idx, val)
            else:
                self.update(node.right, idx, val)
            # update parent
            node.val = node.left.val + node.right.val

    def sum(self, node: Node, l: int, r: int):
        if l <= node.start and node.end <= r:
            return node.val
        mid = node.mid()
        result = 0
        if l <= mid:
            result += self.sum(node.left, l, r)
        if r > mid:
            result += self.sum(node.right, l, r)
        return result