@@ -7,6 +7,7 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
77[%% global_debug_log_level 9 ]
88[%% global_debug_log_level_from_env_var " OCANNL_LOG_LEVEL" ]
99
10+ include Backend_types. No_device_types
1011open Backend_types.Types
1112
1213let name = " cc"
@@ -18,8 +19,6 @@ let compiler_command () = Utils.get_global_arg ~default:"cc" ~arg_name:"cc_backe
1819
1920module Tn = Tnode
2021
21- type ctx_array = Ndarray .t [@@ deriving sexp_of ]
22- type ctx_arrays = ctx_array Map .M (Tn ).t [@@ deriving sexp_of ]
2322type context = { label : string ; arrays : ctx_arrays } [@@ deriving sexp_of ]
2423
2524let ctx_arrays context = context.arrays
@@ -36,7 +35,7 @@ let alloc_buffer ?old_buffer ~size_in_bytes () =
3635 | None -> assert false
3736
3837let to_buffer tn ~dst ~src =
39- let src = Map. find_exn src.arrays tn in
38+ let src = Map. find_exn src.arrays.ctx_arrays tn in
4039 Ndarray. map2 { f2 = Ndarray.A. blit } src dst
4140
4241let host_to_buffer src ~dst = Ndarray. map2 { f2 = Ndarray.A. blit } src dst
@@ -50,7 +49,9 @@ let is_initialized, initialize =
5049let finalize _ctx = ()
5150
5251let init ~label =
53- let result = { label; arrays = Map. empty (module Tn ) } in
52+ let result =
53+ { label; arrays = { used_memory = Atomic. make 0 ; ctx_arrays = Map. empty (module Tn ) } }
54+ in
5455 Stdlib.Gc. finalise finalize result;
5556 result
5657
@@ -61,7 +62,7 @@ type procedure = {
6162 name : string ;
6263 result : library ;
6364 params : (string * param_source ) list ;
64- opt_ctx_arrays : Ndarray .t Map .M ( Tn ).t option ;
65+ opt_ctx_arrays : ctx_arrays option ;
6566}
6667[@@ deriving sexp_of ]
6768
@@ -105,13 +106,14 @@ let c_compile_and_load ~f_name =
105106
106107module C_syntax_config (Input : sig
107108 val for_lowereds : Low_level .optimized array
108- val opt_ctx_arrays : ( Tn .t , buffer_ptr , Tn .comparator_witness ) Base.Map .t option
109+ val opt_ctx_arrays : ctx_arrays option
109110end ) =
110111struct
111- let for_lowereds = Input. for_lowereds
112-
113112 type nonrec ctx_array = ctx_array
113+ type nonrec ctx_arrays = ctx_arrays
114114
115+ let get_array = get_array
116+ let for_lowereds = Input. for_lowereds
115117 let opt_ctx_arrays = Input. opt_ctx_arrays
116118 let hardcoded_context_ptr = Some Ndarray. c_ptr_to_string
117119 let is_in_context = is_in_context
@@ -133,15 +135,15 @@ let%diagn_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_
133135 let opt_ctx_arrays =
134136 Option. map opt_ctx_arrays ~f: (fun ctx_arrays ->
135137 Hashtbl. fold lowered.traced_store ~init: ctx_arrays ~f: (fun ~key :tn ~data :node ctx_arrays ->
136- match Map. find ctx_arrays tn with
138+ match Map. find ctx_arrays.ctx_arrays tn with
137139 | None ->
138140 if is_in_context node then
139141 let debug = " CC compile-time ctx array for " ^ Tn. debug_name tn in
140142 let data =
141143 Ndarray. create_array ~debug (Lazy. force tn.Tn. prec) ~dims: (Lazy. force tn.dims)
142144 @@ Constant_fill { values = [| 0. |]; strict = false }
143145 in
144- Map. add_exn ctx_arrays ~key: tn ~data
146+ { ctx_arrays with ctx_arrays = Map. add_exn ctx_arrays.ctx_arrays ~key: tn ~data }
145147 else ctx_arrays
146148 | Some _ -> ctx_arrays))
147149 in
@@ -162,22 +164,25 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
162164 (lowereds : Low_level.optimized option array ) =
163165 let for_lowereds = Array. filter_map ~f: Fn. id lowereds in
164166 let opt_ctx_arrays =
165- Option. map opt_ctx_arrays ~f: (fun ctx_arrays ->
166- Array. fold for_lowereds ~init: ctx_arrays ~f: (fun ctx_arrays lowered ->
167- Hashtbl. fold lowered.traced_store ~init: ctx_arrays
168- ~f: (fun ~key :tn ~data :node ctx_arrays ->
169- match Map. find ctx_arrays tn with
170- | None ->
171- if is_in_context node then
172- let debug = " CC compile-time ctx array for " ^ Tn. debug_name tn in
173- let data =
174- Ndarray. create_array ~debug (Lazy. force tn.Tn. prec)
175- ~dims: (Lazy. force tn.dims)
176- @@ Constant_fill { values = [| 0. |]; strict = false }
177- in
178- Map. add_exn ctx_arrays ~key: tn ~data
179- else ctx_arrays
180- | Some _ -> ctx_arrays)))
167+ Option. map opt_ctx_arrays ~f: (fun arrays ->
168+ let ctx_arrays =
169+ Array. fold for_lowereds ~init: arrays.ctx_arrays ~f: (fun ctx_arrays lowered ->
170+ Hashtbl. fold lowered.traced_store ~init: ctx_arrays
171+ ~f: (fun ~key :tn ~data :node ctx_arrays ->
172+ match Map. find ctx_arrays tn with
173+ | None ->
174+ if is_in_context node then
175+ let debug = " CC compile-time ctx array for " ^ Tn. debug_name tn in
176+ let data =
177+ Ndarray. create_array ~debug (Lazy. force tn.Tn. prec)
178+ ~dims: (Lazy. force tn.dims)
179+ @@ Constant_fill { values = [| 0. |]; strict = false }
180+ in
181+ Map. add_exn ctx_arrays ~key: tn ~data
182+ else ctx_arrays
183+ | Some _ -> ctx_arrays))
184+ in
185+ { arrays with ctx_arrays })
181186 in
182187 let module Syntax = C_syntax. C_syntax (C_syntax_config (struct
183188 let for_lowereds = for_lowereds
@@ -186,7 +191,7 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
186191 (* FIXME: do we really want all of them, or only the used ones? *)
187192 let idx_params = Indexing. bound_symbols bindings in
188193 let global_ctx_arrays =
189- ref (match opt_ctx_arrays with Some ctx_arrays -> ctx_arrays | None -> Map. empty ( module Tn ) )
194+ ref (match opt_ctx_arrays with Some ctx_arrays -> ctx_arrays | None -> empty_ctx_arrays )
190195 in
191196 let base_name =
192197 String. (
@@ -206,7 +211,7 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
206211 let opt_ctx_arrays = Option. map opt_ctx_arrays ~f: (fun _ -> ! global_ctx_arrays) in
207212 ( opt_ctx_arrays,
208213 Array. mapi params ~f: (fun i params ->
209- Option. map names.(i) ~f: (fun name ->
214+ Option. map names.(i) ~f: (fun name ->
210215 {
211216 result;
212217 params = Option. value_exn ~here: [% here] params;
@@ -219,7 +224,7 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
219224 context * _ * _ * string =
220225 let label : string = prior_context.label in
221226 let name : string = code.name in
222- let arrays : Ndarray.t Base.Map.M(Tn).t =
227+ let arrays =
223228 match code with
224229 | { opt_ctx_arrays = Some arrays ; _ } -> arrays
225230 | { params; _ } ->
@@ -232,7 +237,7 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
232237 Ndarray. create_array ~debug (Lazy. force tn.Tn. prec) ~dims: (Lazy. force tn.dims)
233238 @@ Constant_fill { values = [| 0. |]; strict = false }
234239 in
235- Map. update ctx_arrays tn ~f
240+ { ctx_arrays with ctx_arrays = Map. update ctx_arrays.ctx_arrays tn ~f }
236241 | _ -> ctx_arrays)
237242 in
238243 let context = { label; arrays } in
@@ -258,7 +263,9 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
258263 let get_ptr (buffer , _ ) = Ndarray. get_voidptr_not_managed buffer in
259264 Param_2f (get_ptr, merge_buffer, link bs ps Ctypes. (ptr void @-> cs))
260265 | bs , Param_ptr tn :: ps ->
261- let nd = match Map. find arrays tn with Some nd -> nd | None -> assert false in
266+ let nd =
267+ match get_array (ctx_arrays context) tn with Some nd -> nd | None -> assert false
268+ in
262269 let c_ptr = Ndarray. get_voidptr_not_managed nd in
263270 Param_2 (ref (Some c_ptr), link bs ps Ctypes. (ptr void @-> cs))
264271 in
0 commit comments