@@ -19,14 +19,10 @@ module Config = struct
1919 {
2020 name = " MNIST" ;
2121 cache_subdir = " mnist/" ;
22- train_images_url =
23- " https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz" ;
24- train_labels_url =
25- " https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz" ;
26- test_images_url =
27- " https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz" ;
28- test_labels_url =
29- " https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz" ;
22+ train_images_url = " https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz" ;
23+ train_labels_url = " https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz" ;
24+ test_images_url = " https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz" ;
25+ test_labels_url = " https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz" ;
3026 image_magic_number = 2051 ;
3127 label_magic_number = 2049 ;
3228 }
@@ -79,8 +75,7 @@ let ensure_dataset config =
7975 let path = dataset_dir ^ base_filename in
8076
8177 if not (Sys. file_exists path) then (
82- Printf. printf " File %s not found for %s dataset.\n %!" base_filename
83- config.name;
78+ Printf. printf " File %s not found for %s dataset.\n %!" base_filename config.name;
8479 (* Ensure the .gz file is downloaded *)
8580 ensure_file url gz_path;
8681 (* Ensure it's decompressed *)
@@ -89,25 +84,21 @@ let ensure_dataset config =
8984 else Printf. printf " Found decompressed file %s.\n %!" path)
9085 files_to_process
9186
92- let read_idx_file ~read_header ~create_array ~populate_array ~expected_magic
93- config filename =
87+ let read_idx_file ~read_header ~create_array ~populate_array ~expected_magic config filename =
9488 Printf. printf " Reading %s file: %s\n %!" config.Config. name filename;
9589 let ic = open_in_bin filename in
9690 let s =
9791 try really_input_string ic (in_channel_length ic)
9892 with exn ->
9993 close_in_noerr ic;
100- failwith
101- (Printf. sprintf " Error reading file %s: %s" filename
102- (Printexc. to_string exn ))
94+ failwith (Printf. sprintf " Error reading file %s: %s" filename (Printexc. to_string exn ))
10395 in
10496 close_in ic;
10597
10698 let magic = read_int32_be s 0 in
10799 if magic <> expected_magic then
108100 failwith
109- (Printf. sprintf " Invalid magic number %d in %s (expected %d)" magic
110- filename expected_magic);
101+ (Printf. sprintf " Invalid magic number %d in %s (expected %d)" magic filename expected_magic);
111102
112103 let dimensions, data_offset = read_header s in
113104 let total_items, data_len =
@@ -119,38 +110,34 @@ let read_idx_file ~read_header ~create_array ~populate_array ~expected_magic
119110 let expected_len = data_offset + data_len in
120111 if String. length s <> expected_len then
121112 failwith
122- (Printf. sprintf
123- " File %s has unexpected length: %d vs %d (header offset %d, data len \
124- %d)"
113+ (Printf. sprintf " File %s has unexpected length: %d vs %d (header offset %d, data len %d)"
125114 filename (String. length s) expected_len data_offset data_len);
126115
127116 let arr = create_array dimensions in
128117 populate_array arr s data_offset total_items;
129118 arr
130119
131- (* read_images and read_labels remain largely the same, just use the config
132- passed in *)
120+ (* read_images and read_labels remain largely the same, just use the config passed in *)
133121let read_images config filename =
134122 let read_header s =
135123 let num_images = read_int32_be s 4 in
136124 let num_rows = read_int32_be s 8 in
137125 let num_cols = read_int32_be s 12 in
138126 ([| num_images; num_rows; num_cols |], 16 )
139127 in
140- let create_array dims =
141- Array3. create int8_unsigned c_layout dims.(0 ) dims.(1 ) dims.(2 )
142- in
128+ let create_array dims = Genarray. create int8_unsigned c_layout dims in
143129 let populate_array arr s offset _ =
144- let num_images = Array3. dim1 arr in
145- let num_rows = Array3. dim2 arr in
146- let num_cols = Array3. dim3 arr in
130+ let dims = Genarray. dims arr in
131+ let num_images = dims.(0 ) in
132+ let num_rows = dims.(1 ) in
133+ let num_cols = dims.(2 ) in
147134 let img_size = num_rows * num_cols in
148135 for i = 0 to num_images - 1 do
149136 let start_pos = offset + (i * img_size) in
150137 for r = 0 to num_rows - 1 do
151138 for c = 0 to num_cols - 1 do
152139 let pos = start_pos + (r * num_cols) + c in
153- arr.{i, r, c} < - Char. code s.[pos]
140+ Genarray. set arr [| i; r; c |] ( Char. code s.[pos])
154141 done
155142 done
156143 done
@@ -163,10 +150,10 @@ let read_labels config filename =
163150 let num_labels = read_int32_be s 4 in
164151 ([| num_labels |], 8 )
165152 in
166- let create_array dims = Array1 . create int8_unsigned c_layout dims.( 0 ) in
153+ let create_array dims = Genarray . create int8_unsigned c_layout dims in
167154 let populate_array arr s offset total_items =
168155 for i = 0 to total_items - 1 do
169- arr.{i} < - Char. code s.[offset + i]
156+ Genarray. set arr [| i |] ( Char. code s.[offset + i])
170157 done
171158 in
172159 read_idx_file ~read_header ~create_array ~populate_array
0 commit comments