Skip to content

Commit cc95235

Browse files
authored
Merge pull request #408 from ahrefs/claude/issue-396-20250924-1048
Implement shape errors for parameters with unspecified dimensions; note: known failing tests that I'll address soon
2 parents 3dffac6 + 3eb6f61 commit cc95235

23 files changed

+929
-224
lines changed

CLAUDE.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ opam install cudajit # for CUDA backend
7979
- Row variables (`..d..`) enable flexible axis handling and broadcasting
8080
- Einsum notation supports convolutions, reductions, and arbitrary permutations
8181
- "Principle of least commitment": use row variables where axis count doesn't matter
82+
- Shape inference completion is forced by lowering: via `Context.compile`, or wrappers such as `Train.to_routine`, `Train.run_once` or `Train.forward_once`
8283

8384
3. **Backend Architecture**: Unified interface supporting CPU (multicore), CUDA, and Metal backends
8485

@@ -90,8 +91,8 @@ opam install cudajit # for CUDA backend
9091

9192
- Tests are implemented either as inline expectations using `ppx_expect`; or as cram-style tests using Dune's `test` stanza where an `.ml` file is compiled, executed, and its output compared against an `.expected` file
9293
- The two approaches are exclusive: a test using using `.expected` file target cannot also use `%expect` inline expectations
93-
- `.expected` tests are easier to debug, `%expect` tests should only be used when the outputs are illustrative
94-
- Tutorial files, i.e. `%expect` tests, in `test/` serve as both documentation and integration tests
94+
- `.expected` tests, i.e. using the `test` stanza, are easier to debug, use them for testing new features
95+
- Tutorial files, i.e. `%expect` tests, in `test/` serve as both documentation and integration tests, should only be used when the outputs are illustrative
9596

9697
**Running Tests**:
9798
- `dune runtest` - runs all tests including inline tests and cram-style tests

arrayjit/lib/context.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@ type backend_wrapper =
2424
-> backend_wrapper
2525

2626
type t = {
27-
backend_wrapper : backend_wrapper;
27+
backend_wrapper : (backend_wrapper [@sexp.opaque]);
2828
device_id : int;
2929
backend_name : string;
3030
initialized_nodes : Set.M(Tn).t; (* Track which nodes have been initialized *)
3131
}
32+
[@@deriving sexp_of]
3233

