Skip to content

Commit c88f239

Browse files
committed
Fixes #387: normal distribution
1 parent e3a898f commit c88f239

33 files changed

+53
-37
lines changed

bin/compilation_speed.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ open Ocannl
33
module Tn = Ir.Tnode
44
module IDX = Train.IDX
55
module CDSL = Train.CDSL
6-
open Operation.DSL_modules
6+
open Nn_blocks.DSL_modules
77

88
(* FIXME: expose backend by name from Context *)
99

lib/nn_blocks.ml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,22 @@ open! Base
2020
open Ocannl_tensor.Operation.DSL_modules
2121
module Tn = Ir.Tnode
2222

23+
let%op box_muller grad_spec init_f () =
24+
let epsilon = [%oc Float.ldexp 1. (-24)] in
25+
let u1 = init_f () in
26+
let u2 = init_f () in
27+
Ocannl_tensor.Operation.pointmul ~grad_spec
28+
(sqrt (-2. *. log (u1 + (!.epsilon *. (1. - u1)))))
29+
(cos (2. *. !.Float.pi *. u2))
30+
31+
[%%extend_dsls
32+
let normal () = [%oc box_muller grad_spec uniform ()]
33+
let normal1 () = [%oc box_muller grad_spec uniform1 ()]
34+
let normal_at counter = [%oc box_muller grad_spec (fun () -> uniform_at counter) ()]
35+
let normal_at1 counter = [%oc box_muller grad_spec (fun () -> uniform_at1 counter) ()]]
36+
37+
open DSL_modules
38+
2339
let%op mlp_layer ~label ~hid_dim () x = relu (({ w } * x) + { b = 0.; o = [ hid_dim ] })
2440

2541
(** Masks and scales by 1/keep_prob to maintain expected value. When [train_step = None], the

test/einsum/einsum_trivia.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ open Ocannl
33
module Tn = Ir.Tnode
44
module IDX = Train.IDX
55
module CDSL = Train.CDSL
6-
open Operation.DSL_modules
6+
open Nn_blocks.DSL_modules
77

88
module type Backend = Ir.Backend_intf.Backend
99

test/einsum/einsum_trivia_exec.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ open Ocannl
33
module Tn = Ir.Tnode
44
module IDX = Train.IDX
55
module CDSL = Train.CDSL
6-
open Operation.DSL_modules
6+
open Nn_blocks.DSL_modules
77

88
module type Backend = Ir.Backend_intf.Backend
99

test/einsum/inline_permuted_view.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ open Base
22
open Ocannl
33
module IDX = Train.IDX
44
module CDSL = Train.CDSL
5-
open Operation.DSL_modules
5+
open Nn_blocks.DSL_modules
66

77
module type Backend = Ir.Backend_intf.Backend
88

test/einsum/moons_demo_variant.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
open Base
22
open Ocannl
33
module IDX = Train.IDX
4-
open Operation.DSL_modules
4+
open Nn_blocks.DSL_modules
55
module CDSL = Train.CDSL
66
module Asgns = Ir.Assignments
77
module Tn = Ir.Tnode

test/einsum/surjectivity.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ open Base
22
open Ocannl
33
module IDX = Train.IDX
44
module CDSL = Train.CDSL
5-
open Operation.DSL_modules
5+
open Nn_blocks.DSL_modules
66

77
module type Backend = Ir.Backend_intf.Backend
88

test/einsum/test_accumulation_semantics.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ open Base
22
open Ocannl
33
module IDX = Train.IDX
44
module CDSL = Train.CDSL
5-
open Operation.DSL_modules
5+
open Nn_blocks.DSL_modules
66

77
module type Backend = Ir.Backend_intf.Backend
88

test/einsum/test_einsum_capture.ml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ open Base
22
open Ocannl
33

44
let capture_for_computation () =
5-
let open Operation.DSL_modules in
5+
let open Nn_blocks.DSL_modules in
66
Tensor.unsafe_reinitialize ();
77
let ctx = Context.auto () in
88
let%op x = { x = uniform1 (); o = [ 2; 3 ] } in
@@ -51,7 +51,7 @@ let capture_for_computation () =
5151
Train.printf ~here:[%here] ~with_code:false ~with_grad:false dim_calc
5252

5353
let test_set_dim_and_set_equal () =
54-
let open Operation.DSL_modules in
54+
let open Nn_blocks.DSL_modules in
5555
Tensor.unsafe_reinitialize ();
5656
Stdio.printf "\n=== Testing set_dim and set_equal functionality ===\n";
5757

@@ -139,7 +139,7 @@ let test_set_dim_and_set_equal () =
139139
Stdio.printf "=== All tests completed ===\n"
140140

141141
let capture_for_shape_validation () =
142-
let open Operation.DSL_modules in
142+
let open Nn_blocks.DSL_modules in
143143
Tensor.unsafe_reinitialize ();
144144
Shape.unsafe_reinitialize ();
145145
let ctx = Context.auto () in
@@ -253,7 +253,7 @@ let capture_for_shape_validation () =
253253
Stdio.printf "=== Shape inference integration tests completed ===\n"
254254

255255
let capture_for_shape_inference () =
256-
let open Operation.DSL_modules in
256+
let open Nn_blocks.DSL_modules in
257257
Tensor.unsafe_reinitialize ();
258258
Shape.unsafe_reinitialize ();
259259
let ctx = Context.auto () in

test/einsum/test_surjectivity.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ open Base
22
open Ocannl
33
module IDX = Train.IDX
44
module CDSL = Train.CDSL
5-
open Operation.DSL_modules
5+
open Nn_blocks.DSL_modules
66

77
module type Backend = Ir.Backend_intf.Backend
88

0 commit comments

Comments
 (0)