- 배열 같은 자료형에서 특정 구간에 속한 원소들의 연산 (합, 최댓값, 최솟값 등)을 알아볼 때 효율적임.
  - 누적합은 데이터가 변경 되었을 때 O(N)의 시간 복잡도를 가지지만, 세그먼트 트리는 O(log N)의 시간 복잡도를 가짐.
  - 배열의 특정 구간에 대한 추가를 담고 있게 된다.
- 배열을 양분하고 양분하고 양분하는 식으로 특정 구간에 속한 원소들의 연산을 구한다.
- 세그먼트 트리는 재귀 구조를 이용해 구현된다.
  - 주의할 점으로는 루트 노드 즉, 노드의 인덱스가 1부터 시작한다.

In [1]:
arr = [1, 2, 5, 3, 9, 6, 5, 3, 2]
print(f'array : {arr}')

## 세그먼트 트리의 길이는 넉넉하게 주었다.
seg_tree = [0 for _ in range(4 * len(arr))]

array : [1, 2, 5, 3, 9, 6, 5, 3, 2]


In [2]:
def build(tree, node, left, right, func):

    if left == right:
        tree[node] = arr[left]
        return tree[node]

    mid        = left + (right - left) // 2
    left_val   = build(tree,     2 * node,    left,   mid, func)
    right_val  = build(tree, 2 * node + 1, mid + 1, right, func)

    tree[node] = func(left_val, right_val)
    return tree[node]


def update(tree, idx, val, node, left, right, func):

    if (idx < left) or (idx > right): return tree[node]

    if left == right:
        tree[node] = val
        return tree[node]

    mid        = left + (right - left) // 2
    left_val   = update(tree, idx, val,    2 * node,    left,   mid, func)
    right_val  = update(tree, idx, val, 2 * node +1, mid + 1, right, func)
    
    tree[node] = func(left_val, right_val)
    return tree[node]


def prefix_sum(arr):

    sum_array = [0] * (len(arr) + 1)

    for idx, _ in enumerate(arr):
        sum_array[idx + 1] = sum_array[idx] + arr[idx]


    return sum_array


## merge 함수의 반환값에 따라 세그먼트 트리가 가지게 되는 정보가 달라진다.
## 아래 merge_sum 함수는 구간합 정보를 갖는 트리가 만들어진다.
merge_sum = lambda left, right: left + right
merge_mul = lambda left, right: left * right 
merge_min = lambda left, right: min(left, right)

In [3]:
%%time
build(seg_tree, 1, 0, len(arr) - 1, merge_sum)

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 38.4 µs


36

In [4]:
%%time
prefix_sum(arr)

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 33.1 µs


[0, 1, 3, 8, 11, 20, 26, 31, 34, 36]

In [5]:
build(seg_tree, 1, 0, len(arr) - 1, merge_sum)
print(f'segment tree root node              : {seg_tree[1]}')
print(f'segment tree left, right node       : {seg_tree[2]} {seg_tree[3]}')
print(f'segment tree left left, right node  : {seg_tree[4]} {seg_tree[5]}')
print(f'segment tree right left, right node : {seg_tree[6]} {seg_tree[7]}')

segment tree root node              : 36
segment tree left, right node       : 20 16
segment tree left left, right node  : 8 12
segment tree right left, right node : 11 5


In [6]:
update(seg_tree, 3, 100, 1, 0, len(arr) - 1, merge_sum)
print(f'segment tree root node              : {seg_tree[1]}')
print(f'segment tree left, right node       : {seg_tree[2]} {seg_tree[3]}')
print(f'segment tree left left, right node  : {seg_tree[4]} {seg_tree[5]}')
print(f'segment tree right left, right node : {seg_tree[6]} {seg_tree[7]}')

segment tree root node              : 133
segment tree left, right node       : 117 16
segment tree left left, right node  : 8 109
segment tree right left, right node : 11 5


In [7]:
build(seg_tree, 1, 0, len(arr) - 1, merge_min)
print(f'segment tree root node              : {seg_tree[1]}')
print(f'segment tree left, right node       : {seg_tree[2]} {seg_tree[3]}')
print(f'segment tree left left, right node  : {seg_tree[4]} {seg_tree[5]}')
print(f'segment tree right left, right node : {seg_tree[6]} {seg_tree[7]}')

segment tree root node              : 1
segment tree left, right node       : 1 2
segment tree left left, right node  : 1 3
segment tree right left, right node : 5 2


In [8]:
update(seg_tree, 5, -5, 1, 0, len(arr) - 1, merge_min)
print(f'segment tree root node              : {seg_tree[1]}')
print(f'segment tree left, right node       : {seg_tree[2]} {seg_tree[3]}')
print(f'segment tree left left, right node  : {seg_tree[4]} {seg_tree[5]}')
print(f'segment tree right left, right node : {seg_tree[6]} {seg_tree[7]}')

segment tree root node              : -5
segment tree left, right node       : 1 -5
segment tree left left, right node  : 1 3
segment tree right left, right node : -5 2


In [9]:
build(seg_tree, 1, 0, len(arr) - 1, merge_mul)
print(f'segment tree root node              : {seg_tree[1]}')
print(f'segment tree left, right node       : {seg_tree[2]} {seg_tree[3]}')
print(f'segment tree left left, right node  : {seg_tree[4]} {seg_tree[5]}')
print(f'segment tree right left, right node : {seg_tree[6]} {seg_tree[7]}')

segment tree root node              : 48600
segment tree left, right node       : 270 180
segment tree left left, right node  : 10 27
segment tree right left, right node : 30 6


In [10]:
update(seg_tree, 2, 3, 1, 0, len(arr) - 1, merge_mul)
print(f'segment tree root node              : {seg_tree[1]}')
print(f'segment tree left, right node       : {seg_tree[2]} {seg_tree[3]}')
print(f'segment tree left left, right node  : {seg_tree[4]} {seg_tree[5]}')
print(f'segment tree right left, right node : {seg_tree[6]} {seg_tree[7]}')

segment tree root node              : 29160
segment tree left, right node       : 162 180
segment tree left left, right node  : 6 27
segment tree right left, right node : 30 6
