@@ -33,32 +33,34 @@ let () =
3333 (* For teacher forcing: create shifted versions of target sequence *)
3434 (* tgt_input: positions 0 to tgt_seq_len-2 (all but last) *)
3535 let tgt_input =
36- TDSL. range_of_shape ~label: [ " tgt_input" ] ~batch_dims: [ batch_size; tgt_seq_len - 1 ]
36+ TDSL. range_of_shape ~label: [ " tgt_input" ]
37+ ~batch_dims: [ batch_size; tgt_seq_len - 1 ]
3738 ~input_dims: [] ~output_dims: [ tgt_vocab_size ] ()
3839 in
3940
4041 (* tgt_target: positions 1 to tgt_seq_len-1 (all but first) *)
4142 (* In practice, this would be shifted token IDs, here we use one-hot for simplicity *)
4243 let tgt_target =
43- NTDSL. init ~l: " tgt_target" ~prec: Ir.Ops. single ~b: [ batch_size; tgt_seq_len - 1 ] ~i: []
44- ~o: [ tgt_vocab_size ]
44+ NTDSL. init ~l: " tgt_target" ~prec: Ir.Ops. single
45+ ~b: [ batch_size; tgt_seq_len - 1 ]
46+ ~i: [] ~o: [ tgt_vocab_size ]
4547 ~f: (function
4648 | [| _b; s; v |] ->
4749 (* Create a simple one-hot pattern for testing *)
4850 if v = Int. ((s + 1 ) % tgt_vocab_size) then 1. else 0.
49- | idcs ->
50- failwith @@ " Invalid indices: "
51- ^ Sexp. to_string_hum ([% sexp_of: int array ] idcs))
51+ | idcs -> failwith @@ " Invalid indices: " ^ Sexp. to_string_hum ([% sexp_of: int array ] idcs))
5252 ()
5353 in
5454
5555 (* Create a causal mask for the decoder input (shifted target sequence) *)
5656 (* Mask should be 0 for positions to mask out, 1 for positions to keep *)
5757 let mask =
58- NTDSL. init ~l: " mask" ~prec: Ir.Ops. single ~b: [ batch_size; tgt_seq_len - 1 ]
59- ~i: [ tgt_seq_len - 1 ] ~o: [ 1 ]
58+ NTDSL. init ~l: " mask" ~prec: Ir.Ops. single
59+ ~b: [ tgt_seq_len - 1 ]
60+ ~i: [ tgt_seq_len - 1 ]
61+ ~o: []
6062 ~f: (function
61- | [| _; s; _ ; t |] -> if s > = t then 1. else 0.
63+ | [| s ; t |] -> if s > = t then 1. else 0.
6264 | idcs ->
6365 failwith @@ " Invalid indices: expected [| _; s; _; t |], got "
6466 ^ Sexp. to_string_hum ([% sexp_of: int array ] idcs))
@@ -78,7 +80,6 @@ let () =
7880 let _ctx = Ocannl.Train. forward_once ~output_cd_file: false ~bindings ctx loss in
7981
8082 (* Verify shapes *)
81- Stdio. printf " Loss shape:\n %s\n "
82- (Sexp. to_string_hum ([% sexp_of: Shape. t] loss.Tensor. shape));
83+ Stdio. printf " Loss shape:\n %s\n " (Sexp. to_string_hum ([% sexp_of: Shape. t] loss.Tensor. shape));
8384 Stdio. printf " Logits shape:\n %s\n %!"
8485 (Sexp. to_string_hum ([% sexp_of: Shape. t] logits.Tensor. shape))
0 commit comments