Skip to content

Commit a0eb1b1

Browse files
committed
embed_self_id operation, by Claude Opus
Summary by Claude: I successfully implemented support for the Embed_self_id fetch operation: 1. In arrayjit/lib/assignments.ml: - Added Embed_self_id case to doc_of_fetch_op function to handle pretty-printing as "!@self_id" - Added Embed_self_id case to to_low_level function where it converts to Low_level.Constant with the Tnode id converted to float 2. In lib/operation.ml: - Added embed_self_id nullary operation function (line 405) that creates a tensor operation using the Embed_self_id fetch operation 3. In lib/tensor.ml: - Added Embed_self_id to the pattern matching case to handle it properly 4. In lib/shape.ml: - Added Embed_self_id case to the pattern matching to handle shape inference The implementation follows the same pattern as Embed_symbol but evaluates to the float representation of the Tnode id of the array field, as requested.
1 parent e121809 commit a0eb1b1

File tree

5 files changed

+73
-1
lines changed

5 files changed

+73
-1
lines changed

arrayjit/lib/assignments.ml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ type fetch_op =
3333
positions in a buffer instantiating the node can differ.) *)
3434
| Slice of { batch_idx : Indexing.static_symbol; sliced : Tn.t }
3535
| Embed_symbol of Indexing.static_symbol
36+
| Embed_self_id (** Embeds the id of the [array] field of the [Fetch] constructor. *)
3637
[@@deriving sexp_of, equal]
3738

