@@ -491,340 +491,6 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
491491 let get_buffer tn context = Backend. get_buffer tn context.ctx
492492end
493493
494- module Pipes_multicore_backend (Backend : No_device_backend ) : Backend = struct
495- module Domain = Domain [@ warning " -3" ]
496-
497- type task_list = Tnode .task Utils .mutable_list [@@ deriving sexp_of ]
498-
499- type device_state = {
500- mutable keep_spinning : bool ;
501- mutable device_error : exn option ;
502- mutable host_pos : task_list ;
503- mutable dev_pos : task_list ;
504- mutable dev_previous_pos : task_list ;
505- dev_wait : (Utils .waiter [@ sexp.opaque]);
506- }
507- [@@ deriving sexp_of ]
508-
509- type buffer_ptr = Backend .buffer_ptr [@@ deriving sexp_of ]
510-
511- type device = {
512- state : device_state ;
513- host_wait_for_idle : (Utils .waiter [@ sexp.opaque]);
514- merge_buffer : (buffer_ptr * Tnode .t ) option ref ;
515- mutable allocated_buffer : (buffer_ptr * int ) option ;
516- ordinal : int ;
517- domain : (unit Domain .t [@ sexp.opaque]);
518- }
519- [@@ deriving sexp_of ]
520-
521- let alloc_buffer ?old_buffer ~size_in_bytes _device =
522- Backend. alloc_buffer ?old_buffer ~size_in_bytes ()
523-
524- type physical_device = device [@@ deriving sexp_of ]
525- type code = Backend .code [@@ deriving sexp_of ]
526- type code_batch = Backend .code_batch [@@ deriving sexp_of ]
527-
528- let expected_merge_node (code : code ) = Backend. expected_merge_node code
529- let expected_merge_nodes (codes : code_batch ) = Backend. expected_merge_nodes codes
530- let is_dev_queue_empty state = Utils. (is_empty @@ tl_exn state.dev_previous_pos)
531- let is_idle device = is_dev_queue_empty device.state && device.state.dev_wait.is_waiting ()
532- let name = " multicore " ^ Backend. name
533-
534- let await device =
535- assert (Domain. is_main_domain () );
536- let d = device.state in
537- let keep_waiting () =
538- if d.keep_spinning && not (is_dev_queue_empty d) then (
539- ignore (d.dev_wait.release_if_waiting () : bool );
540- true )
541- else d.keep_spinning && not (d.dev_wait.is_waiting () )
542- in
543- while not (is_dev_queue_empty d) do
544- ignore (device.host_wait_for_idle.await ~keep_waiting () : bool )
545- done ;
546- Option. iter d.device_error ~f: (fun e ->
547- Exn. reraise e @@ name ^ " device " ^ Int. to_string device.ordinal)
548-
549- let % track2_sexp schedule_task device task =
550- assert (Domain. is_main_domain () );
551- let d = device.state in
552- Option. iter d.device_error ~f: (fun e ->
553- Exn. reraise e @@ name ^ " device " ^ Int. to_string device.ordinal);
554- if not d.keep_spinning then invalid_arg " Multicore_backend: device not available" ;
555- d.host_pos < - Utils. insert ~next: task d.host_pos;
556- ignore (d.dev_wait.release_if_waiting () : bool )
557-
558- let global_run_no = ref 0
559-
560- let spinup_device ~(ordinal : int ) : device =
561- Int. incr global_run_no;
562- let init_pos =
563- Utils. Cons
564- {
565- hd =
566- Tnode. Task
567- { context_lifetime = () ; description = " root of task queue" ; work = (fun () -> () ) };
568- tl = Empty ;
569- }
570- in
571- let state =
572- {
573- keep_spinning = true ;
574- device_error = None ;
575- host_pos = init_pos;
576- dev_pos = Empty ;
577- dev_previous_pos = init_pos;
578- dev_wait = Utils. waiter ~name: " dev" () ;
579- }
580- in
581- let host_wait_for_idle = Utils. waiter ~name: " host" () in
582- let keep_waiting () =
583- state.keep_spinning && is_dev_queue_empty state && not (host_wait_for_idle.is_waiting () )
584- in
585- let wait_by_dev = state.dev_wait.await ~keep_waiting in
586- let % diagn_l_sexp worker (() : unit ) : unit =
587- try
588- while state.keep_spinning do
589- match state.dev_pos with
590- | Empty ->
591- let _host_released : bool = host_wait_for_idle.release_if_waiting () in
592- let _could_wait : bool = wait_by_dev () in
593- (* not _host_released && not _could_wait: we busy-loop until host processes its
594- release. *)
595- state.dev_pos < - Utils. tl_exn state.dev_previous_pos
596- | Cons { hd; tl } ->
597- Tnode. run hd;
598- state.dev_previous_pos < - state.dev_pos;
599- state.dev_pos < - tl
600- done
601- with e ->
602- state.device_error < - Some e;
603- state.keep_spinning < - false ;
604- [% log " Device" , (ordinal : int ), " exception" , Exn. to_string e];
605- ignore (host_wait_for_idle.release_if_waiting () : bool );
606- (* TODO: we risk raising this error multiple times because await and schedule_task raise
607- device_error. But this is fine if we assume all exceptions are fatal. *)
608- raise e
609- in
610- {
611- state;
612- host_wait_for_idle;
613- ordinal;
614- domain = Domain. spawn worker;
615- merge_buffer = ref None ;
616- allocated_buffer = None ;
617- }
618-
619- let % diagn_sexp make_work device (Tnode. Task { context_lifetime; description; _ } as task) =
620- let % diagn_l_sexp work () = schedule_task device task in
621- Tnode. Task
622- {
623- context_lifetime;
624- description = " schedules {" ^ description ^ " } on device " ^ Int. to_string device.ordinal;
625- work;
626- }
627-
628- type context = { device : device ; ctx : Backend .context ; expected_merge_node : Tnode .t option }
629- [@@ deriving sexp_of ]
630-
631- type nonrec routine = context routine [@@ deriving sexp_of ]
632-
633- let init device =
634- {
635- device;
636- ctx = Backend. init ~label: (name ^ " " ^ Int. to_string device.ordinal);
637- expected_merge_node = None ;
638- }
639-
640- let initialize = Backend. initialize
641- let is_initialized = Backend. is_initialized
642-
643- let finalize { device; ctx; expected_merge_node = _ } =
644- await device;
645- Backend. finalize ctx
646-
647- let compile = Backend. compile
648- let compile_batch = Backend. compile_batch
649-
650- let link ?from_prior_context { ctx; device; expected_merge_node = _ } code =
651- let task = Backend. link ?from_prior_context ~merge_buffer: device.merge_buffer ctx code in
652- {
653- task with
654- context =
655- { ctx = task.context; device; expected_merge_node = Backend. expected_merge_node code };
656- schedule = make_work device task.schedule;
657- }
658-
659- let link_batch ?from_prior_context { ctx; device; expected_merge_node } code_batch =
660- let ctx, routines =
661- Backend. link_batch ?from_prior_context ~merge_buffer: device.merge_buffer ctx code_batch
662- in
663- let merge_nodes = Backend. expected_merge_nodes code_batch in
664- ( { ctx; device; expected_merge_node },
665- Array. mapi routines ~f: (fun i ->
666- Option. map ~f: (fun task ->
667- {
668- task with
669- context = { ctx = task.context; device; expected_merge_node = merge_nodes.(i) };
670- schedule = make_work device task.schedule;
671- })) )
672-
673- let from_host (context : context ) (tn : Tnode.t ) =
674- Option. value ~default: false
675- @@ Option. map (Backend. get_buffer tn context.ctx) ~f: (fun c_arr ->
676- match tn.Tnode. array with
677- | (lazy (Some h_arr )) ->
678- let % diagn_l_sexp work () =
679- Backend. host_to_buffer h_arr ~dst: c_arr;
680- [% log_block
681- " from_host " ^ Tnode. debug_name tn;
682- [% log " copied" , Tnode. debug_name tn, " from host" ];
683- [% log2_printbox
684- let indices =
685- Array. init (Array. length @@ Lazy. force tn.dims) ~f: (fun i -> i - 5 )
686- in
687- Ndarray. render_array ~indices h_arr]]
688- in
689- schedule_task context.device
690- (Tnode. Task
691- {
692- context_lifetime = context;
693- description =
694- " from_host " ^ Tnode. debug_name tn ^ " dst "
695- ^ Int. to_string context.device.ordinal;
696- work;
697- });
698- true
699- | (lazy None) ->
700- [% diagn_sexp
701- [% log_block
702- " nothing to copy from host" ;
703- [% log " for" , Tnode. debug_name tn]]];
704- false )
705-
706- let to_host (context : context ) (tn : Tnode.t ) =
707- Option. value ~default: false
708- @@ Option. map (Backend. get_buffer tn context.ctx) ~f: (fun c_arr ->
709- match tn.Tnode. array with
710- | (lazy (Some h_arr )) ->
711- let % diagn_l_sexp work () =
712- Backend. buffer_to_host h_arr ~src: c_arr;
713- [% log_block
714- " to_host " ^ Tnode. debug_name tn;
715- [% log " copied to host" ];
716- [% log2_printbox
717- let indices =
718- Array. init (Array. length @@ Lazy. force tn.dims) ~f: (fun i -> i - 5 )
719- in
720- Ndarray. render_array ~indices h_arr]]
721- in
722- schedule_task context.device
723- (Tnode. Task
724- {
725- context_lifetime = context;
726- description =
727- " from_host " ^ Tnode. debug_name tn ^ " dst "
728- ^ Int. to_string context.device.ordinal;
729- work;
730- });
731- true
732- | (lazy None) ->
733- [% diagn_sexp
734- [% log_block
735- " nothing to copy to host" ;
736- [% log " for" , Tnode. debug_name tn]]];
737- false )
738-
739- let device_to_device tn ~into_merge_buffer ~dst ~src =
740- let dev = dst.device in
741- if
742- (not (equal_merge_buffer_use into_merge_buffer No ))
743- && not (Option. equal Tnode. equal (Some tn) dst.expected_merge_node)
744- then
745- raise
746- @@ Utils. User_error
747- (" Multicore_backend.device_to_device: merge node mismatch, expected "
748- ^ Option. (value ~default: " none" @@ map ~f: Tnode. debug_name dst.expected_merge_node)
749- ^ " , actual " ^ Tnode. debug_name tn);
750- let schedule dst_ptr =
751- let work =
752- match into_merge_buffer with
753- | No -> fun () -> Backend. to_buffer tn ~dst: dst_ptr ~src: src.ctx
754- | Streaming ->
755- fun () ->
756- dev.merge_buffer :=
757- Option. map ~f: (fun ptr -> (ptr, tn)) @@ Backend. get_buffer tn src.ctx
758- | Copy ->
759- fun () ->
760- let size_in_bytes = Tnode. size_in_bytes tn in
761- let allocated_capacity =
762- Option. value ~default: 0 @@ Option. map dev.allocated_buffer ~f: snd
763- in
764- if allocated_capacity < size_in_bytes then
765- dev.allocated_buffer < -
766- Some
767- ( Backend. alloc_buffer ?old_buffer:dev.allocated_buffer ~size_in_bytes ( ),
768- size_in_bytes );
769- let merge_ptr = fst @@ Option. value_exn dev.allocated_buffer in
770- dev.merge_buffer := Some (merge_ptr, tn);
771- Backend. to_buffer tn ~dst: merge_ptr ~src: src.ctx
772- in
773- schedule_task dev
774- (Tnode. Task
775- {
776- context_lifetime = (dst, src);
777- description =
778- " device_to_device " ^ Tnode. debug_name tn ^ " dst " ^ Int. to_string dev.ordinal
779- ^ " src " ^ Int. to_string src.device.ordinal;
780- work;
781- })
782- in
783- match (Backend. get_buffer tn dst.ctx, Backend. get_buffer tn src.ctx) with
784- | Some dst , Some _ ->
785- schedule dst;
786- true
787- | _ -> false
788-
789- let num_physical_devices () = Domain. recommended_domain_count () - 1
790- let suggested_num_virtual_devices _device = 1
791- let devices = Array. create ~len: (num_physical_devices () ) None
792-
793- let % track2_sexp unsafe_cleanup () =
794- assert (Domain. is_main_domain () );
795- let wait_for_finish device =
796- await device;
797- device.state.keep_spinning < - false ;
798- ignore (device.state.dev_wait.release_if_waiting () : bool )
799- in
800- Array. iter devices ~f: (Option. iter ~f: wait_for_finish);
801- let cleanup ordinal device =
802- Domain. join device.domain;
803- device.host_wait_for_idle.finalize () ;
804- device.state.dev_wait.finalize () ;
805- devices.(ordinal) < - None
806- in
807- Array. iteri devices ~f: (fun ordinal -> Option. iter ~f: (cleanup ordinal));
808- Backend. unsafe_cleanup ()
809-
810- let get_device ~ordinal =
811- Option. value_or_thunk devices.(ordinal) ~default: (fun () ->
812- let dev = spinup_device ~ordinal in
813- devices.(ordinal) < - Some dev;
814- dev)
815-
816- let new_virtual_device device = device
817- let get_physical_device device = device
818- let get_ctx_device { device; _ } = device
819- let get_name device = Int. to_string device.ordinal
820- let to_ordinal { ordinal; _ } = ordinal
821- let to_subordinal _ = 0
822- let to_buffer tn ~dst ~src = Backend. to_buffer tn ~dst ~src: src.ctx
823- let host_to_buffer = Backend. host_to_buffer
824- let buffer_to_host = Backend. buffer_to_host
825- let get_buffer tn context = Backend. get_buffer tn context.ctx
826- end
827-
828494(* * For debugging, allow [Sync_backend(...).suggested_num_virtual_devices] calls to return >1
829495 numbers. *)
830496let sync_suggested_num_virtual_devices = ref 1
@@ -1198,14 +864,12 @@ module C_device : No_device_backend = Simple_no_device_backend ((
1198864
1199865module Cc_backend = Multicore_backend (C_device )
1200866module Sync_cc_backend = Sync_backend (C_device )
1201- module Pipes_cc_backend = Pipes_multicore_backend (C_device )
1202867
1203868module Gccjit_device : No_device_backend = Simple_no_device_backend ((
1204869 Gcc_backend : Simple_backend with type context = Gcc_backend. context))
1205870
1206871module Gccjit_backend = Multicore_backend (Gccjit_device )
1207872module Sync_gccjit_backend = Sync_backend (Gccjit_device )
1208- module Pipes_gccjit_backend = Pipes_multicore_backend (Gccjit_device )
1209873
1210874module Cuda_backend : Backend = struct
1211875 include Cuda_backend
@@ -1304,15 +968,13 @@ let fresh_backend ?backend_name ?(config = Physical_devices_only) () =
1304968 let backend =
1305969 match
1306970 Option. value_or_thunk backend_name ~default: (fun () ->
1307- Utils. get_global_arg ~arg_name: " backend" ~default: " pipes_cc " )
971+ Utils. get_global_arg ~arg_name: " backend" ~default: " cc " )
1308972 |> String. lowercase
1309973 with
1310974 | "cc" -> (module Cc_backend : Backend )
1311975 | "gccjit" -> (module Gccjit_backend : Backend )
1312976 | "sync_cc" -> (module Sync_cc_backend : Backend )
1313977 | "sync_gccjit" -> (module Sync_gccjit_backend : Backend )
1314- | "pipes_cc" -> (module Pipes_cc_backend : Backend )
1315- | "pipes_gccjit" -> (module Pipes_gccjit_backend : Backend )
1316978 | "cuda" -> (module Cuda_backend : Backend )
1317979 | backend -> invalid_arg [% string " Backends.fresh_backend: unknown backend %{backend}" ]
1318980 in
0 commit comments