## Segment Tree

Please cite this if you referenced this article, and share with your friends if you found this helpful!

Written by Brandon


##### Segment Trees  allow for O(logn) range queries and updates on an ordered collection of elements. Quick access to a copy-pastable segment tree template is essential for competitive programming.

##### 1. An exact quickhand template

        > If you are looking for quick code to copy-paste or analyze, go to the Template.


##### 2. Notes on the nuances of this topic

        > If you are looking to better an existing understanding, go to the Notes.


##### 3. A Guide to this topic, with examples.

        > If you are looking to learn about this topic from a place of little understanding, go to the Guide.

## Template

Quick access to code.

* Call the constructor on a sequence to build the segment tree

* Call tree[i] or tree[l:r] to query an index or range respectively

* Call tree[i] += x or tree[l:r] += x to add to an index or range respectively

In [5]:

#All operations are O(logn), other than constructor, which is O(n)
class SegmentTree:
    def __init__(self, arr):
        self.arr = arr
        self.n = len(arr)
        self.tree = [0] * (2 * self.n)
        self.lazy = [0] * (2 * self.n)
        for i in range(self.n):
            self.tree[self.n + i] = self.arr[i]
        for i in range(self.n - 1, 0, -1):
            self.tree[i] = self.tree[2 * i] + self.tree[2 * i + 1]
    
    #Point update
    def pointUpdate(self, i, val):
        i += self.n
        self.tree[i] = val
        while i > 1:
            i>>=1
            self.tree[i] = self.tree[2 * i] + self.tree[2 * i + 1]
    
    def update(self, l, r, val):
        l += self.n
        r += self.n
        l0, r0 = l, r
        while l <= r:
            if l & 1:
                self.tree[l] += val
                l += 1
            if r & 1 == 0:
                self.tree[r] += val
                r -= 1
            l >>= 1
            r >>= 1
        while l0 > 1:
            l0 >>= 1
            self.tree[l0] = self.tree[l0 << 1] + self.tree[l0 << 1 | 1]
        while r0 > 1:
            r0 >>= 1
            self.tree[r0] = self.tree[r0 << 1] + self.tree[r0 << 1 | 1]
    
    #Lazy propogation
    def propogate(self, i):
        for j in range(1, self.h):
            k = i >> j
            if self.lazy[k]:
                self.tree[k << 1] += self.lazy[k]
                self.tree[k << 1 | 1] += self.lazy[k]
                self.lazy[k << 1] += self.lazy[k]
                self.lazy[k << 1 | 1] += self.lazy[k]
                self.lazy[k] = 0

    def rangeUpdate(self, l, r, val):
        l += self.n
        r += self.n
        l0 = l
        r0 = r
        while l < r:
            if l & 1:
                self.tree[l] += val
                self.lazy[l] += val
                l += 1
            if r & 1:
                r -= 1
                self.tree[r] += val
                self.lazy[r] += val
            l >>= 1
            r >>= 1
        self.propogate(l0)
        self.propogate(r0 - 1)
    
    #Range query
    def query(self, l, r):
        l += self.n
        r += self.n
        ans = 0
        while l < r:
            if l & 1:
                ans += self.tree[l]
                l += 1
            if r & 1:
                r -= 1
                ans += self.tree[r]
            l >>= 1
            r >>= 1
        return ans
    

    #Make this just print the underlying structure; not the whole tree because that's hard to read
    #<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<,
    def __str__(self):
        return str(self.tree)
    
    def __repr__(self):
        return str(self.tree)
    
    def __setitem__(self, i, val):
        if type(i) == type(1):
            i += self.n
        self.update(i, val)
    
    #Remember to add support for negative indices
    def __getitem__(self, i):
        match type(i):
            case slice:
                if i.start == None:
                    i.start = 0
                if i.stop == None:
                    i.stop = self.n
                return self.query(i.start, i.stop)
            
            
            