3839
and t =
@@ -340,6 +341,9 @@ let%diagn2_sexp to_low_level code =
340341
| Fetch { array; fetch_op = Embed_symbol s; dims } ->
341342
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs ->
342343
set array idcs @@ Embed_index (Iterator s.static_symbol))
344+
| Fetch { array; fetch_op = Embed_self_id; dims } ->
345+
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs ->
346+
set array idcs @@ Constant (Float.of_int array.id))
343347
| Fetch { array; fetch_op = Range_over_offsets; dims = (lazy dims) } ->
344348
Low_level.loop_over_dims dims ~body:(fun idcs ->
345349
let offset = Indexing.reflect_projection ~dims ~projection:idcs in
@@ -434,6 +438,7 @@ let to_doc ?name ?static_indices () c =
434438
string (ident sliced ^ " @| " ^ Indexing.symbol_ident batch_idx.static_symbol)
435439
| Embed_symbol { static_symbol; static_range = _ } ->
436440
string ("!@" ^ Indexing.symbol_ident static_symbol)
441+
| Embed_self_id -> string "!@self_id"
437442
in
438443

439444
let rec doc_of_code = function

lib/operation.ml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,22 @@ let embed_symbol ?(label = []) static_sym : Tensor.t =
402402
(Shape.make ~batch_dims:[] ~input_dims:[] ~output_dims:[ 1 ] ())
403403
[]
404404

405+
let embed_self_id ?(label = []) () : Tensor.t =
406+
let module NTDSL = Initial_NTDSL in
407+
let op_asn ~v ~projections =
408+
Asgns.to_comp
409+
@@ Fetch
410+
{
411+
array = v;
412+
fetch_op = Embed_self_id;
413+
dims = lazy (Lazy.force projections).Idx.lhs_dims;
414+
}
415+
in
416+
let grad_asn ~t:_ ~g:_ ~projections:_ = Asgns.empty_comp in
417+
Tensor.op ~label:("!@self_id" :: label) ~op_asn ~grad_asn ~grad_spec:Prohibit_grad
418+
(Shape.make ~batch_dims:[] ~input_dims:[] ~output_dims:[ 1 ] ())
419+
[]
420+
405421
module DO = struct
406422
let ( * ) = matmul ~grad_spec:If_needed
407423
let ( *. ) = pointmul ~grad_spec:If_needed
@@ -436,13 +452,15 @@ module DO = struct
436452
let ( <> ) = ne ~grad_spec:Prohibit_grad
437453
let threefry4x32 = threefry4x32
438454
let uint4x32_to_prec_uniform = uint4x32_to_prec_uniform
455+
let embed_self_id = embed_self_id
439456
end
440457

441458
module NDO = struct
442459
include NDO_before_div
443460

444461
let ( /. ) = pointdiv ~grad_spec:Prohibit_grad
445462
let ( @| ) ?label t1 idx = slice ?label ~grad_spec:Prohibit_grad idx t1
463+
let ( !@ ) = embed_symbol
446464
let relu = relu ~grad_spec:Prohibit_grad
447465
let sat01 = sat01 ~grad_spec:Prohibit_grad
448466
let fma = fma ~grad_spec:Prohibit_grad
@@ -465,6 +483,7 @@ module NDO = struct
465483
let ( <> ) = ne ~grad_spec:Prohibit_grad
466484
let threefry4x32 = threefry4x32
467485
let uint4x32_to_prec_uniform = uint4x32_to_prec_uniform
486+
let embed_self_id = embed_self_id
468487
end
469488

470489
(** The input [i] dimensions default to empty. The batch and output dimensions will be inferred if
@@ -527,6 +546,7 @@ module TDSL = struct
527546
let rebatch = rebatch ~grad_spec:If_needed
528547
let threefry4x32 = threefry4x32
529548
let uint4x32_to_prec_uniform = uint4x32_to_prec_uniform
549+
let embed_self_id = embed_self_id
530550

531551
(** The input and output dimensions will be inferred if omitted. See {!reshape}. *)
532552
let reshape_param ~l ?i ?o ndarray =
@@ -561,6 +581,7 @@ module NTDSL = struct
561581
let rebatch = rebatch ~grad_spec:Prohibit_grad
562582
let threefry4x32 = threefry4x32
563583
let uint4x32_to_prec_uniform = uint4x32_to_prec_uniform
584+
let embed_self_id = embed_self_id
564585

565586
let counter ?(label = []) =
566587
let module NTDSL = Initial_NTDSL in

lib/shape.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,7 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
467467
:: mark_terminal () )
468468
else (Row.dim_map_empty, mark_terminal ())
469469
| Terminal (Fetch (Embed_symbol _)) -> (Row.dim_map_empty, mark_terminal ())
470+
| Terminal (Fetch Embed_self_id) -> (Row.dim_map_empty, mark_terminal ())
470471
| Transpose (Transpose, sh) ->
471472
( Row.dim_map_empty,
472473
[

lib/tensor.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?inp
370370
match fetch_op with
371371
| None -> Asgns.empty_comp
372372
| Some
373-
((Constant _ | Slice _ | Embed_symbol _ | Range_over_offsets | Constant_fill _) as fetch_op)
373+
((Constant _ | Slice _ | Embed_symbol _ | Embed_self_id | Range_over_offsets | Constant_fill _) as fetch_op)
374374
->
375375
Asgns.to_comp @@ Fetch { array = v; fetch_op; dims }
376376
in

test/operations/hello_world_op.ml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,3 +768,48 @@ let%expect_test "Very big tensor" =
768768
│└──────┴─────────────────────────────────────────┴─────────────────────────────────────────┴──────┴─────────────────────────────────────────┴─────────────────────────────────────────┘│
769769
└───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
770770
|}]
771+
772+
let%expect_test "Embed self id" =
773+
Tensor.unsafe_reinitialize ();
774+
let module Backend = (val Backends.fresh_backend ()) in
775+
let backend =
776+
(module Backend : Backend
777+
with type buffer_ptr = Backend.buffer_ptr
778+
and type dev = Backend.dev
779+
and type runner = Backend.runner
780+
and type event = Backend.event
781+
and type optimize_ctx = Backend.optimize_ctx)
782+
in
783+
let%op hey = embed_self_id () in
784+
let%op hoo = embed_self_id () in
785+
(* let%op bar = hoo + hey + embed_self_id () in *)
786+
Train.set_hosted hey.value;
787+
Train.set_hosted hoo.value;
788+
(* Train.set_hosted bar.value; *)
789+
ignore (Train.forward_once backend hey);
790+
ignore (Train.forward_once backend hoo);
791+
(* ignore (Train.forward_once backend bar); *)
792+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false hey;
793+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false hoo;
794+
(* Train.printf ~here:[%here] ~with_code:false ~with_grad:false bar; *)
795+
[%expect {|
796+
HERE: test/operations/hello_world_op.ml:792:23
797+
┌─────────────────────────┐
798+
│[0]: !@self_id shape 0:1
799+
│┌┬──────┐ │
800+
│││axis 0│ │
801+
│├┼──────┤ │
802+
│││ 0.00 │ │
803+
│└┴──────┘ │
804+
└─────────────────────────┘
805+
HERE: test/operations/hello_world_op.ml:793:23
806+
┌─────────────────────────┐
807+
│[1]: !@self_id shape 0:1
808+
│┌┬──────┐ │
809+
│││axis 0│ │
810+
│├┼──────┤ │
811+
│││ 1.00 │ │
812+
│└┴──────┘ │
813+
└─────────────────────────┘
814+
|}]
815+

0 commit comments

Comments
 (0)