Skip to content

Commit

Permalink
Added granularity option to iter_stream functions (needs doc).
Browse files Browse the repository at this point in the history
Somewhat improved error handling (looks like a "kill" function would be
nice)
  • Loading branch information
mjambon committed Nov 26, 2011
1 parent e4a9580 commit b4b82b7
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 17 deletions.
90 changes: 80 additions & 10 deletions nproc.ml
Expand Up @@ -4,6 +4,25 @@ let log_error = ref (fun s -> eprintf "[err] %s\n%!" s)
let log_info = ref (fun s -> eprintf "[info] %s\n%!" s)
let string_of_exn = ref Printexc.to_string

(* Get the n first elements of the stream as a reversed list. *)
let rec npop acc n strm =
if n > 0 then
match Stream.peek strm with
None -> acc
| Some x ->
Stream.junk strm;
npop (x :: acc) (n-1) strm
else
acc

(* Chunkify stream; each chunk is in reverse order. *)
let chunkify n strm =
Stream.from (
fun _ ->
match npop [] n strm with
[] -> None
| l -> Some l
)

module Full =
struct
Expand Down Expand Up @@ -54,7 +73,16 @@ struct
let result =
try
match Marshal.from_channel ic with
Worker_req (f, x) -> f central_service worker_data x
Worker_req (f, x) ->
(try f central_service worker_data x
with e ->
let msg =
sprintf "Exception raised by Nproc task: %s"
(!string_of_exn e)
in
!log_error msg;
exit 1
)
| Central_res _ -> assert false
with
End_of_file ->
Expand Down Expand Up @@ -103,12 +131,37 @@ struct
(read_from_worker g)
)
and read_from_worker g () =
Lwt.bind (Lwt_io.read_value ic) (handle_input g)
Lwt.try_bind
(fun () -> Lwt_io.read_value ic)
(handle_input g)
(fun e ->
let msg =
match e with
End_of_file ->
"Task failed (see error log)"
| e ->
sprintf "Cannot read from Nproc worker: exception %s"
(!string_of_exn e)
in
!log_error msg;
failwith msg
)

and handle_input g = function
Worker_res result ->
g result;
(try
g result
with e ->
let msg =
sprintf "Error while handling result of Nproc task: \
exception %s"
(!string_of_exn e)
in
!log_error msg;
failwith msg
);
pull ()

| Central_req x ->
Lwt.bind (central_service x) (
fun y ->
Expand Down Expand Up @@ -232,12 +285,28 @@ struct
Lwt.return elt
)

let iter_stream ~nproc ~serv ~env ~f ~g in_stream =
let task_stream = lwt_of_stream f g in_stream in
let p, t =
create_gen (task_stream, (fun _ -> assert false)) nproc serv env
in
Lwt_main.run t
let iter_stream
?(granularity = 1)
~nproc ~serv ~env ~f ~g in_stream =

if granularity <= 0 then
invalid_arg (sprintf "Nproc.iter_stream: granularity=%i" granularity)
else
let task_stream =
if granularity = 1 then
lwt_of_stream f g in_stream
else
let in_stream' = chunkify granularity in_stream in
let f' central_service worker_data l =
List.rev_map (f central_service worker_data) l
in
let g' l = List.iter g l in
lwt_of_stream f' g' in_stream'
in
let p, t =
create_gen (task_stream, (fun _ -> assert false)) nproc serv env
in
Lwt_main.run t
end


Expand All @@ -251,8 +320,9 @@ let close = Full.close
let submit p ~f x =
Full.submit p (fun _ _ x -> f x) x

let iter_stream ~nproc ~f ~g strm =
let iter_stream ?granularity ~nproc ~f ~g strm =
Full.iter_stream
?granularity
~nproc
~env: ()
~serv: (fun () -> Lwt.return ())
Expand Down
2 changes: 2 additions & 0 deletions nproc.mli
Expand Up @@ -39,6 +39,7 @@ val submit : t -> f: ('a -> 'b) -> 'a -> 'b Lwt.t
*)

val iter_stream :
?granularity: int ->
nproc: int ->
f: ('a -> 'b) ->
g: ('b -> unit) ->
Expand Down Expand Up @@ -134,6 +135,7 @@ sig
*)

val iter_stream :
?granularity: int ->
nproc: int ->
serv: ('serv_request -> 'serv_response Lwt.t) ->
env: 'env ->
Expand Down
46 changes: 39 additions & 7 deletions test_nproc.ml
@@ -1,6 +1,33 @@
open Printf

let test_error1 () =
let strm = Stream.from (fun i -> if i < 100 then Some i else None) in
try
Nproc.iter_stream
~nproc: 8
~f: (fun n -> failwith "oops")
~g: (fun _ -> assert false)
strm
with e ->
printf "OK - Caught exception as expected: %s\n"
(Printexc.to_string e)

let test_error2 () =
let strm = Stream.from (fun i -> if i < 100 then Some i else None) in
try
Nproc.iter_stream
~nproc: 8
~f: (fun n -> -n)
~g: (fun n' -> failwith "oops")
strm
with e ->
printf "OK - Caught exception as expected: %s\n"
(Printexc.to_string e)


let test1 () =
let l = Array.to_list (Array.init 1000 (fun i -> i)) in
let p, t = Nproc.create 100 in
let l = Array.to_list (Array.init 6 (fun i -> i)) in
let p, t = Nproc.create 2 in
List.iter (
fun x ->
ignore (
Expand All @@ -11,17 +38,22 @@ let test1 () =
) l;
Lwt_main.run (Nproc.close p)

let test2 () =
let strm = Stream.from (fun i -> if i < 1000 then Some i else None) in
let test2 ?granularity () =
let strm = Stream.from (fun i -> if i < 6 then Some i else None) in
Nproc.iter_stream
~nproc: 100
~nproc: 2
~f: (fun n -> Unix.sleep 1; (n, -n))
~g: (fun (x, y) -> Printf.printf "%i -> %i\n%!" x y)
strm

let () =
print_endline "*** test error (1) ***";
test_error1 ();
print_endline "*** test error (2) ***";
test_error2 ();
print_endline "*** test1 ***";
test1 ();
print_endline "*** test2 ***";
test2 ()

test2 ();
print_endline "*** test2 (granularity = 3) ***";
test2 ~granularity:3 ()

0 comments on commit b4b82b7

Please sign in to comment.