-
Notifications
You must be signed in to change notification settings - Fork 0
/
segmented tree.py
125 lines (96 loc) · 3.9 KB
/
segmented tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
The idea here is to build a segment tree. Each node stores the left and right
endpoint of an interval and the sum of that interval. All of the leaves will store
elements of the array and each internal node will store sum of leaves under it.
Creating the tree takes O(n) time. Query and updates are both O(log n).
implemented by tejask91
"""
#Segment tree node
class Node(object):
def __init__(self, start, end):
self.start = start
self.end = end
self.total = 0
self.left = None
self.right = None
class NumArray(object):
def __init__(self, nums):
"""
initialize your data structure here.
:type nums: List[int]
"""
#helper function to create the tree from input array
def createTree(nums, l, r):
#base case
if l > r:
return None
#leaf node
if l == r:
n = Node(l, r)
n.total = nums[l]
return n
mid = (l + r) // 2
root = Node(l, r)
#recursively build the Segment tree
root.left = createTree(nums, l, mid)
root.right = createTree(nums, mid+1, r)
#Total stores the sum of all leaves under root
#i.e. those elements lying between (start, end)
root.total = root.left.total + root.right.total
return root
self.root = createTree(nums, 0, len(nums)-1)
def update(self, i, val):
"""
:type i: int
:type val: int
:rtype: int
"""
#Helper function to update a value
def updateVal(root, i, val):
#Base case. The actual value will be updated in a leaf.
#The total is then propogated upwards
if root.start == root.end:
root.total = val
return val
mid = (root.start + root.end) // 2
#If the index is less than the mid, that leaf must be in the left subtree
if i <= mid:
updateVal(root.left, i, val)
#Otherwise, the right subtree
else:
updateVal(root.right, i, val)
#Propogate the changes after recursive call returns
root.total = root.left.total + root.right.total
return root.total
return updateVal(self.root, i, val)
def sumRange(self, i, j):
"""
sum of elements nums[i..j], inclusive.
:type i: int
:type j: int
:rtype: int
"""
#Helper function to calculate range sum
def rangeSum(root, i, j):
#If the range exactly matches the root, we already have the sum
if root.start == i and root.end == j:
return root.total
mid = (root.start + root.end) // 2
#If end of the range is less than the mid, the entire interval lies
#in the left subtree
if j <= mid:
return rangeSum(root.left, i, j)
#If start of the interval is greater than mid, the entire inteval lies
#in the right subtree
elif i >= mid + 1:
return rangeSum(root.right, i, j)
#Otherwise, the interval is split. So we calculate the sum recursively,
#by splitting the interval
else:
return rangeSum(root.left, i, mid) + rangeSum(root.right, mid+1, j)
return rangeSum(self.root, i, j)
# Your NumArray object will be instantiated and called as such:
# numArray = NumArray(nums)
# numArray.sumRange(0, 1)
# numArray.update(1, 10)
# numArray.sumRange(1, 2)