@@ -203,10 +203,20 @@ type grad_spec = Require_grad | Prohibit_grad | If_needed [@@deriving sexp, equa
203203let % track7_sexp op ~(label : string list ) ?(ternary_op = Shape. Pointwise_tern )
204204 ?(compose_op = Shape. Pointwise_bin ) ?(transpose_op = Shape. Pointwise_un ) ?terminal_op ~op_asn
205205 ~grad_asn ?(grad_spec = If_needed ) make_shape (orig_ts : t list ) : t =
206+ List. iter orig_ts ~f: (fun t ->
207+ if t.id > = session_state.next_id then
208+ raise
209+ @@ Session_error
210+ ( [% string
211+ " Tensor #%{t.id#Int} %{Tn.debug_name t.value} has an id greater than the last id \
212+ #%{session_state.next_id - 1#Int} -- check your uses of \
213+ Tensor.unsafe_reinitialize, if all your uses are valid, report this as a bug." ],
214+ Some t ));
206215 (* The code needs to be included in the order it was computed due to potential non-tree DAGs. *)
207216 let ordered_ts = List. dedup_and_sort orig_ts ~compare: (fun t1 t2 -> Int. ascending t1.id t2.id) in
208217 let id : int = session_state.next_id in
209218 session_state.next_id < - session_state.next_id + 1 ;
219+ let _session_state_next_id : int = session_state.next_id in
210220 let shape = make_shape ~debug_name: (Tn. get_debug_name ~id ~label () ) ~id in
211221 let default_prec =
212222 let lazy_v_precs = List. map orig_ts ~f: (fun ti -> ti.value.prec) in
@@ -247,8 +257,8 @@ let%track7_sexp op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
247257 | [ t1; t2 ] -> [ Shape. Broadcast (compose_op, t1.shape, t2.shape) ]
248258 | [ t1; t2; t3 ] -> [ Shape. Broadcast_tern (ternary_op, t1.shape, t2.shape, t3.shape) ]
249259 | _ ->
250- (* Let's implement what we need when we need it. *)
251- assert false
260+ (* Let's implement what we need when we need it. *)
261+ assert false
252262 in
253263 let local_shape_updates =
254264 List. map ~f: (fun logic -> Shape. { shape; logic; id = get_update_id () }) @@ shape_logics orig_ts
@@ -381,8 +391,8 @@ let%track7_sexp unop ?transpose_op ~op_asn ~grad_asn ?grad_spec t1 ?(label = [])
381391 () )
382392 [ t1 ]
383393
384- let % track7_sexp term ?init_data ?fetch_op ?grad_spec ?(label = [] ) ?batch_dims ?batch_axes ?input_dims
385- ?output_dims ?input_axes ?output_axes ?deduced () : t =
394+ let % track7_sexp term ?init_data ?fetch_op ?grad_spec ?(label = [] ) ?batch_dims ?batch_axes
395+ ?input_dims ? output_dims ?input_axes ?output_axes ?deduced () : t =
386396 let terminal_op =
387397 match (init_data, fetch_op) with
388398 | Some _ , Some _ -> invalid_arg " Tensor.term: both init_data and fetch_op are provided"
@@ -566,7 +576,7 @@ let rec get_random_seed () =
566576 set_random_seed () ;
567577 get_random_seed ()
568578
569- let % track5_sexp unsafe_reinitialize () =
579+ let % track5_sexp unsafe_reinitialize () : unit =
570580 session_state.next_id < - 0 ;
571581 session_state.forward_roots < - Map. empty (module Int );
572582 session_state.backprop_roots < - Map. empty (module Int );
0 commit comments