3737"""
3838from __future__ import annotations
3939
40- from typing import Callable , TypeVar
40+ from typing import Any , Callable , Generic , TypeVar
4141
4242T = TypeVar ("T" )
4343
4444
45- class SegmentTree :
45+ class SegmentTree ( Generic [ T ]) :
4646 def __init__ (self , arr : list [T ], fnc : Callable [[T , T ], T ]) -> None :
4747 """
4848 Segment Tree constructor, it works just with commutative combiner.
@@ -55,8 +55,10 @@ def __init__(self, arr: list[T], fnc: Callable[[T, T], T]) -> None:
5555 ... lambda a, b: (a[0] + b[0], a[1] + b[1])).query(0, 2)
5656 (6, 9)
5757 """
58- self .N = len (arr )
59- self .st = [None for _ in range (len (arr ))] + arr
58+ any_type : Any | T = None
59+
60+ self .N : int = len (arr )
61+ self .st : list [T ] = [any_type for _ in range (self .N )] + arr
6062 self .fn = fnc
6163 self .build ()
6264
@@ -83,7 +85,7 @@ def update(self, p: int, v: T) -> None:
8385 p = p // 2
8486 self .st [p ] = self .fn (self .st [p * 2 ], self .st [p * 2 + 1 ])
8587
86- def query (self , l : int , r : int ) -> T : # noqa: E741
88+ def query (self , l : int , r : int ) -> T | None : # noqa: E741
8789 """
8890 Get range query value in log(N) time
8991 :param l: left element index
@@ -101,7 +103,8 @@ def query(self, l: int, r: int) -> T: # noqa: E741
101103 7
102104 """
103105 l , r = l + self .N , r + self .N # noqa: E741
104- res = None
106+
107+ res : T | None = None
105108 while l <= r : # noqa: E741
106109 if l % 2 == 1 :
107110 res = self .st [l ] if res is None else self .fn (res , self .st [l ])
@@ -135,7 +138,7 @@ def query(self, l: int, r: int) -> T: # noqa: E741
135138 max_segment_tree = SegmentTree (test_array , max )
136139 sum_segment_tree = SegmentTree (test_array , lambda a , b : a + b )
137140
138- def test_all_segments ():
141+ def test_all_segments () -> None :
139142 """
140143 Test all possible segments
141144 """
0 commit comments