-
Notifications
You must be signed in to change notification settings - Fork 3
/
scan.py
105 lines (79 loc) · 3.46 KB
/
scan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import annotations
from typing import Callable
import pytreeclass as pytc
from jax import lax
from jax import numpy as jnp
from kernex._src.base import kernelOperation
from kernex._src.utils import ZIP, _offset_to_padding, cached_property, ix_, roll_view
@pytc.treeclass
class baseKernelScan(kernelOperation):
def __post_init__(self):
# if there is only one function, use the single call method
# this is faster than the multi call method
# this is because the multi call method uses lax.switch
self.__call__ = (
self.__single_call__ if len(self.funcs) == 1 else self.__multi_call__
)
def reduce_scan_func(self, func, *args, **kwargs) -> Callable:
if self.relative:
# if the function is relative, the function is applied to the view
# the result is a 1D array of the same length as the number of views
return lambda view, array: array.at[self.index_from_view(view)].set(
func(roll_view(array[ix_(*view)]), *args, **kwargs)
)
else:
return lambda view, array: array.at[self.index_from_view(view)].set(
func(array[ix_(*view)], *args, **kwargs)
)
def __single_call__(self, array, *args, **kwargs):
padded_array = jnp.pad(array, self.pad_width)
reduced_func = self.reduce_scan_func(self.funcs[0], *args, **kwargs)
def scan_body(padded_array, view):
result = reduced_func(view, padded_array).reshape(padded_array.shape)
return result, result[self.index_from_view(view)]
return lax.scan(scan_body, padded_array, self.views)[1].reshape(
self.output_shape
)
def __multi_call__(self, array, *args, **kwargs):
padded_array = jnp.pad(array, self.pad_width)
reduced_funcs = tuple(
self.reduce_scan_func(func, *args, **kwargs) for func in self.funcs[::-1]
)
def scan_body(padded_array, view):
result = lax.switch(
self.func_index_from_view(view), reduced_funcs, view, padded_array
).reshape(padded_array.shape)
return result, result[self.index_from_view(view)]
return lax.scan(scan_body, padded_array, self.views)[1].reshape(
self.output_shape
)
@pytc.treeclass
class kernelScan(baseKernelScan):
def __init__(self, func_dict, shape, kernel_size, strides, padding, relative):
super().__init__(func_dict, shape, kernel_size, strides, padding, relative)
def __call__(self, array, *args, **kwargs):
return self.__call__(array, *args, **kwargs)
@pytc.treeclass
class offsetKernelScan(kernelScan):
def __init__(self, func_dict, shape, kernel_size, strides, offset, relative):
self.offset = offset
super().__init__(
func_dict,
shape,
kernel_size,
strides,
_offset_to_padding(offset, kernel_size),
relative,
)
@cached_property
def __set_indices__(self):
return tuple(
jnp.arange(x0, di - xf, si)
for di, ki, si, (x0, xf) in ZIP(
self.shape, self.kernel_size, self.strides, self.offset
)
)
def __call__(self, array, *args, **kwargs):
result = self.__call__(array, *args, **kwargs)
assert result.shape <= array.shape, "scan operation output must be scalar."
return array.at[ix_(*self.__set_indices__)].set(result)