Skip to content

Commit 2af41be

Browse files
committed
Fixes #295: always create new modules for fresh_backend to never leak any caches
1 parent 77b3395 commit 2af41be

File tree

4 files changed

+105
-137
lines changed

4 files changed

+105
-137
lines changed

arrayjit/lib/backends.ml

Lines changed: 17 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,6 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
442442
(context, Some r))
443443
end
444444

445-
module Cuda_backend : Backend = Raise_backend ((Cuda_backend : Lowered_backend))
446-
447445
module Make_device_backend_from_lowered
448446
(Add_scheduler : functor
449447
(Impl : For_add_scheduler)
@@ -455,36 +453,6 @@ struct
455453
include Backend_device
456454
end
457455

458-
module Cc_multicore = Make_device_backend_from_lowered (Schedulers.Multicore) (Cc_backend)
459-
module Gcc_multicore = Make_device_backend_from_lowered (Schedulers.Multicore) (Gcc_backend)
460-
module Cc_sync = Make_device_backend_from_lowered (Schedulers.Sync) (Cc_backend)
461-
module Gcc_sync = Make_device_backend_from_lowered (Schedulers.Sync) (Gcc_backend)
462-
463-
let%track5_sexp reinitialize (module Backend : Backend) config =
464-
if not @@ Backend.is_initialized () then Backend.initialize config
465-
else (
466-
[%log "reinitialize: cleanup devices"];
467-
for ordinal = 0 to Backend.num_devices () - 1 do
468-
Backend.(sync_device @@ get_device ~ordinal)
469-
done;
470-
[%log "reinitialize: efore full_major"];
471-
Stdlib.Gc.full_major ();
472-
[%log "reinitialize: cleanup devices 2"];
473-
(* TODO: does this do anything? *)
474-
for ordinal = 0 to Backend.num_devices () - 1 do
475-
Backend.(sync_device @@ get_device ~ordinal)
476-
done;
477-
[%log "reinitialize: after cleanup 2"];
478-
(* This ensures cleanliness of the streams weak arrays. *)
479-
Stdlib.Gc.full_major ();
480-
[%log "reinitialize: after full_major 2"];
481-
for ordinal = 0 to Backend.num_devices () - 1 do
482-
let device = Backend.get_device ~ordinal in
483-
Utils.weak_iter device.streams ~f:(fun _stream -> assert false)
484-
done;
485-
[%log "reinitialize: after checking devices"];
486-
Backend.initialize config)
487-
488456
let finalize (type buffer_ptr dev runner event)
489457
(module Backend : Backend
490458
with type buffer_ptr = buffer_ptr
@@ -503,19 +471,20 @@ let finalize (type buffer_ptr dev runner event)
503471
&& not (Hashtbl.mem ctx.stream.device.cross_stream_candidates key)
504472
then mem_free ctx.stream data)))
505473

506-
let%track5_sexp fresh_backend ?backend_name ?(config = Only_devices_parallel) () =
507-
let backend =
508-
match
509-
Option.value_or_thunk backend_name ~default:(fun () ->
510-
Utils.get_global_arg ~arg_name:"backend" ~default:"cc")
511-
|> String.lowercase
512-
with
513-
| "cc" -> (module Cc_multicore : Backend)
514-
| "gccjit" -> (module Gcc_multicore : Backend)
515-
| "sync_cc" -> (module Cc_sync : Backend)
516-
| "sync_gccjit" -> (module Gcc_sync : Backend)
517-
| "cuda" -> (module Cuda_backend : Backend)
518-
| backend -> invalid_arg [%string "Backends.fresh_backend: unknown backend %{backend}"]
519-
in
520-
reinitialize backend config;
521-
backend
474+
let%track5_sexp fresh_backend ?backend_name () =
475+
Stdlib.Gc.full_major ();
476+
(* TODO: is running again needed to give time to weak arrays to become empty? *)
477+
Stdlib.Gc.full_major ();
478+
match
479+
Option.value_or_thunk backend_name ~default:(fun () ->
480+
Utils.get_global_arg ~arg_name:"backend" ~default:"cc")
481+
|> String.lowercase
482+
with
483+
| "cc" -> (module Make_device_backend_from_lowered (Schedulers.Multicore) (Cc_backend) : Backend)
484+
| "gccjit" ->
485+
(module Make_device_backend_from_lowered (Schedulers.Multicore) (Gcc_backend) : Backend)
486+
| "sync_cc" -> (module Make_device_backend_from_lowered (Schedulers.Sync) (Cc_backend) : Backend)
487+
| "sync_gccjit" ->
488+
(module Make_device_backend_from_lowered (Schedulers.Sync) (Gcc_backend) : Backend)
489+
| "cuda" -> (module Raise_backend ((Cuda_backend : Lowered_backend)) : Backend)
490+
| backend -> invalid_arg [%string "Backends.fresh_backend: unknown backend %{backend}"]

