-
Notifications
You must be signed in to change notification settings - Fork 2
/
basic.jl
79 lines (54 loc) · 1.44 KB
/
basic.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
using Reactant
using Test
using Enzyme
fastmax(x::AbstractArray{T}) where T = reduce(max, x; dims=1, init = float(T)(-Inf))
@testset "Basic reduce max" begin
r_res = fastmax(ones(2, 10))
a = Reactant.ConcreteRArray(ones(2, 10))
c_res = fastmax(a)
@test c_res ≈ r_res
f=Reactant.compile(fastmax, (a,))
f_res = f(a)
@test f_res ≈ r_res
end
function softmax!(x)
max_ = fastmax(x)
return x .- max_
end
@testset "Basic softmax" begin
in = ones(2, 10)
r_res = softmax!(in)
in = Reactant.ConcreteRArray(ones(2, 10))
f=Reactant.compile(softmax!, (in,))
f_res = f(in)
@test f_res ≈ r_res
end
@testset "Basic cos" begin
c = Reactant.ConcreteRArray(ones(3,2))
f=Reactant.compile(cos, (c,))
r = f(c)
@test r ≈ cos.(ones(3,2))
end
function sumcos(x)
return sum(cos.(x))
end
function grad_ip(x)
dx = Enzyme.make_zero(x)
Enzyme.autodiff(Reverse, sumcos, Active, Duplicated(x, dx))
return dx
end
function resgrad_ip(x)
dx = Enzyme.make_zero(x)
res = Enzyme.autodiff(ReverseWithPrimal, sumcos, Active, Duplicated(x, dx))
return (res, dx)
end
@testset "Basic grad cos" begin
c = Reactant.ConcreteRArray(ones(3,2))
f=Reactant.compile(grad_ip, (c,))
r = f(c)
@test r ≈ -sin.(ones(3,2))
f=Reactant.compile(resgrad_ip, (c,))
orig, r = f(c)
@test orig[2] ≈ sum(cos.(ones(3,2)))
@test r ≈ -sin.(ones(3,2))
end