Skip to content

Commit 06c1d82

Browse files
committed
Add Segment Tree implementation in Python
1 parent 7530a41 commit 06c1d82

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""
2+
A Segment Tree is a binary tree data structure used for efficiently answering
3+
range queries and updates on an array, such as sum, minimum, or maximum over
4+
a subrange. It offers O(log n) time complexity for both queries and updates,
5+
making it very efficient compared to a naive O(n) approach.
6+
7+
While building the tree takes O(n) time and the tree requires O(n) space,
8+
this preprocessing enables fast range queries that would otherwise be slow.
9+
Segment Trees are especially useful when the array is mutable and queries
10+
and updates are intermixed.
11+
12+
Time Complexity:
13+
- Build: O(n)
14+
- Query: O(log n)
15+
- Update: O(log n)
16+
17+
Example usage and doctests:
18+
19+
>>> data = [1, 2, 3, 4, 5]
20+
>>> st = SegmentTree(data)
21+
>>> st.query(1, 4)
22+
9
23+
>>> st.update(2, 10)
24+
>>> st.query(1, 4)
25+
16
26+
"""
27+
28+
29+
class SegmentTree:
30+
"""Segment Tree for efficient range sum queries."""
31+
32+
def __init__(self, data: list[int]):
33+
"""Initialize the segment tree with the input data.
34+
35+
Args:
36+
data (list[int]): List of integers to build the segment tree.
37+
"""
38+
self.n = len(data)
39+
self.tree = [0] * (2 * self.n)
40+
# Build the tree
41+
for i in range(self.n):
42+
self.tree[self.n + i] = data[i]
43+
for i in range(self.n - 1, 0, -1):
44+
self.tree[i] = self.tree[i << 1] + self.tree[i << 1 | 1]
45+
46+
def update(self, index: int, value: int) -> None:
47+
"""Update element at index with a new value.
48+
49+
Args:
50+
index (int): Index of the element to update.
51+
value (int): New value to set at the given index.
52+
"""
53+
if index < 0 or index >= self.n:
54+
raise ValueError("Index out of bounds")
55+
index += self.n
56+
self.tree[index] = value
57+
while index > 1:
58+
index >>= 1
59+
self.tree[index] = self.tree[index << 1] + self.tree[index << 1 | 1]
60+
61+
def query(self, left: int, right: int) -> int:
62+
"""Compute the sum of elements in the interval [left, right).
63+
64+
Args:
65+
left (int): Left index (inclusive).
66+
right (int): Right index (exclusive).
67+
68+
Returns:
69+
int: Sum of elements from left to right-1.
70+
71+
Raises:
72+
ValueError: If indices are out of bounds or left >= right.
73+
"""
74+
if left < 0 or right > self.n or left >= right:
75+
raise ValueError("Invalid query range")
76+
res = 0
77+
left += self.n
78+
right += self.n
79+
while left < right:
80+
if left & 1:
81+
res += self.tree[left]
82+
left += 1
83+
if right & 1:
84+
right -= 1
85+
res += self.tree[right]
86+
left >>= 1
87+
right >>= 1
88+
return res
89+
90+
91+
if __name__ == "__main__":
92+
import doctest
93+
94+
data = [1, 2, 3, 4, 5]
95+
st = SegmentTree(data)
96+
print("Initial sum 1-4:", st.query(1, 4))
97+
st.update(2, 10)
98+
print("Updated sum 1-4:", st.query(1, 4))
99+
100+
doctest.testmod()

0 commit comments

Comments
 (0)