@@ -125,7 +125,6 @@ def _accumulate_common(
125125 if a1 != nd :
126126 out = dpt .permute_dims (out , perm )
127127
128- final_ev = dpctl .SyclEvent ()
129128 _manager = SequentialOrderManager [q ]
130129 depends = _manager .submitted_events
131130 if implemented_types :
@@ -144,12 +143,11 @@ def _accumulate_common(
144143 _manager .add_event_pair (ht_e , acc_ev )
145144 if not (orig_out is None or out is orig_out ):
146145 # Copy the out data from temporary buffer to original memory
147- ht_e_cpy , acc_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
146+ ht_e_cpy , cpy_e = ti ._copy_usm_ndarray_into_usm_ndarray (
148147 src = out , dst = orig_out , sycl_queue = q , depends = [acc_ev ]
149148 )
150- _manager .add_event_pair (ht_e_cpy , acc_ev )
149+ _manager .add_event_pair (ht_e_cpy , cpy_e )
151150 out = orig_out
152- final_ev = acc_ev
153151 else :
154152 if _dtype_supported (res_dt , res_dt ):
155153 tmp = dpt .empty (
@@ -160,21 +158,21 @@ def _accumulate_common(
160158 )
161159 _manager .add_event_pair (ht_e_cpy , cpy_e )
162160 if not include_initial :
163- ht_e , final_ev = _accumulate_fn (
161+ ht_e , acc_ev = _accumulate_fn (
164162 src = tmp ,
165163 trailing_dims_to_accumulate = 1 ,
166164 dst = out ,
167165 sycl_queue = q ,
168166 depends = [cpy_e ],
169167 )
170168 else :
171- ht_e , final_ev = _accumulate_include_initial_fn (
169+ ht_e , acc_ev = _accumulate_include_initial_fn (
172170 src = tmp ,
173171 dst = out ,
174172 sycl_queue = q ,
175173 depends = [cpy_e ],
176174 )
177- _manager .add_event_pair (ht_e , final_ev )
175+ _manager .add_event_pair (ht_e , acc_ev )
178176 else :
179177 buf_dt = _default_accumulation_type_fn (inp_dt , q )
180178 tmp = dpt .empty (
@@ -190,25 +188,25 @@ def _accumulate_common(
190188 if a1 != nd :
191189 tmp_res = dpt .permute_dims (tmp_res , perm )
192190 if not include_initial :
193- ht_e , a_e = _accumulate_fn (
191+ ht_e , acc_ev = _accumulate_fn (
194192 src = tmp ,
195193 trailing_dims_to_accumulate = 1 ,
196194 dst = tmp_res ,
197195 sycl_queue = q ,
198196 depends = [cpy_e ],
199197 )
200198 else :
201- ht_e , a_e = _accumulate_include_initial_fn (
199+ ht_e , acc_ev = _accumulate_include_initial_fn (
202200 src = tmp ,
203201 dst = tmp_res ,
204202 sycl_queue = q ,
205203 depends = [cpy_e ],
206204 )
207- _manager .add_event_pair (ht_e , a_e )
208- ht_e_cpy2 , final_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
209- src = tmp_res , dst = out , sycl_queue = q , depends = [a_e ]
205+ _manager .add_event_pair (ht_e , acc_ev )
206+ ht_e_cpy2 , cpy_e2 = ti ._copy_usm_ndarray_into_usm_ndarray (
207+ src = tmp_res , dst = out , sycl_queue = q , depends = [acc_ev ]
210208 )
211- _manager .add_event_pair (ht_e_cpy2 , final_ev )
209+ _manager .add_event_pair (ht_e_cpy2 , cpy_e2 )
212210
213211 if appended_axis :
214212 out = dpt .squeeze (out )
0 commit comments