### segment tree

    Can be used to solve range min/max queries, sum queries and range update queries in O(logn)

In [1]:
# Time Complexity of construction - O(n)
# Leaf nodes = 2^h
# Total nodes =  2*leaf_nodes - 1
# Sum TC = O(logn)
# Update TC = O(lgn)

import math

def sum_helper(st, low, high, rs, re, i):
    if rs <= low and re >= high:
        return st[i]

    if rs > high or re < low:
        return 0
    mid = (low + high) // 2

    return sum_helper(st,   low,   mid,  rs, re, 2*i + 1) + \
           sum_helper(st, mid + 1, high, rs, re, 2*i + 2)


def calc_sum(st, n, start, end):
    if start < 0 or end >= n or start > end:
        return -1

    return sum_helper(st, 0, n-1, start, end, 0)

def update_helper(st, arr, low, high, index, diff, i):
    if index < low or index > high:
        return
    
    st[i] += diff
    
    if low != high:
        mid = (low + high) // 2
        update_helper(st, arr, low,   mid, index, diff, 2*i + 1)
        update_helper(st, arr, mid+1, high, index, diff, 2*i + 2)
    

def update(st, arr, n, index, value):
    if index < 0 or index > n-1:
        return -1
    
    diff = value - arr[index]
    arr[index] = value
    update_helper(st, arr, 0, n-1, index, diff, 0)
    

def construct_helper(arr, st, low, high, i):
    if low == high:
        st[i] = arr[high]
        return st[i]
    mid = (low + high) // 2
    st[i] = construct_helper(arr, st, low, mid, 2*i + 1) + construct_helper(arr, st, mid + 1, high, 2*i + 2)
    return st[i]
    

def construct_segment_tree(arr, n):
    height = int(math.ceil(math.log2(n)))
    total_nodes = 2 * (2 ** height) - 1
    st = [0] * total_nodes
    low = 0
    construct_helper(arr, st, low, n-1, 0)
    return st
    
    
arr = [1, 3, 5, 7, 9, 11, 15, 16];  
n = len(arr)
st = construct_segment_tree(arr, n)
start = 1
end = 4
# print(st)
print(calc_sum(st, n, start, end))

index = 3
value = 17
update(st, arr, n, index, value)

print(calc_sum(st, n, start, end))

# outside range, so no change in sum
index = 7
value = 26
update(st, arr, n, index, value)

print(calc_sum(st, n, start, end))


24
34
34