In [None]:
#Example:
tree = SegmentTree([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

print(tree[0])
print(tree[5])

#Can query ranges with slices
print(tree[0:5])
print(tree[:5])
print(tree[5:10])
print(tree[5:])
print(tree[0:10])
print(tree[:])

#Can add to ranges with slices as well
tree[0:5] += 10
tree[0:5] -= 5

#Can only set values explicitly if it's one point
tree[0] = 10
#tree[0:5] = 10 doesn't mean anything if there isn't an external context

#Step is ignored
print(tree[0:10])
print(tree[0:10:2])

#Can set and get negative indices

## Notes

Segment trees are one of the most important data structures. Their place as a comprimise in time complexity between a regular array and prefix array (talked about in the introduction) may seem niche. But between these three options, it's all but the default choice of the three.

This is largely because O(logn) time complexity can almost always be considered synonymous with constant time. For example, take a collection of 2^32 elements (around 10 billion elements). O(n) access takes 2^32 times longer than O(1), whereas O(1) access only takes 32 times longer. Most of the time this might as well be considered constant. In the frequent situation where both range queries and element updates are needed, the choice between an array of original elements, and a prefix array is untenable. Unless there are constant updates or constant queries (not only that, but less than 32), a segment tree is faster than an array or prefix sum. The overhead in terms of sheer code complexity, particularly when it comes to specialized queries, is the only salient drawback. This is similar to the simplicity of heaps, as opposed to the more versatile multisets.

On the topic of multisets, I have long held the belief that the majority of problems that demand the use of a segment tree can be solved using a multiset. That is to say, most solutions that incorporate O(logn) range queries and updates to a structure can be rewritten to utilize O(logn) insertion/deletion in an ordered structure. While of course there many problems that require the use of a segment tree, it's often worth spending a few seconds pondering the replacement of a segment tree with a simpler, built-in data-structure.

Segment trees, as with other tree structures, can be written using pointers, or using an array. Here, an array structure is implemented, primarily because the size of the segment tree is often fixed, which allows for a standardized, heap-like array representation of a tree structure.

## FAQ:

As segment trees are intermediate data structures, they require a baseline of knowledge that extends beyond the scope of this article. Here is a list of useful questions to ask and answer before continuing.


* How do you implement a tree structure as an array?
    There are multiple ways to do this. This segment tree implementation uses a heap-like structure. To see another, more versatile example of a tree structure represented in an array, look at the article on disjoint sets introduction. For a heap-like structure, we choose index 0 to represent the root, indices 1 and 2 to be its left and right children respectively, indices 3, 4, 5, and 6 to be its grandchildren, 7, 8, 9, 10, 11, 12, 13, and 14 to be the next row, and so-on. This allows us to store a value and any index, and follow the formula that for any node at index i, its children are at (i * 2)+1, and (i * 2)+2. If one of these "child" indices j extends beyond the length of the array (j >= n), then that child doesn't exist. This is the equivalent to a tree node containing a null pointer. For example, take the following array: [0,1,2,3,4,5,6]. If we wanted to represent this as a height balanced binary search tree in this format, we could re-order this array as follows: [3,1,5,0,2,4,6]. Now, the root is 3, with all values to the "left" (i.e. all nodes stemming from index 1) are strictly less, and all values to the right are strictly more. This simple rearrangement, put succinctly, allows us to execute divide and conquer algorithms on our data that would otherwise be unattainable. Note that the above re-representation of array data is *not* the formulation of a segment tree. It is simply an example of the array representation of a binary tree that heaps and segment trees use.
    

* Why is this tree representation used?
    This format uses an array to represent a height balanced, fixed-size tree. This makes it a good fit for situations where we're electing to utilize a tree structure to represent an existing collection of flattened data, which allows us to choose its structure. This is as opposed to arbitrary tree structures, which map poorly to this tree syntax. For example, a tree with one node per level would take an array of length 2^(# of nodes) to represent. But the array above only took O(n) space to re-arrange the data, because we can choose the most useful and efficient way to re-arrange the elements into a tree.

* How is a segment tree implemented, using this structure?
    Segment trees are arranged such that each node represents a range. The root node, at tree[0], represents the range of the entire tree. Each of its children, at tree[1] and tree[2], represent ranges for each half of the tree, and their combined ranges evaluate to the range represented by the root. This pattern continues down to the leaf nodes, which each represent one element in the tree. For example, take the following array: [0,1,0,1,0,1,0,1]. A segment tree intended for range sum queries would start with a root that represents the cumulative sum of the array (4), and each of its children would be ranges of the first and second halves of the array. This would be the underlying array of the segment tree: [4,2,2,1,1,1,1,0,1,0,1,0,1,0,1]. There are (2n)-1 nodes in the tree.

* What happens if the size of the initial array isn't a power of two?
    The tree is missing its rightmost nodes, and the parents of those nodes are themselves leaf nodes. For example, take the following array: [0,1,0,1,0,1]. The array for a segment tree made for sum queries of these elements would look like this: [3,1,2,1,0,1,1,]

# Time Complexities:

* Creation
    O(n) time and space. For an array of length n, the segment tree has n leaves, n//2 parents for those leaves, n//4 nodes at the next level, etc. All the way up to a single root node, at tree[0]. This sum Σ(n, n/2, n/4 ... 1) is evaluates to an array of length 2n needeed for all ranges of the segment tree.

* Update
    O(logn) time, O(1) space. This applies to both range updates and updates to single elements. This is predicated on lazy propogation, described below.

* Query
    O(logn) time, O(1) space. This applies to both range queries and single element access.

## Guide

Below is a detailed explanation of exactly what a Segment Tree is, as well as its related terms. This is meant to acquaint you with Segment Trees, if you're not familiar with them.

There are two extremely common questions that get asked about graphs:

* What is the shortest path between two nodes?
* Are two nodes reachable from one another?

Disjoint sets are what we use to find the answer to the second of those questions (the answer to the first being BFS/Dijkstra's algorithm). For this reason they're among the most commonly used approaches for answering questions about graphs. Among other things, disjoint sets can be used to ansewr:

1. Do two nodes belong to the same component?
2. How many components are in a graph?
3. How many nodes belong to a given component in a graph?

Disjoint sets are almost always used to describe the "components" of a graph. In order to talk about disjoint sets further, we need to define what a "component" is, along with a couple of other important terms.

* Component:
    A group of nodes that are all connected to one another, directly or indirectly.

* Direct vs. Indirect connection:
    Two nodes are directly connected if they share a direct edge. Two nodes are indirectly connected if they share edges with other directly or indirectly connected nodes.

Let's build a disjoint set from the ground up.

Take the following graph, in the form of an adjacency list:

In [5]:
#There are 5 nodes: 0, 1, 2, 3, and 4
n = 5
#The graph is bidirectional;
#If adj[i] contains j, then adj[j] contains i
adj = {0:[1,2], 1:[0], 2:[0], 3:[4], 4:[3]}

We can see that node 0 is connected to nodes 1 and 2 (and vice versa). We can also see that 3 and 4 are connected to each other.

Even though node 1 and node 2 aren't directly connected, we can see that they're indirectly connected through node 0. So nodes 0, 1 and 2 are all reachable from one another. But none of them can reach nodes 3 or 4. So there are two different islands, or components, in the graph. There's nodes 0, 1 and 2, and then nodes 3 and 4. We can reason through this, but how can we find this out about much larger, more complicated graphs? This is what a disjoint set is for.

Before going further into its uses, let's go over the anatomy and mechanism of a disjoint set.

In [6]:
#Root array
root = [i for i in range(n)]

The root array, sometimes called the parent array, is the central data structure. For every node, root[i] is the root node of the ith node. This represents that nodes i, and root[i] are connected. For example, if root[1] is 0, and root[2] is 0, then both 1 and 2 are connected to root 0 (and each other). Then, if we every want to check if two nodes are connected, we can check to see if their root is the same. If root[1] is 0, and root[2] is 0, then we know they're connected to the same node, and therefore root 1 and root 2 are connected.

Initially, each index in the root array points to itself. root[0] == 0, root[1] == 1, etc. This represents the initial state, that each node exists in its own component, before any edges have been introduced. The next step is to create a way to populate and access the root array. Let's start with a simplified version of union(a, b), which is how we introduce an edge between a and b:

In [7]:
#represent an edge between a and b
def union(a, b):
    root[b] = a

Now we have a way to connect two nodes. If we call this union on each edge, we can create an initial root array

In [10]:
for a in adj:
    for b in adj[a]:
        if b > a:
            union(a, b)
#Root:
#[0, 0, 0, 3, 3]

Now there are two groups of nodes. The nodes connected to node 0, and nodes connected to node 3. Now it appears that we have two distinct components in this graph.

There are a couple things to notice about this function. The order is arbitrary here. We could just as easily have said "root[a] = b." We'll return to that later. 

Another thing to notice about this function is that it can create trees with a large depth. Say we change our graph above slightly:

In [12]:
#1 connects 0 and 2, instead of 0 connecting 1 and 2
adj = {0:[1], 1:[0, 2], 2:[1], 3:[4], 4:[3]}
#Reset root
root = [i for i in range(n)]

Now 0 and 2 are connected to 1, instead of 1 and 2 being connected to 0. If we again call union, the root array will be different.

In [14]:
for a in adj:
    for b in adj[a]:
        if b > a:
            union(a, b)
#Root:
#[0, 0, 1, 3, 3]

Now node 2 leads to node 1, which leads to node 0, which leads to itself, as it's the root node of its tree of connected nodes. This isn't what we want. We want every node to point to its original root node, rather than just the node above it in the tree. This is where we need a find(a) function, in order to find and update the original root node for node a.

We always want to use find(a), rather than root[a], regardless of the context. Let's say root[a] gets set to node c, but then root[c] gets changed. Even if we initially used find(a) and find(c) to set root[a] to c, now we need to call find(a) again to set root[a] to find d, since node c now points to node d instead of itself.

In [15]:
#Fina(a): find the flag node for node a
def find(a):
    while root[a] != a:
        a = root[a]
    return a

We also potentially didn't want to change anything when we call union(a, b). For example if root[a] is b, we don't want to change root[b] to be a (this would arise for example if we blindly iterated through the adjacency list of a bidirectional graph, and accounted for each edge twice). If we do this, each node is each other's root. And if we look for the root "flag-node" for a given node by repeatedly looking at root[a] then root[ root[a]]... etc. Then we'll enter an infinite cycle. This is solved by returning from the function early if the nodes are already connected. Now we can update our union function to a workable set of functions:

In [17]:
root = [i for i in range(n)]

def find(a):
    while root[a] != a:
        a = root[a]
    return a

def union(a, b):
    #Find the root nodes for the components that contain  nodes a and b
    a, b = find(a), find(b)
    #Return early if they're already part of the same component
    if a == b: return
    #Connect these components
    root[b] = a

This is a working disjoint set. However we can introduce a couple other things to this code, not only to optimize the time complexity significantly, but also to garner more data about a given graph than just the root array.

The most unequivocal improvement to this code is *path compression*. Path compression is how the find() and union() function's time complexities optimize the most, and there's no drawback even in amount to type, so there's no reason not to include it in every disjoint set. The idea is that, for any node a, we never care about any root node in its tree (root[a], root[ root[a]], etc.) except the very top root node x, where root[x] == x. Therefore, if we ever call find(a), we might as well change root[a] to be x itself, so that we don't have to iterate up the entire height of the tree to find out that the flag node for a is x. This has cascading impacts, since then any child node of node a will travel that much shorter of a path to find root node x, before setting its own root[b] to x. Path compression changes find(a) to this:

In [None]:
def find(a):
    while root[root[a]] != root[a]:
        root[a] = root[root[a]]
    return root[a]

It may look more confusing but it means that every time find(a) is called, the height of the tree is being trimmed to the minimum height of 2, where every node points directly to its flag node. This is the find() function done.

A worst-case scenario that this accounts for significantly would be a graph that is essentially a linked list. Imagine that every node gets linked directly to its parent node. So root[a] == b, root[b] == c, root[c] == d... root[y] = x. Now in order to find out what component a node belongs to (for example to find out that node a belongs to component x), this is potentially an O(n) operation, to iteratively call root[a], root[ root[a]], and so-on. With path compression, this process shrinks to anywhere between logarithmic and constant time, depending on how many times find(a) is called compared to the amount of edgesin the graph.

There is room for another significant optimization, when union(a, b) sets root[b] = a. Since we could set root[b] = a, or root[a] = b, we may as well make this choice intelligently. If we think of these root components in the root[] array as trees, that we're building with union(), and trimming and searching with find(), we can make sure that this tree is height-balanced. This means that there's no "linked list" worst-case scenario mentioned above, where the heigh of the tree is the number of nodes in the tree. The height of a height-balanced tree is always O(logn) where n is the height of the tree. Here, we do that by choosing whether to set root[b] = a, or root[a] = b, depending on which would displace less nodes. For this we keep track of a size array, as well as a root array. Then, when we update size[a] or size[b] to consume the other tree's size, when we change root[b] or root[a] respectively:

In [4]:
#Each node starts out as its own components, with a size of 1 (itself)
root = [i for i in range(n)]
size = [1 for i in range(n)]
def find(a):
    while root[root[a]] != root[a]:
        root[a] = root[root[a]]
    return root[a]
def union(a, b):
    a, b = find(a), find(b)
    if a == b: return
    #Set the root of the smaller tree to the root of the larger tree
    if size[b] > size[a]:
        root[a] = b
        size[b]+=size[a]
    else:
        root[b] = a
        size[a]+=size[b]

123


With both of these optimizations, the time complexity is essentially constant. Not only can we find out of two nodes are part of the same connected component, but we can immediately find how many nodes are in a given component, by looking at size[x] for a given root node.

The only other thing to mention is that components can be compared by their size (number of nodes in the component) or their rank (the maximum height of the tree). Both lead to the same time complexity, for find() and union(), which is O(α(n)), the inverse-ackermann function of n, which can safely be called constant time for any concievable number n.

With this understanding, one can fully utilize disjoint sets!

Resources:
https://cp-algorithms.com/data_structures/segment_tree.html
