Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support wide ints in tape #1002

Merged
merged 13 commits into from
Aug 31, 2023
17 changes: 17 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6037,6 +6037,15 @@ const TapeTypes = Dict{String, DataType}()
base_type(T::UnionAll) = base_type(T.body)
base_type(T::DataType) = T

const WideIntWidths = [256, 512, 1024, 2048]

let
for n ∈ WideIntWidths
let T = Symbol(:UInt,n)
eval(quote primitive type $T <: Unsigned $n end end)
end
end
end
# return result and if contains any
function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool}
tkind = LLVM.API.LLVMGetTypeKind(Type)
Expand Down Expand Up @@ -6105,6 +6114,14 @@ function to_tape_type(Type::LLVM.API.LLVMTypeRef)::Tuple{DataType,Bool}
return UInt64, false
elseif N == 128
return UInt128, false
elseif N == 256
return UInt256, false
elseif N == 512
return UInt512, false
elseif N == 1024
return UInt1024, false
elseif N == 2048
return UInt2048, false
else
error("Can't construct tape type for integer of width $N")
motabbara marked this conversation as resolved.
Show resolved Hide resolved
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
27 changes: 27 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using ForwardDiff
using Aqua
using Statistics
using LinearAlgebra
using InlineStrings

using Enzyme_jll
@info "Testing against" Enzyme_jll.libEnzyme
Expand Down Expand Up @@ -2278,6 +2279,32 @@ end
@test ad_eta[1] ≈ 0.0
end

@testset "Tape Width" begin
struct Roo
x::Float64
bar::String63
end

struct Moo
x::Float64
bar::String63
end

function g(f)
return f.x*5.0
end

res = autodiff(Reverse, g, Active, Active(Roo(3.0, "a")))[1][1]

@test res.x == 5.0

if VERSION > v"1.10-"
res = autodiff(Reverse, g, Active, Active(Moo(3.0, "a")))[1][1]

@test res.x == 5.0
end
end

@testset "Type preservation" begin
# Float16 fails due to #870
for T in (Float64, Float32, #=Float16=#)
Expand Down
Loading