@@ -79,42 +79,78 @@ module type Device_config = sig
7979 val name : string
8080end
8181
82- type ('buffer_ptr, 'dev, 'runner, 'event) device = {
82+
83+ type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
84+ dev : 'dev ;
85+ ordinal : int ;
86+ mutable shared_merge_buffer : 'buffer_ptr buffer option ;
87+ mutable scheduled_shared_merge_node : (Tnode .t * 'event option ) option ;
88+ mutable latest_stream_id : int ;
89+ released : Utils .atomic_bool ;
90+ cross_stream_candidates : 'buffer_ptr Hashtbl .M (Tnode ).t;
91+ owner_stream : ('buffer_ptr , 'dev , 'runner , 'event ) stream_ref Hashtbl .M (Tnode ).t;
92+ shared_writer_streams :
93+ (('buffer_ptr , 'dev , 'runner , 'event ) stream_ref * 'event ) list Hashtbl .M (Tnode ).t;
94+ host_reading_streams :
95+ (('buffer_ptr , 'dev , 'runner , 'event ) stream_ref * 'event ) list Hashtbl .M (Tnode ).t;
96+ host_writing_streams :
97+ (('buffer_ptr , 'dev , 'runner , 'event ) stream_ref * 'event ) list Hashtbl .M (Tnode ).t;
98+ }
99+
100+ and ('buffer_ptr, 'dev, 'runner, 'event) stream_ref = {
101+ device : ('buffer_ptr , 'dev , 'runner , 'event ) device_ref ;
102+ runner : 'runner ;
103+ merge_buffer : 'buffer_ptr buffer option ref ;
104+ mutable scheduled_merge_node : Tnode .t option ;
105+ stream_id : int ;
106+ mutable allocated_buffer : 'buffer_ptr buffer option ;
107+ updating_for : 'event Hashtbl .M (Tnode ).t;
108+ mutable updating_for_merge_buffer : (Tnode .t * 'event ) option ;
109+ reader_streams :
110+ (('buffer_ptr , 'dev , 'runner , 'event ) stream_ref * 'event ) list Hashtbl .M (Tnode ).t;
111+ }
112+
113+ let sexp_of_device_ref _ _ _ _ device = [% sexp_of: string * int ] (" ordinal" , device.ordinal)
114+ let sexp_of_stream_ref _ _ _ _ stream = [% sexp_of: string * int ] (" stream_id" , stream.stream_id)
115+ let equal_stream_ref s1 s2 = s1.stream_id = s2.stream_id && s1.device.ordinal = s2.device.ordinal
116+
117+ type ('buffer_ptr, 'dev, 'runner, 'event) device =
118+ ('buffer_ptr , 'dev , 'runner , 'event ) device_ref = {
83119 dev : 'dev ;
84120 ordinal : int ;
85121 mutable shared_merge_buffer : 'buffer_ptr buffer option ;
86122 (* * Depending on backend implementations, either the currently used cross-stream merge buffer,
87123 or the one most recently scheduled. *)
88124 mutable scheduled_shared_merge_node : (Tnode .t * 'event option ) option ;
89- (* * The tensor node that was most recently scheduled to be in the cross-stream merge buffer,
90- and its readiness event. *)
125+ (* * The tensor node that was most recently scheduled to be in the cross-stream merge buffer. *)
91126 mutable latest_stream_id : int ;
92127 released : Utils .atomic_bool ;
93128 cross_stream_candidates : 'buffer_ptr Hashtbl .M (Tnode ).t;
94129 (* * Freshly created arrays that might be shared across streams. The map can both grow and
95130 shrink. *)
96- owner_stream : ('buffer_ptr , 'dev , 'runner , 'event ) stream Hashtbl .M (Tnode ).t;
131+ owner_stream : ('buffer_ptr , 'dev , 'runner , 'event ) stream_ref Hashtbl .M (Tnode ).t;
97132 (* * The stream owning a given node. This map can only grow. Currently, if the memory mode of a
98133 node is inferred, only this stream will modify a cross-stream shared array. But memory
99134 modes can also be set manually. *)
100135 shared_writer_streams :
101- (('buffer_ptr , 'dev , 'runner , 'event ) stream * 'event ) list Hashtbl .M (Tnode ).t;
136+ (('buffer_ptr , 'dev , 'runner , 'event ) stream_ref * 'event ) list Hashtbl .M (Tnode ).t;
102137 (* * The streams that most recently have been scheduled to update (write to) a
103138 cross-stream-shared node, and the associated update completion event. The completed events
104139 are removed opportunistically. *)
105140 host_reading_streams :
106- (('buffer_ptr , 'dev , 'runner , 'event ) stream * 'event ) list Hashtbl .M (Tnode ).t;
141+ (('buffer_ptr , 'dev , 'runner , 'event ) stream_ref * 'event ) list Hashtbl .M (Tnode ).t;
107142 (* * The streams that most recently have been reading from a node's on-host array. The
108143 completed events are removed opportunistically. *)
109144 host_writing_streams :
110- (('buffer_ptr , 'dev , 'runner , 'event ) stream * 'event ) list Hashtbl .M (Tnode ).t;
145+ (('buffer_ptr , 'dev , 'runner , 'event ) stream_ref * 'event ) list Hashtbl .M (Tnode ).t;
111146 (* * The streams that most recently have been writing to a node's on-host array. The completed
112147 events are removed opportunistically. *)
113148}
114149[@@ deriving sexp_of ]
115150
116- and ('buffer_ptr, 'dev, 'runner, 'event) stream = {
117- device : ('buffer_ptr , 'dev , 'runner , 'event ) device ;
151+ type ('buffer_ptr, 'dev, 'runner, 'event) stream =
152+ ('buffer_ptr , 'dev , 'runner , 'event ) stream_ref = {
153+ device : ('buffer_ptr , 'dev , 'runner , 'event ) device_ref ;
118154 runner : 'runner ;
119155 merge_buffer : 'buffer_ptr buffer option ref ;
120156 (* * Depending on backend implementations, either the currently used merge buffer, or the one
@@ -124,17 +160,18 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream = {
124160 stream_id : int ; (* * An ID unique within the device. *)
125161 mutable allocated_buffer : 'buffer_ptr buffer option ;
126162 updating_for : 'event Hashtbl .M (Tnode ).t;
127- (* The completion event for updating (writing to) a node via this stream, if any. *)
163+ (* The completion event for updating (writing to) a node via this stream, if any. *)
128164 mutable updating_for_merge_buffer : (Tnode .t * 'event ) option ;
129165 (* * Like {!field-updating_for}, but for the merge buffer. *)
130- reader_streams : (('buffer_ptr , 'dev , 'runner , 'event ) stream * 'event ) list Hashtbl .M (Tnode ).t;
166+ reader_streams :
167+ (('buffer_ptr , 'dev , 'runner , 'event ) stream_ref * 'event ) list Hashtbl .M (Tnode ).t;
131168 (* * The streams, other than this stream, that most recently have been reading from a node in
132169 this stream's context, and the associated use completion events. The completed events are
133170 removed opportunistically. *)
134171}
135172[@@ deriving sexp_of ]
136173
137- let equal_stream s1 s2 = s1.stream_id = s2.stream_id && s1.device.ordinal = s2.device.ordinal
174+ let equal_stream = equal_stream_ref
138175
139176type ('buffer_ptr, 'stream) context = {
140177 stream : 'stream ;
0 commit comments