Skip to content

Commit 982b813

Browse files
committed
Allow Uint4x32_to_prec_uniform to be virtual or local; tiny formatting & cleanup
1 parent da61756 commit 982b813

File tree

3 files changed

+17
-26
lines changed

3 files changed

+17
-26
lines changed

lib/shape.ml

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -394,10 +394,8 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
394394
[ Terminal_row cur_sh.batch; Terminal_row cur_sh.input; Terminal_row cur_sh.output ]
395395
in
396396
match logic with
397-
| Terminal Range_over_offsets ->
398-
(Row.dim_map_empty, mark_terminal ())
399-
| Terminal (Constant _c) ->
400-
(Row.dim_map_empty, mark_terminal ())
397+
| Terminal Range_over_offsets -> (Row.dim_map_empty, mark_terminal ())
398+
| Terminal (Constant _c) -> (Row.dim_map_empty, mark_terminal ())
401399
| Terminal (Constant_fill values) ->
402400
let len = Array.length values in
403401
let io_dims =
@@ -416,14 +414,10 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
416414
constr = Total_elems { nominator = batch_elems; divided_by = dim_var_set_empty };
417415
}
418416
:: mark_terminal () )
419-
| Terminal (Access (C_function _)) ->
420-
(Row.dim_map_empty, mark_terminal ())
421-
| Terminal (Access (External_unsafe _)) ->
422-
(Row.dim_map_empty, mark_terminal ())
423-
| Terminal (Access (Merge_buffer _)) ->
424-
(Row.dim_map_empty, mark_terminal ())
425-
| Terminal (Access (Uint4x32_to_prec_uniform _)) ->
426-
(Row.dim_map_empty, mark_terminal ())
417+
| Terminal (Access (C_function _)) -> (Row.dim_map_empty, mark_terminal ())
418+
| Terminal (Access (External_unsafe _)) -> (Row.dim_map_empty, mark_terminal ())
419+
| Terminal (Access (Merge_buffer _)) -> (Row.dim_map_empty, mark_terminal ())
420+
| Terminal (Access (Uint4x32_to_prec_uniform _)) -> (Row.dim_map_empty, mark_terminal ())
427421
| Terminal (Access (File_mapped (filename, prec))) ->
428422
let fd = Unix.openfile filename [ Unix.O_RDONLY ] 0o640 in
429423
let len = Unix.lseek fd 0 Unix.SEEK_END / Ir.Ops.prec_in_bytes prec in
@@ -444,10 +438,8 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
444438
constr = Total_elems { nominator = batch_elems; divided_by = dim_var_set_empty };
445439
}
446440
:: mark_terminal () )
447-
| Terminal (Slice _) ->
448-
(Row.dim_map_empty, mark_terminal ())
449-
| Terminal (Embed_symbol _) ->
450-
(Row.dim_map_empty, mark_terminal ())
441+
| Terminal (Slice _) -> (Row.dim_map_empty, mark_terminal ())
442+
| Terminal (Embed_symbol _) -> (Row.dim_map_empty, mark_terminal ())
451443
| Transpose (Transpose, sh) ->
452444
( Row.dim_map_empty,
453445
[
@@ -773,7 +765,6 @@ let fresh_proj_ids update =
773765
(** Computes the indexing into subtensors given the shape information of a tensor.
774766
[derive_projections] should only be invoked when the shapes are fully inferred already! *)
775767
let%debug4_sexp derive_projections (update_step : update_step) : Idx.projections =
776-
Stdio.printf "derive_projections\n%!";
777768
finish_inference ();
778769
let resolved_padding, inferred_padding = fresh_proj_ids update_step in
779770
let _debug_update_step : update_step = update_step in

lib/tensor.ml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,7 @@ let op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
299299
session_state.backprop_roots <- Map.remove session_state.backprop_roots ti.id);
300300
(* The order is not relevant, we keep the same order as in backprop for readability. *)
301301
let diff = Some { grad = g; zero_grads; backprop } in
302-
let tensor =
303-
{ params = Set.empty (module T); forward; diff; id; value = v; shape; children }
304-
in
302+
let tensor = { params = Set.empty (module T); forward; diff; id; value = v; shape; children } in
305303
session_state.forward_roots <- Map.add_exn session_state.forward_roots ~key:id ~data:tensor;
306304
session_state.backprop_roots <- Map.add_exn session_state.backprop_roots ~key:id ~data:tensor;
307305
tensor
@@ -331,7 +329,11 @@ let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?inp
331329
| Some fetch_op_fn ->
332330
let fetch_op = fetch_op_fn ~v in
333331
(match fetch_op with
334-
| Constant _ | Slice _ | Embed_symbol _ | Range_over_offsets | Constant_fill _ -> ()
332+
| Constant _ | Slice _ | Embed_symbol _ | Range_over_offsets | Constant_fill _
333+
| Access (Uint4x32_to_prec_uniform _) ->
334+
(* For these operations it makes sense to have a local / virtual tensor if the result is
335+
consumed in the same computation. *)
336+
()
335337
| Access _ ->
336338
(* Note: [Access] can be used for merging across devices. But, some use cases of
337339
[Access] will require a hosted tensor node. *)
@@ -363,7 +365,7 @@ let number ?(label = []) ?axis_label ?(grad_spec = Prohibit_grad) c =
363365
t
364366

365367
let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?output_dims
366-
?batch_axes ?input_axes ?output_axes ?(_strict = true) values =
368+
?batch_axes ?input_axes ?output_axes values =
367369
let to_dim_list dims axes =
368370
Option.value ~default:[] @@ Option.first_some dims @@ Option.map axes ~f:(List.map ~f:snd)
369371
in
@@ -407,8 +409,8 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?
407409
Tn.update_prec ~only_if:is_up_to_fp16 t.value single);
408410
t
409411

410-
let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?(_strict = true) ?values
411-
label =
412+
let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced
413+
?values label =
412414
let fetch_op_fn ~v:_ =
413415
match values with Some values -> Asgns.Constant_fill values | None -> Asgns.Range_over_offsets
414416
in

lib/tensor.mli

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,6 @@ val ndarray :
200200
?batch_axes:(string * int) list ->
201201
?input_axes:(string * int) list ->
202202
?output_axes:(string * int) list ->
203-
?_strict:bool ->
204203
float array ->
205204
t
206205
(** A tensor with an explicit shape, initialized to the given values. Omitted shape rows default to
@@ -215,7 +214,6 @@ val param :
215214
?input_axes:(string * int) list ->
216215
?output_axes:(string * int) list ->
217216
?deduced:Shape.deduce_within_shape ->
218-
?_strict:bool ->
219217
?values:float array ->
220218
string ->
221219
t

0 commit comments

Comments
 (0)