From b4b82b72c76f8dc04d19b5ecc29d5e66be89a128 Mon Sep 17 00:00:00 2001 From: Martin Jambon Date: Sat, 26 Nov 2011 00:34:09 -0800 Subject: [PATCH] Added granularity option to iter_stream functions (needs doc). Somewhat improved error handling (looks like a "kill" function would be nice) --- nproc.ml | 90 +++++++++++++++++++++++++++++++++++++++++++++------ nproc.mli | 2 ++ test_nproc.ml | 46 ++++++++++++++++++++++---- 3 files changed, 121 insertions(+), 17 deletions(-) diff --git a/nproc.ml b/nproc.ml index 2b0ce7c..d291f94 100644 --- a/nproc.ml +++ b/nproc.ml @@ -4,6 +4,25 @@ let log_error = ref (fun s -> eprintf "[err] %s\n%!" s) let log_info = ref (fun s -> eprintf "[info] %s\n%!" s) let string_of_exn = ref Printexc.to_string +(* Get the n first elements of the stream as a reversed list. *) +let rec npop acc n strm = + if n > 0 then + match Stream.peek strm with + None -> acc + | Some x -> + Stream.junk strm; + npop (x :: acc) (n-1) strm + else + acc + +(* Chunkify stream; each chunk is in reverse order. *) +let chunkify n strm = + Stream.from ( + fun _ -> + match npop [] n strm with + [] -> None + | l -> Some l + ) module Full = struct @@ -54,7 +73,16 @@ struct let result = try match Marshal.from_channel ic with - Worker_req (f, x) -> f central_service worker_data x + Worker_req (f, x) -> + (try f central_service worker_data x + with e -> + let msg = + sprintf "Exception raised by Nproc task: %s" + (!string_of_exn e) + in + !log_error msg; + exit 1 + ) | Central_res _ -> assert false with End_of_file -> @@ -103,12 +131,37 @@ struct (read_from_worker g) ) and read_from_worker g () = - Lwt.bind (Lwt_io.read_value ic) (handle_input g) + Lwt.try_bind + (fun () -> Lwt_io.read_value ic) + (handle_input g) + (fun e -> + let msg = + match e with + End_of_file -> + "Task failed (see error log)" + | e -> + sprintf "Cannot read from Nproc worker: exception %s" + (!string_of_exn e) + in + !log_error msg; + failwith msg + ) and handle_input g = function Worker_res result -> - g result; + (try + g result + with e -> + let msg = + sprintf "Error while handling result of Nproc task: \ + exception %s" + (!string_of_exn e) + in + !log_error msg; + failwith msg + ); pull () + | Central_req x -> Lwt.bind (central_service x) ( fun y -> @@ -232,12 +285,28 @@ struct Lwt.return elt ) - let iter_stream ~nproc ~serv ~env ~f ~g in_stream = - let task_stream = lwt_of_stream f g in_stream in - let p, t = - create_gen (task_stream, (fun _ -> assert false)) nproc serv env - in - Lwt_main.run t + let iter_stream + ?(granularity = 1) + ~nproc ~serv ~env ~f ~g in_stream = + + if granularity <= 0 then + invalid_arg (sprintf "Nproc.iter_stream: granularity=%i" granularity) + else + let task_stream = + if granularity = 1 then + lwt_of_stream f g in_stream + else + let in_stream' = chunkify granularity in_stream in + let f' central_service worker_data l = + List.rev_map (f central_service worker_data) l + in + let g' l = List.iter g l in + lwt_of_stream f' g' in_stream' + in + let p, t = + create_gen (task_stream, (fun _ -> assert false)) nproc serv env + in + Lwt_main.run t end @@ -251,8 +320,9 @@ let close = Full.close let submit p ~f x = Full.submit p (fun _ _ x -> f x) x -let iter_stream ~nproc ~f ~g strm = +let iter_stream ?granularity ~nproc ~f ~g strm = Full.iter_stream + ?granularity ~nproc ~env: () ~serv: (fun () -> Lwt.return ()) diff --git a/nproc.mli b/nproc.mli index 09aadab..f461651 100644 --- a/nproc.mli +++ b/nproc.mli @@ -39,6 +39,7 @@ val submit : t -> f: ('a -> 'b) -> 'a -> 'b Lwt.t *) val iter_stream : + ?granularity: int -> nproc: int -> f: ('a -> 'b) -> g: ('b -> unit) -> @@ -134,6 +135,7 @@ sig *) val iter_stream : + ?granularity: int -> nproc: int -> serv: ('serv_request -> 'serv_response Lwt.t) -> env: 'env -> diff --git a/test_nproc.ml b/test_nproc.ml index 2133823..df77886 100644 --- a/test_nproc.ml +++ b/test_nproc.ml @@ -1,6 +1,33 @@ +open Printf + +let test_error1 () = + let strm = Stream.from (fun i -> if i < 100 then Some i else None) in + try + Nproc.iter_stream + ~nproc: 8 + ~f: (fun n -> failwith "oops") + ~g: (fun _ -> assert false) + strm + with e -> + printf "OK - Caught exception as expected: %s\n" + (Printexc.to_string e) + +let test_error2 () = + let strm = Stream.from (fun i -> if i < 100 then Some i else None) in + try + Nproc.iter_stream + ~nproc: 8 + ~f: (fun n -> -n) + ~g: (fun n' -> failwith "oops") + strm + with e -> + printf "OK - Caught exception as expected: %s\n" + (Printexc.to_string e) + + let test1 () = - let l = Array.to_list (Array.init 1000 (fun i -> i)) in - let p, t = Nproc.create 100 in + let l = Array.to_list (Array.init 6 (fun i -> i)) in + let p, t = Nproc.create 2 in List.iter ( fun x -> ignore ( @@ -11,17 +38,22 @@ let test1 () = ) l; Lwt_main.run (Nproc.close p) -let test2 () = - let strm = Stream.from (fun i -> if i < 1000 then Some i else None) in +let test2 ?granularity () = + let strm = Stream.from (fun i -> if i < 6 then Some i else None) in Nproc.iter_stream - ~nproc: 100 + ~nproc: 2 ~f: (fun n -> Unix.sleep 1; (n, -n)) ~g: (fun (x, y) -> Printf.printf "%i -> %i\n%!" x y) strm let () = + print_endline "*** test error (1) ***"; + test_error1 (); + print_endline "*** test error (2) ***"; + test_error2 (); print_endline "*** test1 ***"; test1 (); print_endline "*** test2 ***"; - test2 () - + test2 (); + print_endline "*** test2 (granularity = 3) ***"; + test2 ~granularity:3 ()