Skip to content

Commit f18e6e4

Browse files
committed
Refactored type print_style (by Claude)
Also overlooked update to test_numerical_types.expected
1 parent 8826162 commit f18e6e4

File tree

10 files changed

+229
-26
lines changed

10 files changed

+229
-26
lines changed

arrayjit/test/test_numerical_types.expected

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,13 @@ FP8 array values:
3636
[1] = 0.500000
3737
[2] = 2.000000
3838
[3] = -1.000000
39+
40+
41+
Testing padding functionality:
42+
Padded array (dims 4x6, unpadded region 2x3):
43+
-999.0 -999.0 -999.0 -999.0 -999.0 -999.0
44+
-999.0 -999.0 1.0 2.0 3.0 -999.0
45+
-999.0 -999.0 4.0 5.0 6.0 -999.0
46+
-999.0 -999.0 -999.0 -999.0 -999.0 -999.0
47+
48+
Expected: padding value (-999.0) in margins, data values (1.0-6.0) in center region

lib/row.ml

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -75,26 +75,41 @@ type 'a dim_hashtbl = 'a Hashtbl.M(Dim_var).t [@@deriving sexp]
7575

7676
let dim_hashtbl () = Hashtbl.create (module Dim_var)
7777

78+
type print_style = Only_labels | Axis_size | Axis_number_and_size | Projection_and_size
79+
[@@deriving equal, compare, sexp]
80+
7881
let solved_dim_to_string style { d; label; proj_id; padding } =
79-
let label_prefix =
80-
match style with
81-
| `Only_labels -> ( match label with None -> "_" | Some l -> l)
82-
| _ -> ( match label with None -> "" | Some l -> l ^ "=")
83-
in
84-
match (proj_id, padding) with
85-
| _ when phys_equal style `Only_labels -> label_prefix
86-
| Some proj_id, None -> label_prefix ^ [%string "p%{Proj_id.to_string proj_id}"]
87-
| Some proj_id, Some p -> label_prefix ^ [%string "p%{Proj_id.to_string proj_id}+%{p#Int}"]
88-
| None, Some p -> label_prefix ^ [%string "%{d#Int}+%{p#Int}"]
89-
| None, None -> label_prefix ^ Int.to_string d
82+
match style with
83+
| Only_labels -> ( match label with None -> "_" | Some l -> l)
84+
| Axis_size | Axis_number_and_size ->
85+
let label_prefix = match label with None -> "" | Some l -> l ^ "=" in
86+
(match (proj_id, padding) with
87+
| None, None -> label_prefix ^ Int.to_string d
88+
| None, Some p -> label_prefix ^ [%string "%{d#Int}+%{p#Int}"]
89+
| Some _, None -> label_prefix ^ Int.to_string d
90+
| Some _, Some p -> label_prefix ^ [%string "%{d#Int}+%{p#Int}"])
91+
| Projection_and_size ->
92+
let label_part = match label with None -> "" | Some l -> l ^ "=" in
93+
let size_part = Int.to_string d in
94+
let padding_part = match padding with None -> "" | Some p -> "+" ^ Int.to_string p in
95+
let proj_part = match proj_id with None -> "" | Some pid -> "p" ^ Proj_id.to_string pid in
96+
let extra_parts =
97+
match (proj_id, padding) with
98+
| None, None -> ""
99+
| None, Some _ -> padding_part
100+
| Some _, None -> "[" ^ proj_part ^ "]"
101+
| Some _, Some _ -> "[" ^ proj_part ^ "]" ^ padding_part
102+
in
103+
label_part ^ size_part ^ extra_parts
90104