arrayjit/lib/backends.mli

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22

33
open Base
44

5-
val reinitialize : (module Backend_intf.Backend) -> Backend_intf.config -> unit
6-
(** Initializes the backend, and if it was already initialized, performs garbage collection. *)
7-
85
val finalize :
96
'buffer_ptr 'dev 'runner 'event.
107
(module Backend_intf.Backend
@@ -21,6 +18,6 @@ val finalize :
2118
Note: this type will get simpler with modular explicits. *)
2219

2320
val fresh_backend :
24-
?backend_name:string -> ?config:Backend_intf.config -> unit -> (module Backend_intf.Backend)
25-
(** Reinitializes and returns a backend corresponding to [backend_name], or if omitted, selected via
26-
the global [backend] setting. See {!reinitialize}. *)
21+
?backend_name:string -> unit -> (module Backend_intf.Backend)
22+
(** Creates a new backend corresponding to [backend_name], or if omitted, selected via the global
23+
[backend] setting. *)

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ let is_initialized, initialize =
9393
((fun () -> !initialized), init)
9494

9595
let num_devices = Cu.Device.get_count
96+
97+
(* TODO: this doesn't need to be weak array. *)
9698
let devices = ref @@ Stdlib.Weak.create 0
9799

98100
(* Unlike [devices] above, [initialized_devices] never forgets its entries. *)

test/zero2hero_1of7.ml

Lines changed: 83 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -274,22 +274,22 @@ let%expect_test "Simple gradients hosted" =
274274
Tensor.print_tree ~with_grad:true ~depth:9 l;
275275
[%expect
276276
{|
277-
#12 *._l Host&dev/41
278-
<not-in-yet>
279-
#13 grad_*._l Host&dev/41
280-
<not-in-yet>
281-
#8 +_d Host&dev/41 │#10 f Host-non-const/24
282-
<not-in-yet><not-in-yet>
283-
#9 grad_+_d Host&dev/41 │#11 grad_f Host&dev/41
284-
<not-in-yet><not-in-yet>
285-
#4 *._e Host&dev/41 │#6 c Host-non-const/24
286-
<not-in-yet><not-in-yet>
287-
#5 grad_*._e Host&dev/41 │#7 grad_c Host&dev/41
288-
<not-in-yet><not-in-yet>
289-
#0 a Host-non-const/24│#2 b Host-non-const/24
290-
<not-in-yet><not-in-yet>
291-
#1 grad_a Host&dev/41 │#3 grad_b Host&dev/41
292-
<not-in-yet><not-in-yet>
277+
#12 *._l Host&stream/41
278+
<not-in-yet>
279+
#13 grad_*._l Host&stream/41
280+
<not-in-yet>
281+
#8 +_d Host&stream/41 │#10 f Host&shared/39
282+
<not-in-yet> <not-in-yet>
283+
#9 grad_+_d Host&stream/41 │#11 grad_f Host&stream/41
284+
<not-in-yet> <not-in-yet>
285+
#4 *._e Host&stream/41 │#6 c Host&shared/39
286+
<not-in-yet> <not-in-yet>
287+
#5 grad_*._e Host&stream/41 │#7 grad_c Host&stream/41
288+
<not-in-yet> <not-in-yet>
289+
#0 a Host&shared/39 │#2 b Host&shared/39
290+
<not-in-yet> <not-in-yet>
291+
#1 grad_a Host&stream/41│#3 grad_b Host&stream/41
292+
<not-in-yet> <not-in-yet>
293293
|}];
294294
(* Do not update the params: all values and gradients will be at initial points, which are
295295
specified in the tensor in the brackets. *)
@@ -411,45 +411,45 @@ let%expect_test "Simple gradients virtual" =
411411
Tensor.print_tree ~with_grad:true ~depth:9 l;
412412
[%expect
413413
{|
414-
#12 *._l Host&dev/41
415-
<not-in-yet>
416-
#13 grad_*._l Virt/40
417-
<not-in-yet>
418-
#8 +_d Local/50 │#10 f Host-non-const/24
419-
<not-in-yet><not-in-yet>
420-
#9 grad_+_d Virt/40 │#11 grad_f On-dev/50
421-
<not-in-yet><not-in-yet>
422-
#4 *._e Virt/152 │#6 c Host-non-const/24
423-
<not-in-yet><not-in-yet>
424-
#5 grad_*._e Virt/40 │#7 grad_c On-dev/50
425-
<not-in-yet><not-in-yet>
426-
#0 a Host-non-const/24│#2 b Host-non-const/24
427-
<not-in-yet><not-in-yet>
428-
#1 grad_a On-dev/50 │#3 grad_b On-dev/50
429-
<not-in-yet><not-in-yet>
414+
#12 *._l Host&stream/41
415+
<not-in-yet>
416+
#13 grad_*._l Virt/40
417+
<not-in-yet>
418+
#8 +_d Local/46 │#10 f Host&shared/39
419+
<not-in-yet> <not-in-yet>
420+
#9 grad_+_d Virt/40 │#11 grad_f Dev-stream/41
421+
<not-in-yet> <not-in-yet>
422+
#4 *._e Virt/152 │#6 c Host&shared/39
423+
<not-in-yet> <not-in-yet>
424+
#5 grad_*._e Virt/40 │#7 grad_c Dev-stream/41
425+
<not-in-yet> <not-in-yet>
426+
#0 a Host&shared/39 │#2 b Host&shared/39
427+
<not-in-yet> <not-in-yet>
428+
#1 grad_a Dev-stream/41│#3 grad_b Dev-stream/41
429+
<not-in-yet> <not-in-yet>
430430
|}];
431431
(* Do not update the params: all values and gradients will be at initial points, which are
432432
specified in the tensor in the brackets. *)
433433
Train.sync_run backend grad_routine l;
434434
Tensor.print_tree ~with_grad:true ~depth:9 l;
435435
[%expect
436436
{|
437-
#12 *._l
438-
-8.00e+0
439-
#13 grad_*._l Virt/40
440-
<void>
441-
#8 +_d Local/50 │#10 f
442-
<void>-2.00e+0
443-
#9 grad_+_d Virt/40 │#11 grad_f On-dev/50
444-
<void><void>
445-
#4 *._e Virt/152 │#6 c │
446-
<void>1.00e+1
447-
#5 grad_*._e Virt/40 │#7 grad_c On-dev/50
448-
<void><void>
449-
#0 a │#2 b
450-
2.00e+0-3.00e+0
451-
#1 grad_a On-dev/50│#3 grad_b On-dev/50
452-
<void><void>
437+
#12 *._l
438+
-8.00e+0
439+
#13 grad_*._l Virt/40
440+
<void>
441+
#8 +_d Local/46 │#10 f
442+
<void> -2.00e+0
443+
#9 grad_+_d Virt/40 │#11 grad_f Dev-stream/41
444+
<void> <void>
445+
#4 *._e Virt/152 │#6 c
446+
<void> 1.00e+1
447+
#5 grad_*._e Virt/40 │#7 grad_c Dev-stream/41
448+
<void> <void>
449+
#0 a │#2 b
450+
2.00e+0 -3.00e+0
451+
#1 grad_a Dev-stream/41│#3 grad_b Dev-stream/41
452+
<void> <void>
453453
|}];
454454
(* Only now compile the SGD update. *)
455455
let sgd_routine = Train.to_routine (module Backend) grad_routine.context IDX.empty sgd in
@@ -460,45 +460,45 @@ let%expect_test "Simple gradients virtual" =
460460
Tensor.print_tree ~with_grad:true ~depth:9 l;
461461
[%expect
462462
{|
463-
#12 *._l
464-
-8.00e+0
465-
#13 grad_*._l Virt/40
466-
<void>
467-
#8 +_d Local/50 │#10 f
468-
<void>-2.40e+0
469-
#9 grad_+_d Virt/40 │#11 grad_f On-dev/50
470-
<void><void>
471-
#4 *._e Virt/152 │#6 c │
472-
<void>1.02e+1
473-
#5 grad_*._e Virt/40 │#7 grad_c On-dev/50
474-
<void><void>
475-
#0 a │#2 b
476-
1.40e+0-2.60e+0
477-
#1 grad_a On-dev/50│#3 grad_b On-dev/50
478-
<void><void>
463+
#12 *._l
464+
-8.00e+0
465+
#13 grad_*._l Virt/40
466+
<void>
467+
#8 +_d Local/46 │#10 f
468+
<void> -2.40e+0
469+
#9 grad_+_d Virt/40 │#11 grad_f Dev-stream/41
470+
<void> <void>
471+
#4 *._e Virt/152 │#6 c
472+
<void> 1.02e+1
473+
#5 grad_*._e Virt/40 │#7 grad_c Dev-stream/41
474+
<void> <void>
475+
#0 a │#2 b
476+
1.40e+0 -2.60e+0
477+
#1 grad_a Dev-stream/41│#3 grad_b Dev-stream/41
478+
<void> <void>
479479
|}];
480480
(* Now the params will remain as above, but both param gradients and the values and gradients of
481481
other nodes will change thanks to the forward and backward passes. *)
482482
Train.sync_run backend grad_routine l;
483483
Tensor.print_tree ~with_grad:true ~depth:9 l;
484484
[%expect
485485
{|
486-
#12 *._l
487-
-1.57e+1
488-
#13 grad_*._l Virt/40
489-
<void>
490-
#8 +_d Local/50 │#10 f
491-
<void>-2.40e+0
492-
#9 grad_+_d Virt/40 │#11 grad_f On-dev/50
493-
<void><void>
494-
#4 *._e Virt/152 │#6 c │
495-
<void>1.02e+1
496-
#5 grad_*._e Virt/40 │#7 grad_c On-dev/50
497-
<void><void>
498-
#0 a │#2 b
499-
1.40e+0-2.60e+0
500-
#1 grad_a On-dev/50│#3 grad_b On-dev/50
501-
<void><void>
486+
#12 *._l
487+
-1.57e+1
488+
#13 grad_*._l Virt/40
489+
<void>
490+
#8 +_d Local/46 │#10 f
491+
<void> -2.40e+0
492+
#9 grad_+_d Virt/40 │#11 grad_f Dev-stream/41
493+
<void> <void>
494+
#4 *._e Virt/152 │#6 c
495+
<void> 1.02e+1
496+
#5 grad_*._e Virt/40 │#7 grad_c Dev-stream/41
497+
<void> <void>
498+
#0 a │#2 b
499+
1.40e+0 -2.60e+0
500+
#1 grad_a Dev-stream/41│#3 grad_b Dev-stream/41
501+
<void> <void>
502502
|}]
503503

504504
let%expect_test "tanh plot" =
@@ -565,12 +565,12 @@ let%expect_test "2D neuron virtual" =
565565
7.00e-1
566566
#9 grad_+_v Virt/40
567567
<void>
568-
#6 * Local/50 │#0 b
568+
#6 * Local/46 │#0 b
569569
<void>6.70e+0
570-
#7 grad_* Virt/40 │#1 grad_b Local/50
570+
#7 grad_* Virt/40 │#1 grad_b Local/46
571571
<void><void>
572572
#2 w │#4 x │
573573
-3.00e+0 1.00e+02.00e+0 0.00e+0
574-
#3 grad_w Local/50 │#5 grad_x Local/50
574+
#3 grad_w Local/46 │#5 grad_x Local/46
575575
<void><void>
576576
|}]

0 commit comments

Comments
 (0)