|
| 1 | +open Base |
| 2 | +module Lazy = Utils.Lazy |
| 3 | +module Debug_runtime = Utils.Debug_runtime |
| 4 | + |
| 5 | +let _get_local_debug_runtime = Utils._get_local_debug_runtime |
| 6 | + |
| 7 | +[%%global_debug_log_level 9] |
| 8 | +[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"] |
| 9 | + |
| 10 | +module Types = struct |
| 11 | + type 'context routine = { |
| 12 | + context : 'context; |
| 13 | + schedule : Tnode.task; |
| 14 | + bindings : Indexing.lowered_bindings; |
| 15 | + name : string; |
| 16 | + } |
| 17 | + [@@deriving sexp_of] |
| 18 | + |
| 19 | + type config = Physical_devices_only | For_parallel_copying | Most_parallel_devices |
| 20 | + [@@deriving equal, sexp, variants] |
| 21 | + |
| 22 | + type merge_buffer_use = No | Streaming | Copy [@@deriving equal, sexp] |
| 23 | + |
| 24 | + type param_source = |
| 25 | + | Log_file_name |
| 26 | + | Merge_buffer |
| 27 | + | Param_ptr of Tnode.t |
| 28 | + | Static_idx of Indexing.static_symbol |
| 29 | + [@@deriving sexp_of] |
| 30 | +end |
| 31 | + |
| 32 | +module type No_device_backend = sig |
| 33 | + type code [@@deriving sexp_of] |
| 34 | + type code_batch [@@deriving sexp_of] |
| 35 | + type buffer_ptr [@@deriving sexp_of] |
| 36 | + type context [@@deriving sexp_of] |
| 37 | + type routine = context Types.routine [@@deriving sexp_of] |
| 38 | + |
| 39 | + val name : string |
| 40 | + |
| 41 | + val initialize : Types.config -> unit |
| 42 | + (** Initializes a backend before first use or (on some backends) after {!unsafe_cleanup}. Does |
| 43 | + nothing if the backend is already initialized. *) |
| 44 | + |
| 45 | + val is_initialized : unit -> bool |
| 46 | + (** Returns false if there was no previous {!initialize} call, or, on some backends, the most |
| 47 | + recent call was followed by {!unsafe_cleanup}. If it returns false, one must call |
| 48 | + {!initialize} before using the backend. *) |
| 49 | + |
| 50 | + val init : label:string -> context |
| 51 | + (** [label] is usually the backend name concatenated with the device number. *) |
| 52 | + |
| 53 | + val finalize : context -> unit |
| 54 | + (** Finalizes (just) the context. *) |
| 55 | + |
| 56 | + val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> unit -> buffer_ptr |
| 57 | + val expected_merge_node : code -> Tnode.t option |
| 58 | + val expected_merge_nodes : code_batch -> Tnode.t option array |
| 59 | + |
| 60 | + val compile : ?shared:bool -> ?name:string -> Indexing.unit_bindings -> Assignments.comp -> code |
| 61 | + (** If [~shared:true] (default [false]), the backend should prefer to do more compile work in a |
| 62 | + device-agnostic way. If [~shared:false], the backend can opt to postpone compiling altogether |
| 63 | + until [link] is called, to benefit from more optimizations. *) |
| 64 | + |
| 65 | + val compile_batch : |
| 66 | + ?shared:bool -> |
| 67 | + ?names:string array -> |
| 68 | + ?occupancy:(name:string -> src_n:int -> bool) -> |
| 69 | + Indexing.unit_bindings -> |
| 70 | + Assignments.comp array -> |
| 71 | + code_batch |
| 72 | + (** Unlike the [~shared] parameter, [compile_batch] vs. [compile] is mostly about improving the |
| 73 | + compile time and debugging convenience by generating fewer files -- ideally does not affect |
| 74 | + execution, but there can be backend-specific differences. Only array entries for which |
| 75 | + [occupancy] returns true are included. *) |
| 76 | + |
| 77 | + val link : merge_buffer:(buffer_ptr * Tnode.t) option ref -> context -> code -> routine |
| 78 | + (** Returns the routine for the code's procedure, in a new context derived from the given context. *) |
| 79 | + |
| 80 | + val link_batch : |
| 81 | + merge_buffer:(buffer_ptr * Tnode.t) option ref -> |
| 82 | + context -> |
| 83 | + code_batch -> |
| 84 | + context * routine option array |
| 85 | + (** Returns the routines for the procedures included in the code batch. The returned context is |
| 86 | + downstream of all the returned routines (in particular, the routines' contexts are not |
| 87 | + independent). *) |
| 88 | + |
| 89 | + val unsafe_cleanup : unit -> unit |
| 90 | + (** Cleans up all work on a backend, releases resources. All previously retrieved values |
| 91 | + (contexts, virtual and physical devices) become invalid. The backend needs to be initialized |
| 92 | + again to be used again. *) |
| 93 | + |
| 94 | + val to_buffer : Tnode.t -> dst:buffer_ptr -> src:context -> unit |
| 95 | + val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit |
| 96 | + val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit |
| 97 | + val get_buffer : Tnode.t -> context -> buffer_ptr option |
| 98 | +end |
| 99 | + |
| 100 | +module type Backend = sig |
| 101 | + include No_device_backend |
| 102 | + |
| 103 | + val link : context -> code -> routine |
| 104 | + (** Returns the routine for the code's procedure, in a new context derived from the given context. *) |
| 105 | + |
| 106 | + val link_batch : context -> code_batch -> context * routine option array |
| 107 | + (** Returns the routines for the procedures included in the code batch. The returned context is |
| 108 | + downstream of all the returned routines. *) |
| 109 | + |
| 110 | + type event |
| 111 | + (** An event tracks if a device finished computing past a particular point in its schedue. These |
| 112 | + values are used internally for scheduling across devices of the backend, and can be used for |
| 113 | + explicit scheduling. *) |
| 114 | + |
| 115 | + val sync : event -> unit |
| 116 | + (** Blocks till the event completes, if it's not done already. *) |
| 117 | + |
| 118 | + val is_done : event -> bool |
| 119 | + (** Whether the event completed. *) |
| 120 | + |
| 121 | + val work_for : context -> Tnode.t -> event option |
| 122 | + (** If the tensor node is in the context, returns the event indicating if currently running or |
| 123 | + scheduled computations modifying that node on the context's device have completed. |
| 124 | +
|
| 125 | + NOTE: [work_for ctx tn], if work tracking was not yet registered for [tn], will register work |
| 126 | + tracking for [tn] and return the [all_work] event for [ctx]'s device. *) |
| 127 | + |
| 128 | + val will_wait_for : context -> event -> unit |
| 129 | + (** Schedules waiting for the given event on the context's device. |
| 130 | +
|
| 131 | + NOTE: it should rarely be needed to call [will_wait_for] explicitly, because it is typically |
| 132 | + called internally when necessary. But there is one exception, see {!device_to_device} when |
| 133 | + [into_merge_buffer=Streaming]. *) |
| 134 | + |
| 135 | + val from_host : context -> Tnode.t -> bool |
| 136 | + (** If the tensor node is both hosted and in-context, schedules a copy from host to context and |
| 137 | + returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize |
| 138 | + the device (via [await ctx.device] or [sync (work_for ctx tn)]) before the host's data is |
| 139 | + overwritten. *) |
| 140 | + |
| 141 | + val to_host : context -> Tnode.t -> bool |
| 142 | + (** If the tensor node is both hosted and in-context, schedules a copy from context to host and |
| 143 | + returns true, otherwise returns false. NOTE: it's the caller's responsibility to synchronize |
| 144 | + the device (via [await ctx.device] or [sync (work_for ctx tn)]) before the host's data is |
| 145 | + read. *) |
| 146 | + |
| 147 | + val device_to_device : |
| 148 | + Tnode.t -> into_merge_buffer:Types.merge_buffer_use -> dst:context -> src:context -> bool |
| 149 | + (** [device_to_device tn ~into_merge_buffer ~dst ~src] proceeds as follows: |
| 150 | + - If the node is absent from the [src] context and either it is present in the [dst] context |
| 151 | + or [into_merge_buffer] is different from [No]: raises an error. |
| 152 | + - If the node is absent from [dst] and [into_merge_buffer=No]: returns false. |
| 153 | + - Executes [will_wait_for dst (work_for src tn)]. |
| 154 | + - If [into_merge_buffer=No]: schedules a copy of the tensor node from the device of [src] to |
| 155 | + the device of [dst]. |
| 156 | + - If [into_merge_buffer] is different from [No]: sets on [dst] the merge buffer source to the |
| 157 | + given node. If [into_merge_buffer=Streaming], remembers the buffer pointer of the source |
| 158 | + node to use for streaming, without blocking. If [into_merge_buffer=Copy], schedules copying |
| 159 | + from [src] to the merge buffer of [dst]'s device. |
| 160 | + - If the [dst] context resulted from a compilation with [Streaming] or [Copy] specific merge |
| 161 | + buffer code, the [device_to_device] call should fail immediately if there's a mismatch with |
| 162 | + [into_merge_buffer]. |
| 163 | +
|
| 164 | + NOTE: If [into_merge_buffer:Streaming], after scheduling the work on [dst] using the merge |
| 165 | + buffer but before scheduling work on [src] that modifies [tn], execute |
| 166 | + [will_wait_for src (all_work (get_ctx_device dst))]. *) |
| 167 | + |
| 168 | + type physical_device |
| 169 | + type device |
| 170 | + |
| 171 | + val init : device -> context |
| 172 | + val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> device -> buffer_ptr |
| 173 | + |
| 174 | + val await : device -> unit |
| 175 | + (** Blocks till the device becomes idle, i.e. synchronizes the device. *) |
| 176 | + |
| 177 | + val all_work : device -> event |
| 178 | + (** Returns the event indicating if any currently running or scheduled computations on the device |
| 179 | + have completed. *) |
| 180 | + |
| 181 | + val is_idle : device -> bool |
| 182 | + (** Whether the device is currently waiting for work. *) |
| 183 | + |
| 184 | + val sexp_of_device : device -> Sexp.t |
| 185 | + val get_device : ordinal:int -> physical_device |
| 186 | + val num_physical_devices : unit -> int |
| 187 | + |
| 188 | + val suggested_num_virtual_devices : physical_device -> int |
| 189 | + (** The optimal number of virtual devices for the given physical device to follow the |
| 190 | + {!Types.config} strategy passed to {!No_device_backend.initialize}. *) |
| 191 | + |
| 192 | + val new_virtual_device : physical_device -> device |
| 193 | + val get_ctx_device : context -> device |
| 194 | + val get_physical_device : device -> physical_device |
| 195 | + val to_ordinal : physical_device -> int |
| 196 | + val to_subordinal : device -> int |
| 197 | + val get_name : device -> string |
| 198 | +end |
| 199 | + |
| 200 | +module type Simple_backend = sig |
| 201 | + type context [@@deriving sexp_of] |
| 202 | + type procedure [@@deriving sexp_of] |
| 203 | + type ctx_array [@@deriving sexp_of] |
| 204 | + type buffer_ptr [@@deriving sexp_of] |
| 205 | + type ctx_arrays = ctx_array Map.M(Tnode).t [@@deriving sexp_of] |
| 206 | + |
| 207 | + val buffer_ptr : ctx_array -> buffer_ptr |
| 208 | + val ctx_arrays : context -> ctx_arrays |
| 209 | + val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> unit -> buffer_ptr |
| 210 | + val expected_merge_node : procedure -> Tnode.t option |
| 211 | + |
| 212 | + val is_in_context : Low_level.traced_array -> bool |
| 213 | + (** If true, the node is required to be in the contexts linked with code that uses it. |
| 214 | +
|
| 215 | + Should return false for nodes that are virtual, local, or which the backend prefers to access |
| 216 | + directly from the host. *) |
| 217 | + |
| 218 | + val compile : |
| 219 | + name:string -> |
| 220 | + opt_ctx_arrays:ctx_arrays option -> |
| 221 | + Indexing.unit_bindings -> |
| 222 | + Low_level.optimized -> |
| 223 | + procedure |
| 224 | + |
| 225 | + val compile_batch : |
| 226 | + names:string option array -> |
| 227 | + opt_ctx_arrays:ctx_arrays option -> |
| 228 | + Indexing.unit_bindings -> |
| 229 | + Low_level.optimized option array -> |
| 230 | + ctx_arrays option * procedure option array |
| 231 | + |
| 232 | + val link_compiled : |
| 233 | + merge_buffer:(buffer_ptr * Tnode.t) option ref -> |
| 234 | + context -> |
| 235 | + procedure -> |
| 236 | + context * Indexing.lowered_bindings * Tnode.task * string |
| 237 | + |
| 238 | + val name : string |
| 239 | + val initialize : unit -> unit |
| 240 | + val is_initialized : unit -> bool |
| 241 | + val init : label:string -> context |
| 242 | + val finalize : context -> unit |
| 243 | + val unsafe_cleanup : unit -> unit |
| 244 | + val to_buffer : Tnode.t -> dst:buffer_ptr -> src:context -> unit |
| 245 | + val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit |
| 246 | + val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit |
| 247 | +end |
| 248 | + |
| 249 | +module type Lowered_backend = sig |
| 250 | + type context [@@deriving sexp_of] |
| 251 | + type code [@@deriving sexp_of] |
| 252 | + type code_batch [@@deriving sexp_of] |
| 253 | + type ctx_array [@@deriving sexp_of] |
| 254 | + type event |
| 255 | + |
| 256 | + val sync : event -> unit |
| 257 | + val is_done : event -> bool |
| 258 | + val work_for : context -> Tnode.t -> event option |
| 259 | + val will_wait_for : context -> event -> unit |
| 260 | + |
| 261 | + open Types |
| 262 | + |
| 263 | + val initialize : config -> unit |
| 264 | + val is_initialized : unit -> bool |
| 265 | + val finalize : context -> unit |
| 266 | + val sexp_of_context : context -> Sexplib.Sexp.t |
| 267 | + val compile : name:string -> Indexing.unit_bindings -> Low_level.optimized -> code |
| 268 | + |
| 269 | + val compile_batch : |
| 270 | + names:string option array -> |
| 271 | + Indexing.unit_bindings -> |
| 272 | + Low_level.optimized option array -> |
| 273 | + code_batch |
| 274 | + |
| 275 | + val is_in_context : Low_level.traced_array -> bool |
| 276 | + val ctx_arrays : context -> ctx_array Map.M(Tnode).t |
| 277 | + val link : context -> code -> context * Indexing.lowered_bindings * Tnode.task |
| 278 | + |
| 279 | + val link_batch : |
| 280 | + context -> code_batch -> context * Indexing.lowered_bindings * Tnode.task option array |
| 281 | + |
| 282 | + val unsafe_cleanup : unit -> unit |
| 283 | + |
| 284 | + val from_host : context -> Tnode.t -> bool |
| 285 | + (** If the array is both hosted and in-context, copies from host to context. *) |
| 286 | + |
| 287 | + val to_host : context -> Tnode.t -> bool |
| 288 | + (** If the array is both hosted and in-context, copies from context to host. *) |
| 289 | + |
| 290 | + val device_to_device : |
| 291 | + Tnode.t -> into_merge_buffer:merge_buffer_use -> dst:context -> src:context -> bool |
| 292 | + (** If the array is in both contexts, copies from [dst] to [src]. *) |
| 293 | + |
| 294 | + type buffer_ptr [@@deriving sexp_of] |
| 295 | + |
| 296 | + val to_buffer : Tnode.t -> dst:buffer_ptr -> src:context -> unit |
| 297 | + val host_to_buffer : Ndarray.t -> dst:buffer_ptr -> unit |
| 298 | + val buffer_to_host : Ndarray.t -> src:buffer_ptr -> unit |
| 299 | + val get_buffer : Tnode.t -> context -> buffer_ptr option |
| 300 | + |
| 301 | + type physical_device |
| 302 | + type device |
| 303 | + |
| 304 | + val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> device -> buffer_ptr |
| 305 | + val init : device -> context |
| 306 | + val await : device -> unit |
| 307 | + val is_idle : device -> bool |
| 308 | + val all_work : device -> event |
| 309 | + val sexp_of_device : device -> Sexplib.Sexp.t |
| 310 | + val num_physical_devices : unit -> int |
| 311 | + val suggested_num_virtual_devices : physical_device -> int |
| 312 | + val get_device : ordinal:int -> physical_device |
| 313 | + val get_physical_device : device -> physical_device |
| 314 | + val new_virtual_device : physical_device -> device |
| 315 | + val get_ctx_device : context -> device |
| 316 | + val get_name : device -> string |
| 317 | + val to_ordinal : physical_device -> int |
| 318 | + val to_subordinal : device -> int |
| 319 | + val name : string |
| 320 | +end |
0 commit comments