91105
let dim_to_string style = function
92-
| Dim { label = None; _ } when phys_equal style `Only_labels -> "_"
93-
| Dim { label = Some l; _ } when phys_equal style `Only_labels -> l
94-
| Dim { d; label = None; padding = None; _ } -> Int.to_string d
95-
| Dim { d; label = Some l; padding = None; _ } -> [%string "%{l}=%{d#Int}"]
96-
| Dim { d; label = None; padding = Some p; _ } -> [%string "%{d#Int}+%{p#Int}"]
97-
| Dim { d; label = Some l; padding = Some p; _ } -> [%string "%{l}=%{d#Int}+%{p#Int}"]
106+
| Dim { label = None; _ } when equal_print_style style Only_labels -> "_"
107+
| Dim { label = Some l; _ } when equal_print_style style Only_labels -> l
108+
| Dim { d; label = None; padding = None; proj_id = None } when equal_print_style style Axis_size -> Int.to_string d
109+
| Dim { d; label = Some l; padding = None; proj_id = None } when equal_print_style style Axis_size -> [%string "%{l}=%{d#Int}"]
110+
| Dim { d; label = None; padding = Some p; proj_id = None } when equal_print_style style Axis_size -> [%string "%{d#Int}+%{p#Int}"]
111+
| Dim { d; label = Some l; padding = Some p; proj_id = None } when equal_print_style style Axis_size -> [%string "%{l}=%{d#Int}+%{p#Int}"]
112+
| Dim solved_dim -> solved_dim_to_string style solved_dim
98113
| Var { id; label = Some l } -> [%string "$%{id#Int}:%{l}"]
99114
| Var { id; label = None } -> "$" ^ Int.to_string id
100115
| Affine { solved; unsolved } -> (
@@ -105,7 +120,8 @@ let dim_to_string style = function
105120
in
106121
let unsolved_terms =
107122
List.map unsolved ~f:(fun (coeff, v) ->
108-
if coeff = 1 then [%string "$%{v.id#Int}"] else [%string "%{coeff#Int}*$%{v.id#Int}"])
123+
let label_part = match v.label with None -> "" | Some l -> ":" ^ l in
124+
if coeff = 1 then [%string "$%{v.id#Int}%{label_part}"] else [%string "%{coeff#Int}*$%{v.id#Int}%{label_part}"])
109125
in
110126
let all_terms = solved_terms @ unsolved_terms in
111127
match all_terms with

lib/row.mli

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,11 @@ type dim =
3636

3737
val get_dim : d:int -> ?label:string -> unit -> dim
3838
val dim_to_int_exn : dim -> int
39-
val solved_dim_to_string : [> `Only_labels ] -> solved_dim -> string
40-
val dim_to_string : [> `Only_labels ] -> dim -> string
39+
type print_style = Only_labels | Axis_size | Axis_number_and_size | Projection_and_size
40+
[@@deriving equal, compare, sexp]
41+
42+
val solved_dim_to_string : print_style -> solved_dim -> string
43+
val dim_to_string : print_style -> dim -> string
4144

4245
type row_id [@@deriving sexp, compare, equal, hash]
4346
type row_cmp

lib/shape.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ let of_spec ?(deduced = Not_constrained) ~debug_name ~id spec =
851851
raise @@ Row.Shape_error ("of spec / " ^ s, Shape_mismatch [ result ] :: trace)));
852852
result
853853

854-
let to_string_hum ?(style = `Axis_size) (sh : t) =
854+
let to_string_hum ?(style = Row.Axis_size) (sh : t) =
855855
let n_outputs = List.length @@ sh.output.dims in
856856
let n_batch = List.length @@ sh.batch.dims in
857857
let dims_to_string kind =
@@ -865,8 +865,8 @@ let to_string_hum ?(style = `Axis_size) (sh : t) =
865865
| `Batch -> i
866866
in
867867
match style with
868-
| `Only_labels | `Axis_size -> Row.dim_to_string style d
869-
| `Axis_number_and_size -> Int.to_string num ^ ":" ^ Row.dim_to_string style d)
868+
| Row.Only_labels | Axis_size | Projection_and_size -> Row.dim_to_string style d
869+
| Axis_number_and_size -> Int.to_string num ^ ":" ^ Row.dim_to_string style d)
870870
in
871871
let batch_dims = dims_to_string `Batch in
872872
let input_dims = dims_to_string `Input in

lib/shape.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ val make :
109109
inferred when provided, and must match whenever the axis sizes must match. *)
110110