3334
type routine = {
3435
(* TODO: Remove commented out fields if they prove to be unnecessary *)

arrayjit/lib/context.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
module Backends_deprecated = Backends
44

5-
type t
5+
type t [@@deriving sexp_of]
66
(** Execution context managing device, compilation, and buffers *)
77

88
type routine

arrayjit/lib/ndarray.ml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -672,11 +672,10 @@ let render_array ?(brief = false) ?(prefix = "") ?(entries_per_axis = 4) ?(label
672672
else
673673
concise_float ~prec:Utils.settings.print_decimals_precision (get_as_float arr indices)
674674
with Invalid_argument _ ->
675-
raise
676-
@@ Utils.User_error
677-
[%string
678-
"Invalid indices: %{int_dims_to_string indices} into array: \
679-
%{(int_dims_to_string dims)}"])
675+
failwith
676+
[%string
677+
"Invalid indices: %{int_dims_to_string indices} into array: %{(int_dims_to_string \
678+
dims)}"])
680679
in
681680
let tag ?pos label ind =
682681
if ind = -1 then ""

docs/shape_inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Shape inference and projection inference
22

3-
To separate concerns, OCANNL is split into the `arrayjit` library, responsible for compilation of high-level n-D array operation sequences (`Assignments.comp`) via the gccjit and cuda backends, and the main `ocannl` library, responsible for deriving the operations computing the forward propagation and backpropagation from tensor expressions. In particular, `arrayjit` contains `Indexing`, which represents complex indexing into arrays, and the main library `ocannl` has `Row` and `Shape` modules, which do the most "heavy-lifting" in the translation from concise tensor expressions to sequences of assignments.
3+
To separate concerns, OCANNL is split into the `arrayjit` library, responsible for compilation of high-level n-D array operation sequences (`Assignments.comp`) via backends such as sync_cc, metal and cuda, and the main `ocannl` library, responsible for deriving the operations computing the forward propagation and backpropagation from tensor expressions. In particular, `arrayjit` contains `Indexing`, which represents complex indexing into arrays, and the main library `ocannl` has `Row` and `Shape` modules, which do the most "heavy-lifting" in the translation from concise tensor expressions to sequences of assignments.
44

55
Shape inference broadly speaking consists in OCANNL of inferring the `Shape.t` record -- shape inference proper, and inferring the `Indexing.projections` record -- projections inference. `Shape.t` records are mutable, so that the partially inferred shapes can be observed by the user. Shape and projections inference is intended to be declarative -- independent of the order in which constraints are added. There is one aspect that is not declarative: when tensor expressions are compiled to assignments, i.e. jitted, still-unsolved shape variables in terminal nodes are substituted by their least upper bounds if any, or by dimension-1 / no-more-axes.
66

lib/nn_blocks.ml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,26 @@ let%op transformer ~label ~num_encoder_layers ~num_decoder_layers ~num_heads ~d_
180180
let tgt_embedded = ({ tgt_embed; o = [ d_dec ] } * tgt) + pos_encoding_tgt in
181181
{ w_out } * decoder ~train_step tgt_embedded ~enc_output ~mask
182182

183+
(** Transformer with teacher forcing for autoregressive training.
184+
185+
TODO: Simplify once tensor shifting/slicing is better supported in shape inference. Currently
186+
requires pre-shifted tgt_input (all but last token) and tgt_target (all but first token). During
187+
training, the model learns to predict tgt_target given tgt_input. *)
188+
let%op transformer_with_loss ~label:_ ~model () ~train_step ~src ~tgt_input ~tgt_target ~mask =
189+
(* Get model predictions for the input sequence *)
190+
let logits = model ~train_step ~src ~tgt:tgt_input ~mask in
191+
192+
(* Compute cross-entropy loss between predictions and target *)
193+
(* softmax over vocabulary dimension *)
194+
let log_probs = log (softmax ~spec:"... | v" () logits) in
195+
196+
(* Negative log likelihood loss: -sum(target * log_probs) *)
197+
(* tgt_target should be one-hot encoded or use label smoothing *)
198+
let loss = -(tgt_target *. log_probs) ++ "...|... => 0" in
199+
200+
(* Return both loss and logits for potential additional metrics *)
201+
(loss, logits)
202+
183203
(** {2 Convolutional Neural Network Building Blocks} *)
184204

185205
(** 2D convolution layer with flexible padding and stride options. *)

lib/train.ml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ let every_non_literal_on_host =
183183

184184
module Lazy = Utils.Lazy
185185

186-
let to_routine (ctx : Context.t) ?(hosted = true) bindings comp =
186+
let%track7_sexp to_routine (ctx : Context.t) ?(hosted = true) bindings comp =
187187
if hosted then Set.iter (snd @@ Asgns.collect_nodes_guess_output comp.Asgns.asgns) ~f:set_hosted;
188188
let _ctx, routine = Context.compile ctx comp bindings in
189189
(* Return just the routine for backward compatibility - ctx is discarded here *)
@@ -234,7 +234,7 @@ type example_train_result = {
234234
true, and the update code is output to a file before shape inference potentially crashes at
235235
[init_params]. *)
236236
let%track3_sexp run_once ?(output_cd_file = false) ?(hosted = true) ?(skip_init = false) ?reinit_all
237-
?(bindings = IDX.empty) ~f ctx t =
237+
?(bindings = IDX.empty) ~f ctx (t : Tensor.t) : Context.t =
238238
if hosted then set_hosted t.Tensor.value;
239239
(* Compute the update early, to ensure the shape inference is done. *)
240240
let update = f t in
@@ -275,8 +275,8 @@ let update_once ?output_cd_file ?(hosted = true) ?(skip_init = false) ?reinit_al
275275

276276
(** [printf] is a wrapper around {!Tensor.print} that assumes [~force:true], and by default sets
277277
[~with_code:false], [~with_grad:true], and [~style:`Default]. *)
278-
let printf ?here ?(with_grad = true) ?(with_code = false) ?(with_low_level = false)
279-
?(style = `Default) t =
278+
let%debug7_sexp printf ?here ?(with_grad = true) ?(with_code = false) ?(with_low_level = false)
279+
?(style = `Default) (t : Tensor.t) : unit =
280280
Tensor.print ?here ~force:true ~with_grad ~with_code ~with_low_level style t
281281

282282
(** [printf_tree] is a wrapper around {!Tensor.print_tree} that assumes [~force:true], and by

0 commit comments

Comments
 (0)