Skip to content

Commit a0f952d

Browse files
committed
Migrate MNIST and CIFAR10 to Genarrays
1 parent f7b2c47 commit a0f952d

File tree

2 files changed

+24
-37
lines changed

2 files changed

+24
-37
lines changed

datasets/cifar10.ml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ let read_cifar_batch filename =
3737
let images =
3838
Genarray.create int8_unsigned c_layout [| num_images; 32; 32; 3 |]
3939
in
40-
let labels = Array1.create int8_unsigned c_layout num_images in
40+
let labels = Genarray.create int8_unsigned c_layout [| num_images |] in
4141

4242
for i = 0 to num_images - 1 do
4343
let base_offset = i * bytes_per_image in
44-
labels.{i} <- Char.code s.[base_offset];
44+
Genarray.set labels [| i |] (Char.code s.[base_offset]);
4545
let r_offset = base_offset + 1 in
4646
let g_offset = r_offset + 1024 in
4747
let b_offset = g_offset + 1024 in
@@ -76,12 +76,12 @@ let load () =
7676
let train_images =
7777
Genarray.create int8_unsigned c_layout [| total_train_images; 32; 32; 3 |]
7878
in
79-
let train_labels = Array1.create int8_unsigned c_layout total_train_images in
79+
let train_labels = Genarray.create int8_unsigned c_layout [| total_train_images |] in
8080

8181
let current_offset = ref 0 in
8282
List.iter
8383
(fun (batch_images, batch_labels) ->
84-
let batch_size = Array1.dim batch_labels in
84+
let batch_size = (Genarray.dims batch_labels).(0) in
8585
let img_slice_dims = [| batch_size; 32; 32; 3 |] in
8686
let img_slice =
8787
Genarray.sub_left train_images !current_offset batch_size
@@ -98,9 +98,9 @@ let load () =
9898
(Array.to_list
9999
(Array.map string_of_int (Genarray.dims img_slice)))));
100100

101-
let lbl_slice = Array1.sub train_labels !current_offset batch_size in
101+
let lbl_slice = Genarray.sub_left train_labels !current_offset batch_size in
102102
Genarray.blit batch_images img_slice;
103-
Array1.blit batch_labels lbl_slice;
103+
Genarray.blit batch_labels lbl_slice;
104104
current_offset := !current_offset + batch_size)
105105
train_batches_data;
106106

datasets/mnist.ml

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,10 @@ module Config = struct
1919
{
2020
name = "MNIST";
2121
cache_subdir = "mnist/";
22-
train_images_url =
23-
"https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz";
24-
train_labels_url =
25-
"https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz";
26-
test_images_url =
27-
"https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz";
28-
test_labels_url =
29-
"https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz";
22+
train_images_url = "https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz";
23+
train_labels_url = "https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz";
24+
test_images_url = "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz";
25+
test_labels_url = "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz";
3026
image_magic_number = 2051;
3127
label_magic_number = 2049;
3228
}
@@ -79,8 +75,7 @@ let ensure_dataset config =
7975
let path = dataset_dir ^ base_filename in
8076

8177
if not (Sys.file_exists path) then (
82-
Printf.printf "File %s not found for %s dataset.\n%!" base_filename
83-
config.name;
78+
Printf.printf "File %s not found for %s dataset.\n%!" base_filename config.name;
8479
(* Ensure the .gz file is downloaded *)
8580
ensure_file url gz_path;
8681
(* Ensure it's decompressed *)
@@ -89,25 +84,21 @@ let ensure_dataset config =
8984
else Printf.printf "Found decompressed file %s.\n%!" path)
9085
files_to_process
9186

92-
let read_idx_file ~read_header ~create_array ~populate_array ~expected_magic
93-
config filename =
87+
let read_idx_file ~read_header ~create_array ~populate_array ~expected_magic config filename =
9488
Printf.printf "Reading %s file: %s\n%!" config.Config.name filename;
9589
let ic = open_in_bin filename in
9690
let s =
9791
try really_input_string ic (in_channel_length ic)
9892
with exn ->
9993
close_in_noerr ic;
100-
failwith
101-
(Printf.sprintf "Error reading file %s: %s" filename
102-
(Printexc.to_string exn))
94+
failwith (Printf.sprintf "Error reading file %s: %s" filename (Printexc.to_string exn))
10395
in
10496
close_in ic;
10597

10698
let magic = read_int32_be s 0 in
10799
if magic <> expected_magic then
108100
failwith
109-
(Printf.sprintf "Invalid magic number %d in %s (expected %d)" magic
110-
filename expected_magic);
101+
(Printf.sprintf "Invalid magic number %d in %s (expected %d)" magic filename expected_magic);
111102

112103
let dimensions, data_offset = read_header s in
113104
let total_items, data_len =
@@ -119,38 +110,34 @@ let read_idx_file ~read_header ~create_array ~populate_array ~expected_magic
119110
let expected_len = data_offset + data_len in
120111
if String.length s <> expected_len then
121112
failwith
122-
(Printf.sprintf
123-
"File %s has unexpected length: %d vs %d (header offset %d, data len \
124-
%d)"
113+
(Printf.sprintf "File %s has unexpected length: %d vs %d (header offset %d, data len %d)"
125114
filename (String.length s) expected_len data_offset data_len);
126115

127116
let arr = create_array dimensions in
128117
populate_array arr s data_offset total_items;
129118
arr
130119

131-
(* read_images and read_labels remain largely the same, just use the config
132-
passed in *)
120+
(* read_images and read_labels remain largely the same, just use the config passed in *)
133121
let read_images config filename =
134122
let read_header s =
135123
let num_images = read_int32_be s 4 in
136124
let num_rows = read_int32_be s 8 in
137125
let num_cols = read_int32_be s 12 in
138126
([| num_images; num_rows; num_cols |], 16)
139127
in
140-
let create_array dims =
141-
Array3.create int8_unsigned c_layout dims.(0) dims.(1) dims.(2)
142-
in
128+
let create_array dims = Genarray.create int8_unsigned c_layout dims in
143129
let populate_array arr s offset _ =
144-
let num_images = Array3.dim1 arr in
145-
let num_rows = Array3.dim2 arr in
146-
let num_cols = Array3.dim3 arr in
130+
let dims = Genarray.dims arr in
131+
let num_images = dims.(0) in
132+
let num_rows = dims.(1) in
133+
let num_cols = dims.(2) in
147134
let img_size = num_rows * num_cols in
148135
for i = 0 to num_images - 1 do
149136
let start_pos = offset + (i * img_size) in
150137
for r = 0 to num_rows - 1 do
151138
for c = 0 to num_cols - 1 do
152139
let pos = start_pos + (r * num_cols) + c in
153-
arr.{i, r, c} <- Char.code s.[pos]
140+
Genarray.set arr [| i; r; c |] (Char.code s.[pos])
154141
done
155142
done
156143
done
@@ -163,10 +150,10 @@ let read_labels config filename =
163150
let num_labels = read_int32_be s 4 in
164151
([| num_labels |], 8)
165152
in
166-
let create_array dims = Array1.create int8_unsigned c_layout dims.(0) in
153+
let create_array dims = Genarray.create int8_unsigned c_layout dims in
167154
let populate_array arr s offset total_items =
168155
for i = 0 to total_items - 1 do
169-
arr.{i} <- Char.code s.[offset + i]
156+
Genarray.set arr [| i |] (Char.code s.[offset + i])
170157
done
171158
in
172159
read_idx_file ~read_header ~create_array ~populate_array

0 commit comments

Comments
 (0)