111111
val to_string_hum :
112-
?style:[< `Axis_number_and_size | `Axis_size | `Only_labels > `Axis_size `Only_labels ] ->
112+
?style:Row.print_style ->
113113
t ->
114114
string
115115

lib/tensor.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ let to_doc ?(spy = false) ~with_grad ~with_code ?(with_low_level = false)
590590
let label = Tn.label t.value in
591591
let prefix_str =
592592
"[" ^ Int.to_string t.id ^ "]: " ^ label ^ " shape "
593-
^ Shape.to_string_hum ~style:`Axis_number_and_size sh
593+
^ Shape.to_string_hum ~style:Row.Axis_number_and_size sh
594594
^ " "
595595
in
596596
let grad_txt diff =
@@ -631,7 +631,7 @@ let to_doc ?(spy = false) ~with_grad ~with_code ?(with_low_level = false)
631631
Array.exists ~f:(Fn.non String.is_empty) labels
632632
|| Shape.(List.exists ~f:Row.(equal_dim @@ get_dim ~d:1 ()) sh.input.dims)
633633
in
634-
let axes_spec = if needs_spec then Some (Shape.to_string_hum ~style:`Only_labels sh) else None in
634+
let axes_spec = if needs_spec then Some (Shape.to_string_hum ~style:Row.Only_labels sh) else None in
635635
let num_batch_axes = List.length sh.batch.dims in
636636
let num_input_axes = List.length sh.input.dims in
637637
let num_output_axes = List.length sh.output.dims in

test/dune

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,26 @@
7070
"micrograd_demo_logging-%{read:config/ocannl_backend.txt}-0-0.log.expected"
7171
"%{read:config/ocannl_backend.txt}-0-0.log.actual")))
7272

73+
(executable
74+
(name test_print_style)
75+
(modules test_print_style)
76+
(libraries ocannl)
77+
(preprocess
78+
(pps ppx_ocannl)))
79+
80+
(rule
81+
(target test_print_style.output)
82+
(deps test_print_style.exe ocannl_config)
83+
(action
84+
(with-stdout-to
85+
%{target}
86+
(run %{deps}))))
87+
88+
(rule
89+
(alias runtest)
90+
(action
91+
(diff test_print_style.expected test_print_style.output)))
92+
7393
(library
7494
(name tutorials)
7595
(package neural_nets_lib)

test/test_print_style.expected

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
2+
Welcome to OCANNL! Reading configuration defaults from /Users/lukstafi/ocannl/_build/default/test/ocannl_config.
3+
Retrieving commandline, environment, or config file variable ocannl_log_level
4+
Found 0, in the config file
5+
Testing print_style functionality:
6+
7+
=== Testing solved_dim_to_string ===
8+
Full attributes (d=28, padding=2, label=height, proj_id):
9+
Only_labels: height
10+
Axis_size: height=28+2
11+
Axis_number_and_size: height=28+2
12+
Projection_and_size: height=28+2
13+
14+
Minimal attributes (d=64, no padding, no label, no proj_id):
15+
Only_labels: _
16+
Axis_size: 64
17+
Projection_and_size: 64
18+
19+
With padding only (d=32, padding=3, label=width, no proj_id):
20+
Axis_size: width=32+3
21+
Projection_and_size: width=32+3
22+
23+
With projection (d=32, label=width, proj_id):
24+
Axis_size: width=32
25+
Projection_and_size: width=32[p1]
26+
27+
=== Testing dim_to_string ===
28+
Solved dimensions:
29+
Only_labels (full): height
30+
Axis_size (full): height=28+2
31+
Projection_and_size (full): height=28+2
32+
Only_labels (minimal): _
33+
Axis_size (minimal): 64
34+
35+
Variable dimensions:
36+
Only_labels (labeled var): $1:channels
37+
Axis_size (labeled var): $1:channels
38+
Projection_and_size (labeled var): $1:channels
39+
Only_labels (unlabeled var): $2
40+
Axis_size (unlabeled var): $2
41+
42+
=== Testing Shape.to_string_hum ===
43+
Shape with batch=[1], input=[784], output=[10,5]:
44+
Only_labels: _|_->_,_
45+
Axis_size: 1|784->10,5
46+
Axis_number_and_size: 0:1|3:784->1:10,2:5
47+
Projection_and_size: 1|784->10,5
48+
Default style: 1|784->10,5

test/test_print_style.ml

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
let test_print_styles () =
2+
let open Ocannl.Row in
3+
4+
Stdio.printf "Testing print_style functionality:\n\n";
5+
6+
(* Create a solved dimension with all possible attributes *)
7+
let solved_dim_full = {
8+
d = 28;
9+
padding = Some 2;
10+
label = Some "height";
11+
proj_id = None
12+
} in
13+
14+
(* Create a dimension with projection by using fresh_row_proj *)
15+
let row_with_dim = {
16+
dims = [get_dim ~d:32 ~label:"width" ()];
17+
bcast = Broadcastable;
18+
id = row_id ~sh_id:1 ~kind:`Output
19+
} in
20+
let row_with_proj = fresh_row_proj row_with_dim in
21+
let solved_dim_with_proj = match row_with_proj.dims with
22+
| [Dim sd] -> sd
23+
| _ -> failwith "Expected single dimension"
24+
in
25+
26+
(* Create a solved dimension with minimal attributes *)
27+
let solved_dim_minimal = {
28+
d = 64;
29+
padding = None;
30+
label = None;
31+
proj_id = None
32+
} in
33+
34+
(* Create a solved dimension with only padding *)
35+
let solved_dim_padding = {
36+
d = 32;
37+
padding = Some 3;
38+
label = Some "width";
39+
proj_id = None
40+
} in
41+
42+
(* Create a variable dimension *)
43+
let var_dim_labeled = get_var ~label:"channels" () in
44+
let var_dim_unlabeled = get_var () in
45+
46+
Stdio.printf "=== Testing solved_dim_to_string ===\n";
47+
Stdio.printf "Full attributes (d=28, padding=2, label=height, proj_id):\n";
48+
Stdio.printf " Only_labels: %s\n" (solved_dim_to_string Only_labels solved_dim_full);
49+
Stdio.printf " Axis_size: %s\n" (solved_dim_to_string Axis_size solved_dim_full);
50+
Stdio.printf " Axis_number_and_size: %s\n" (solved_dim_to_string Axis_number_and_size solved_dim_full);
51+
Stdio.printf " Projection_and_size: %s\n" (solved_dim_to_string Projection_and_size solved_dim_full);
52+
53+
Stdio.printf "\nMinimal attributes (d=64, no padding, no label, no proj_id):\n";
54+
Stdio.printf " Only_labels: %s\n" (solved_dim_to_string Only_labels solved_dim_minimal);
55+
Stdio.printf " Axis_size: %s\n" (solved_dim_to_string Axis_size solved_dim_minimal);
56+
Stdio.printf " Projection_and_size: %s\n" (solved_dim_to_string Projection_and_size solved_dim_minimal);
57+
58+
Stdio.printf "\nWith padding only (d=32, padding=3, label=width, no proj_id):\n";
59+
Stdio.printf " Axis_size: %s\n" (solved_dim_to_string Axis_size solved_dim_padding);
60+
Stdio.printf " Projection_and_size: %s\n" (solved_dim_to_string Projection_and_size solved_dim_padding);
61+
62+
Stdio.printf "\nWith projection (d=32, label=width, proj_id):\n";
63+
Stdio.printf " Axis_size: %s\n" (solved_dim_to_string Axis_size solved_dim_with_proj);
64+
Stdio.printf " Projection_and_size: %s\n" (solved_dim_to_string Projection_and_size solved_dim_with_proj);
65+
66+
Stdio.printf "\n=== Testing dim_to_string ===\n";
67+
Stdio.printf "Solved dimensions:\n";
68+
Stdio.printf " Only_labels (full): %s\n" (dim_to_string Only_labels (Dim solved_dim_full));
69+
Stdio.printf " Axis_size (full): %s\n" (dim_to_string Axis_size (Dim solved_dim_full));
70+
Stdio.printf " Projection_and_size (full): %s\n" (dim_to_string Projection_and_size (Dim solved_dim_full));
71+
Stdio.printf " Only_labels (minimal): %s\n" (dim_to_string Only_labels (Dim solved_dim_minimal));
72+
Stdio.printf " Axis_size (minimal): %s\n" (dim_to_string Axis_size (Dim solved_dim_minimal));
73+
74+
Stdio.printf "\nVariable dimensions:\n";
75+
Stdio.printf " Only_labels (labeled var): %s\n" (dim_to_string Only_labels (Var var_dim_labeled));
76+
Stdio.printf " Axis_size (labeled var): %s\n" (dim_to_string Axis_size (Var var_dim_labeled));
77+
Stdio.printf " Projection_and_size (labeled var): %s\n" (dim_to_string Projection_and_size (Var var_dim_labeled));
78+
Stdio.printf " Only_labels (unlabeled var): %s\n" (dim_to_string Only_labels (Var var_dim_unlabeled));
79+
Stdio.printf " Axis_size (unlabeled var): %s\n" (dim_to_string Axis_size (Var var_dim_unlabeled))
80+
81+
let test_shape_to_string () =
82+
let open Ocannl in
83+
84+
Stdio.printf "\n=== Testing Shape.to_string_hum ===\n";
85+
86+
(* Create a simple shape *)
87+
let shape = Shape.make
88+
~batch_dims:[1]
89+
~input_dims:[784]
90+
~output_dims:[10; 5]
91+
~debug_name:"test_shape"
92+
~id:42
93+
() in
94+
95+
Stdio.printf "Shape with batch=[1], input=[784], output=[10,5]:\n";
96+
Stdio.printf " Only_labels: %s\n" (Shape.to_string_hum ~style:Row.Only_labels shape);
97+
Stdio.printf " Axis_size: %s\n" (Shape.to_string_hum ~style:Row.Axis_size shape);
98+
Stdio.printf " Axis_number_and_size: %s\n" (Shape.to_string_hum ~style:Row.Axis_number_and_size shape);
99+
Stdio.printf " Projection_and_size: %s\n" (Shape.to_string_hum ~style:Row.Projection_and_size shape);
100+
101+
(* Test default style *)
102+
Stdio.printf " Default style: %s\n" (Shape.to_string_hum shape)
103+
104+
let () =
105+
test_print_styles ();
106+
test_shape_to_string ()

test_print_style.ml

Whitespace-only changes.

0 commit comments

Comments
 (0)