Skip to content

Commit

Permalink
Merge 6d090c9 into cf9bc00
Browse files Browse the repository at this point in the history
  • Loading branch information
kleinhenz committed Aug 21, 2019
2 parents cf9bc00 + 6d090c9 commit a5c7e3f
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 11 deletions.
54 changes: 43 additions & 11 deletions src/HDF5.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,13 @@ cset(::Type{ASCIIChar}) = H5T_CSET_ASCII

hdf5_type_id(::Type{C}) where {C<:CharType} = H5T_C_S1

# global configuration for complex support
const COMPLEX_SUPPORT = Ref(true)
const COMPLEX_FIELD_NAMES = Ref(("r", "i"))
enable_complex_support() = COMPLEX_SUPPORT[] = true
disable_complex_support() = COMPLEX_SUPPORT[] = false
set_complex_field_names(real::AbstractString, imag::AbstractString) = COMPLEX_FIELD_NAMES[] = ((real, imag))

## HDF5 uses a plain integer to refer to each file, group, or
## dataset. These are wrapped into special types in order to allow
## method dispatch.
Expand Down Expand Up @@ -1175,6 +1182,16 @@ datatype(dset::HDF5Attribute) = HDF5Datatype(h5a_get_type(checkvalid(dset).id),
datatype(x::HDF5Scalar) = HDF5Datatype(hdf5_type_id(typeof(x)), false)
datatype(::Type{T}) where {T<:HDF5Scalar} = HDF5Datatype(hdf5_type_id(T), false)
datatype(A::AbstractArray{T}) where {T<:HDF5Scalar} = HDF5Datatype(hdf5_type_id(T), false)
function datatype(::Type{Complex{T}}) where {T<:HDF5Scalar}
COMPLEX_SUPPORT[] || error("complex support disabled. call HDF5.enable_complex_support() to enable")
dtype = h5t_create(H5T_COMPOUND, 2*sizeof(T))
h5t_insert(dtype, COMPLEX_FIELD_NAMES[][1], 0, hdf5_type_id(T))
h5t_insert(dtype, COMPLEX_FIELD_NAMES[][2], sizeof(T), hdf5_type_id(T))
return HDF5Datatype(dtype)
end
datatype(x::Complex{T}) where {T<:HDF5Scalar} = datatype(typeof(x))
datatype(A::AbstractArray{Complex{T}}) where {T<:HDF5Scalar} = datatype(eltype(A))

function datatype(str::String)
type_id = h5t_copy(hdf5_type_id(typeof(str)))
h5t_set_size(type_id, max(sizeof(str), 1))
Expand Down Expand Up @@ -1203,7 +1220,8 @@ dataspace(dset::HDF5Dataset) = HDF5Dataspace(h5d_get_space(checkvalid(dset).id))
dataspace(attr::HDF5Attribute) = HDF5Dataspace(h5a_get_space(checkvalid(attr).id))

# Create a dataspace from in-memory types
dataspace(x::T) where {T<:HDF5Scalar} = HDF5Dataspace(h5s_create(H5S_SCALAR))
dataspace(x::Union{T, Complex{T}}) where {T<:HDF5Scalar} = HDF5Dataspace(h5s_create(H5S_SCALAR))

function _dataspace(sz::Tuple{Vararg{Int}}, max_dims::Union{Dims, Tuple{}}=())
dims = Vector{Hsize}(undef,length(sz))
any_zero = false
Expand Down Expand Up @@ -1301,18 +1319,20 @@ function read(obj::DatasetOrAttribute)
read(obj, T)
end
# Read scalars
function read(obj::DatasetOrAttribute, ::Type{T}) where {T<:HDF5Scalar}
function read(obj::DatasetOrAttribute, ::Type{T}) where {T<:Union{HDF5Scalar, Complex{<:HDF5Scalar}}}
x = read(obj, Array{T})
x[1]
end
# Read array of scalars
function read(obj::DatasetOrAttribute, ::Type{Array{T}}) where {T<:HDF5Scalar}
function read(obj::DatasetOrAttribute, ::Type{Array{T}}) where {T<:Union{HDF5Scalar, Complex{<:HDF5Scalar}}}
if isnull(obj)
return T[]
end
dims = size(obj)
data = Array{T}(undef,dims)
readarray(obj, hdf5_type_id(T), data)
dtype = datatype(data)
readarray(obj, dtype.id, data)
close(dtype)
data
end
# Empty arrays
Expand Down Expand Up @@ -1700,7 +1720,7 @@ for (privatesym, fsym, ptype) in
obj, dtype
end
# Scalar types
($fsym)(parent::$ptype, name::String, data::Union{T, AbstractArray{T}}, plists...) where {T<:ScalarOrString} =
($fsym)(parent::$ptype, name::String, data::Union{T, AbstractArray{T}}, plists...) where {T<:Union{ScalarOrString, Complex{<:HDF5Scalar}}} =
($privatesym)(parent, name, data, plists...)
# VLEN types
($fsym)(parent::$ptype, name::String, data::HDF5Vlen{T}, plists...) where {T<:Union{HDF5Scalar,CharType}} =
Expand All @@ -1723,7 +1743,7 @@ for (privatesym, fsym, ptype, crsym) in
end
end
# Scalar types
($fsym)(parent::$ptype, name::String, data::Union{T, AbstractArray{T}}, plists...) where {T<:ScalarOrString} =
($fsym)(parent::$ptype, name::String, data::Union{T, AbstractArray{T}}, plists...) where {T<:Union{ScalarOrString, Complex{<:HDF5Scalar}}} =
($privatesym)(parent, name, data, plists...)
# VLEN types
($fsym)(parent::$ptype, name::String, data::HDF5Vlen{T}, plists...) where {T<:Union{HDF5Scalar,CharType}} =
Expand All @@ -1732,7 +1752,7 @@ for (privatesym, fsym, ptype, crsym) in
end
# Write to already-created objects
# Scalars
function write(obj::DatasetOrAttribute, x::Union{T, Array{T}}) where {T<:ScalarOrString}
function write(obj::DatasetOrAttribute, x::Union{T, Array{T}}) where {T<:Union{ScalarOrString, Complex{<:HDF5Scalar}}}
dtype = datatype(x)
try
writearray(obj, dtype.id, x)
Expand All @@ -1750,7 +1770,7 @@ function write(obj::DatasetOrAttribute, data::HDF5Vlen{T}) where {T<:Union{HDF5S
end
end
# For plain files and groups, let "write(obj, name, val)" mean "d_write"
write(parent::Union{HDF5File, HDF5Group}, name::String, data::Union{T, AbstractArray{T}}, plists...) where {T<:ScalarOrString} =
write(parent::Union{HDF5File, HDF5Group}, name::String, data::Union{T, AbstractArray{T}}, plists...) where {T<:Union{ScalarOrString, Complex{<:HDF5Scalar}}} =
d_write(parent, name, data, plists...)
# For datasets, "write(dset, name, val)" means "a_write"
write(parent::HDF5Dataset, name::String, data::Union{T, AbstractArray{T}}, plists...) where {T<:ScalarOrString} = a_write(parent, name, data, plists...)
Expand Down Expand Up @@ -1993,7 +2013,19 @@ function hdf5_to_julia_eltype(objtype)
T = HDF5Vlen{hdf5_to_julia_eltype(HDF5Datatype(super_id))}
elseif class_id == H5T_COMPOUND
N = Int(h5t_get_nmembers(objtype.id))
T = HDF5Compound{N}
# check if should be interpreted as complex
if COMPLEX_SUPPORT[] && N == 2
membernames = ntuple(N) do i
h5t_get_member_name(objtype.id, i-1)
end
membertypes = ntuple(N) do i
h5t_get_member_type(objtype.id, i-1) |> HDF5Datatype |> hdf5_to_julia_eltype
end
iscomplex = membernames == COMPLEX_FIELD_NAMES[] && membertypes[1] == membertypes[2] && membertypes[1] <: HDF5.HDF5Scalar
T = iscomplex ? Complex{membertypes[1]} : HDF5Compound{N}
else
T = HDF5Compound{N}
end
elseif class_id == H5T_ARRAY
T = hdf5array(objtype)
else
Expand All @@ -2007,7 +2039,7 @@ end
# These supply default values where possible
# See also the "special handling" section below
h5a_write(attr_id::Hid, mem_type_id::Hid, buf::String) = h5a_write(attr_id, mem_type_id, unsafe_wrap(Vector{UInt8}, pointer(buf), ncodeunits(buf)))
function h5a_write(attr_id::Hid, mem_type_id::Hid, x::T) where {T<:HDF5Scalar}
function h5a_write(attr_id::Hid, mem_type_id::Hid, x::T) where {T<:Union{HDF5Scalar, Complex{<:HDF5Scalar}}}
tmp = Ref{T}(x)
h5a_write(attr_id, mem_type_id, tmp)
end
Expand All @@ -2034,7 +2066,7 @@ end
function h5d_write(dataset_id::Hid, memtype_id::Hid, str::String, xfer::Hid=H5P_DEFAULT)
ccall((:H5Dwrite, libhdf5), Herr, (Hid, Hid, Hid, Hid, Hid, Cstring), dataset_id, memtype_id, H5S_ALL, H5S_ALL, xfer, str)
end
function h5d_write(dataset_id::Hid, memtype_id::Hid, x::T, xfer::Hid=H5P_DEFAULT) where {T<:HDF5Scalar}
function h5d_write(dataset_id::Hid, memtype_id::Hid, x::T, xfer::Hid=H5P_DEFAULT) where {T<:Union{HDF5Scalar, Complex{<:HDF5Scalar}}}
tmp = Ref{T}(x)
h5d_write(dataset_id, memtype_id, H5S_ALL, H5S_ALL, xfer, tmp)
end
Expand Down
48 changes: 48 additions & 0 deletions test/plain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,54 @@ end

end # testset plain

@testset "complex" begin
HDF5.enable_complex_support()

fn = tempname()
f = h5open(fn, "w")

f["ComplexF64"] = 1.0 + 2.0im
attrs(f["ComplexF64"])["ComplexInt64"] = 1im

Acmplx = rand(ComplexF64, 3, 5)
write(f, "Acmplx64", convert(Matrix{ComplexF64}, Acmplx))
write(f, "Acmplx32", convert(Matrix{ComplexF32}, Acmplx))

HDF5.disable_complex_support()
@test_throws ErrorException f["_ComplexF64"] = 1.0 + 2.0im
@test_throws ErrorException write(f, "_Acmplx64", convert(Matrix{ComplexF64}, Acmplx))
@test_throws ErrorException write(f, "_Acmplx32", convert(Matrix{ComplexF32}, Acmplx))
HDF5.enable_complex_support()

close(f)

fr = h5open(fn)
z = read(fr, "ComplexF64")
@test z == 1.0 + 2.0im && isa(z, ComplexF64)
z_attrs = attrs(fr["ComplexF64"])
@test read(z_attrs["ComplexInt64"]) == 1im

Acmplx32 = read(fr, "Acmplx32")
@test convert(Matrix{ComplexF32}, Acmplx) == Acmplx32
@test eltype(Acmplx32) == ComplexF32
Acmplx64 = read(fr, "Acmplx64")
@test convert(Matrix{ComplexF64}, Acmplx) == Acmplx64
@test eltype(Acmplx64) == ComplexF64

HDF5.disable_complex_support()
z = read(fr, "ComplexF64")
@test isa(z, HDF5.HDF5Compound{2})

Acmplx32 = read(fr, "Acmplx32")
@test eltype(Acmplx32) == HDF5.HDF5Compound{2}
Acmplx64 = read(fr, "Acmplx64")
@test eltype(Acmplx64) == HDF5.HDF5Compound{2}

close(fr)

HDF5.enable_complex_support()
end

# test strings with null and undefined references
@testset "undefined and null" begin
fn = tempname()
Expand Down

0 comments on commit a5c7e3f

Please sign in to comment.