Skip to content

Commit f571d9e

Browse files
committed
Fix sexp_of_device/stream to break cyclicity
1 parent 72bf7ec commit f571d9e

File tree

1 file changed

+49
-12
lines changed

1 file changed

+49
-12
lines changed

arrayjit/lib/backend_intf.ml

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,42 +79,78 @@ module type Device_config = sig
7979
val name : string
8080
end
8181

82-
type ('buffer_ptr, 'dev, 'runner, 'event) device = {
82+
83+
type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
84+
dev : 'dev;
85+
ordinal : int;
86+
mutable shared_merge_buffer : 'buffer_ptr buffer option;
87+
mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option;
88+
mutable latest_stream_id : int;
89+
released : Utils.atomic_bool;
90+
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
91+
owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Hashtbl.M(Tnode).t;
92+
shared_writer_streams :
93+
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
94+
host_reading_streams :
95+
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
96+
host_writing_streams :
97+
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
98+
}
99+
100+
and ('buffer_ptr, 'dev, 'runner, 'event) stream_ref = {
101+
device : ('buffer_ptr, 'dev, 'runner, 'event) device_ref;
102+
runner : 'runner;
103+
merge_buffer : 'buffer_ptr buffer option ref;
104+
mutable scheduled_merge_node : Tnode.t option;
105+
stream_id : int;
106+
mutable allocated_buffer : 'buffer_ptr buffer option;
107+
updating_for : 'event Hashtbl.M(Tnode).t;
108+
mutable updating_for_merge_buffer : (Tnode.t * 'event) option;
109+
reader_streams :
110+
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
111+
}
112+
113+
let sexp_of_device_ref _ _ _ _ device = [%sexp_of: string * int] ("ordinal", device.ordinal)
114+
let sexp_of_stream_ref _ _ _ _ stream = [%sexp_of: string * int] ("stream_id", stream.stream_id)
115+
let equal_stream_ref s1 s2 = s1.stream_id = s2.stream_id && s1.device.ordinal = s2.device.ordinal
116+
117+
type ('buffer_ptr, 'dev, 'runner, 'event) device =
118+
('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
83119
dev : 'dev;
84120
ordinal : int;
85121
mutable shared_merge_buffer : 'buffer_ptr buffer option;
86122
(** Depending on backend implementations, either the currently used cross-stream merge buffer,
87123
or the one most recently scheduled. *)
88124
mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option;
89-
(** The tensor node that was most recently scheduled to be in the cross-stream merge buffer,
90-
and its readiness event. *)
125+
(** The tensor node that was most recently scheduled to be in the cross-stream merge buffer. *)
91126
mutable latest_stream_id : int;
92127
released : Utils.atomic_bool;
93128
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
94129
(** Freshly created arrays that might be shared across streams. The map can both grow and
95130
shrink. *)
96-
owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream Hashtbl.M(Tnode).t;
131+
owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream_ref Hashtbl.M(Tnode).t;
97132
(** The stream owning a given node. This map can only grow. Currently, if the memory mode of a
98133
node is inferred, only this stream will modify a cross-stream shared array. But memory
99134
modes can also be set manually. *)
100135
shared_writer_streams :
101-
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
136+
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
102137
(** The streams that most recently have been scheduled to update (write to) a
103138
cross-stream-shared node, and the associated update completion event. The completed events
104139
are removed opportunistically. *)
105140
host_reading_streams :
106-
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
141+
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
107142
(** The streams that most recently have been reading from a node's on-host array. The
108143
completed events are removed opportunistically. *)
109144
host_writing_streams :
110-
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
145+
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
111146
(** The streams that most recently have been writing to a node's on-host array. The completed
112147
events are removed opportunistically. *)
113148
}
114149
[@@deriving sexp_of]
115150

116-
and ('buffer_ptr, 'dev, 'runner, 'event) stream = {
117-
device : ('buffer_ptr, 'dev, 'runner, 'event) device;
151+
type ('buffer_ptr, 'dev, 'runner, 'event) stream =
152+
('buffer_ptr, 'dev, 'runner, 'event) stream_ref = {
153+
device : ('buffer_ptr, 'dev, 'runner, 'event) device_ref;
118154
runner : 'runner;
119155
merge_buffer : 'buffer_ptr buffer option ref;
120156
(** Depending on backend implementations, either the currently used merge buffer, or the one
@@ -124,17 +160,18 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream = {
124160
stream_id : int; (** An ID unique within the device. *)
125161
mutable allocated_buffer : 'buffer_ptr buffer option;
126162
updating_for : 'event Hashtbl.M(Tnode).t;
127-
(* The completion event for updating (writing to) a node via this stream, if any. *)
163+
(* The completion event for updating (writing to) a node via this stream, if any. *)
128164
mutable updating_for_merge_buffer : (Tnode.t * 'event) option;
129165
(** Like {!field-updating_for}, but for the merge buffer. *)
130-
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
166+
reader_streams :
167+
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
131168
(** The streams, other than this stream, that most recently have been reading from a node in
132169
this stream's context, and the associated use completion events. The completed events are
133170
removed opportunistically. *)
134171
}
135172
[@@deriving sexp_of]
136173

137-
let equal_stream s1 s2 = s1.stream_id = s2.stream_id && s1.device.ordinal = s2.device.ordinal
174+
let equal_stream = equal_stream_ref
138175

139176
type ('buffer_ptr, 'stream) context = {
140177
stream : 'stream;

0 commit comments

Comments
 (0)