Skip to content

Commit

Permalink
Merge pull request #25241 from JuliaLang/jb/vectorstring
Browse files Browse the repository at this point in the history
safer vector<->string conversions, fixing #24388
  • Loading branch information
JeffBezanson authored Jan 3, 2018
2 parents 3d1886f + 7bce3b1 commit 2043060
Show file tree
Hide file tree
Showing 24 changed files with 143 additions and 50 deletions.
15 changes: 8 additions & 7 deletions base/c.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ cconvert(::Type{Cstring}, s::AbstractString) =
cconvert(Cstring, String(s)::String)

function cconvert(::Type{Cwstring}, s::AbstractString)
v = transcode(Cwchar_t, Vector{UInt8}(String(s)))
v = transcode(Cwchar_t, String(s))
!isempty(v) && v[end] == 0 || push!(v, 0)
return v
end
Expand All @@ -140,7 +140,7 @@ containsnul(p::Ptr, len) =
containsnul(s::String) = containsnul(unsafe_convert(Ptr{Cchar}, s), sizeof(s))
containsnul(s::AbstractString) = '\0' in s

function unsafe_convert(::Type{Cstring}, s::Union{String,Vector{UInt8}})
function unsafe_convert(::Type{Cstring}, s::Union{String,AbstractVector{UInt8}})
p = unsafe_convert(Ptr{Cchar}, s)
containsnul(p, sizeof(s)) &&
throw(ArgumentError("embedded NULs are not allowed in C strings: $(repr(s))"))
Expand Down Expand Up @@ -174,7 +174,7 @@ same argument.
This is only available on Windows.
"""
function cwstring(s::AbstractString)
bytes = Vector{UInt8}(String(s))
bytes = codeunits(String(s))
0 in bytes && throw(ArgumentError("embedded NULs are not allowed in C strings: $(repr(s))"))
return push!(transcode(UInt16, bytes), 0)
end
Expand Down Expand Up @@ -202,19 +202,20 @@ Only conversion to/from UTF-8 is currently supported.
"""
function transcode end

transcode(::Type{T}, src::Vector{T}) where {T<:Union{UInt8,UInt16,UInt32,Int32}} = src
transcode(::Type{T}, src::AbstractVector{T}) where {T<:Union{UInt8,UInt16,UInt32,Int32}} = src
transcode(::Type{T}, src::String) where {T<:Union{Int32,UInt32}} = T[T(c) for c in src]
transcode(::Type{T}, src::Vector{UInt8}) where {T<:Union{Int32,UInt32}} = transcode(T, String(src))
transcode(::Type{T}, src::Union{Vector{UInt8},CodeUnits{UInt8,String}}) where {T<:Union{Int32,UInt32}} =
transcode(T, String(src))
function transcode(::Type{UInt8}, src::Vector{<:Union{Int32,UInt32}})
buf = IOBuffer()
for c in src; print(buf, Char(c)); end
take!(buf)
end
transcode(::Type{String}, src::String) = src
transcode(T, src::String) = transcode(T, Vector{UInt8}(src))
transcode(T, src::String) = transcode(T, codeunits(src))
transcode(::Type{String}, src) = String(transcode(UInt8, src))

function transcode(::Type{UInt16}, src::Vector{UInt8})
function transcode(::Type{UInt16}, src::Union{Vector{UInt8},CodeUnits{UInt8,String}})
dst = UInt16[]
i, n = 1, length(src)
n > 0 || return dst
Expand Down
15 changes: 13 additions & 2 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1555,8 +1555,19 @@ export hex2num
@deprecate ctranspose adjoint
@deprecate ctranspose! adjoint!

@deprecate convert(::Type{Vector{UInt8}}, s::AbstractString) Vector{UInt8}(s)
@deprecate convert(::Type{Array{UInt8}}, s::AbstractString) Vector{UInt8}(s)
function convert(::Union{Type{Vector{UInt8}}, Type{Array{UInt8}}}, s::AbstractString)
depwarn("Strings can no longer be `convert`ed to byte arrays. Use `unsafe_wrap` or `codeunits` instead.", :Type)
unsafe_wrap(Vector{UInt8}, String(s))
end
function (::Type{Vector{UInt8}})(s::String)
depwarn("Vector{UInt8}(s::String) will copy data in the future. To avoid copying, use `unsafe_wrap` or `codeunits` instead.", :Type)
unsafe_wrap(Vector{UInt8}, s)
end
function (::Type{Array{UInt8}})(s::String)
depwarn("Array{UInt8}(s::String) will copy data in the future. To avoid copying, use `unsafe_wrap` or `codeunits` instead.", :Type)
unsafe_wrap(Vector{UInt8}, s)
end

@deprecate convert(::Type{Vector{Char}}, s::AbstractString) Vector{Char}(s)
@deprecate convert(::Type{Symbol}, s::AbstractString) Symbol(s)
@deprecate convert(::Type{String}, s::Symbol) String(s)
Expand Down
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ export
chomp,
chop,
codeunit,
codeunits,
dec,
digits,
digits!,
Expand Down
2 changes: 1 addition & 1 deletion base/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ function readuntil(io::IO, target::AbstractString)
# decide how we can index target
if target isa String
# convert String to a utf8-byte-iterator
target = Vector{UInt8}(target)
target = codeunits(target)
#elseif applicable(codeunit, target)
# TODO: a more general version of above optimization
# would be to permit accessing any string via codeunit
Expand Down
2 changes: 1 addition & 1 deletion base/iobuffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function GenericIOBuffer(data::T, readable::Bool, writable::Bool, seekable::Bool
end

# allocate Vector{UInt8}s for IOBuffer storage that can efficiently become Strings
StringVector(n::Integer) = Vector{UInt8}(_string_n(n))
StringVector(n::Integer) = unsafe_wrap(Vector{UInt8}, _string_n(n))

# IOBuffers behave like Files. They are typically readable and writable. They are seekable. (They can be appendable).

Expand Down
4 changes: 2 additions & 2 deletions base/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ elseif Sys.isapple()
break
end
# Hack to compensate for inability to create a string from a subarray with no allocations.
Vector{UInt8}(path_basename) == casepreserved_basename && return true
codeunits(path_basename) == casepreserved_basename && return true

# If there is no match, it's possible that the file does exist but HFS+
# performed unicode normalization. See https://developer.apple.com/library/mac/qa/qa1235/_index.html.
isascii(path_basename) && return false
Vector{UInt8}(Unicode.normalize(path_basename, :NFD)) == casepreserved_basename
codeunits(Unicode.normalize(path_basename, :NFD)) == casepreserved_basename
end
else
# Generic fallback that performs a slow directory listing.
Expand Down
2 changes: 1 addition & 1 deletion base/repl/LineEdit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ function edit_splice!(s, r::Region=region(s), ins::AbstractString = ""; rigid_ma
elseif buf.mark >= B
buf.mark += sizeof(ins) - B + A
end
ret = splice!(buf.data, A+1:B, Vector{UInt8}(ins)) # position(), etc, are 0-indexed
ret = splice!(buf.data, A+1:B, codeunits(String(ins))) # position(), etc, are 0-indexed
buf.size = buf.size + sizeof(ins) - B + A
adjust_pos && seek(buf, position(buf) + sizeof(ins))
String(ret)
Expand Down
2 changes: 1 addition & 1 deletion base/replutil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ function showerror(io::IO, ex::ErrorException)
print(io, ex.msg)
if ex.msg == "type String has no field data"
println(io)
print(io, "Use `Vector{UInt8}(str)` instead.")
print(io, "Use `codeunits(str)` instead.")
end
end
showerror(io::IO, ex::KeyError) = print(io, "KeyError: key $(repr(ex.key)) not found")
Expand Down
42 changes: 39 additions & 3 deletions base/strings/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ end
getindex(s::AbstractString, i::Colon) = s
# TODO: handle other ranges with stride ±1 specially?
# TODO: add more @propagate_inbounds annotations?
getindex(s::AbstractString, r::UnitRange{<:Integer}) = SubString(s, r)
getindex(s::AbstractString, v::AbstractVector{<:Integer}) =
sprint(io->(for i in v; write(io, s[i]) end), sizehint=length(v))
getindex(s::AbstractString, v::AbstractVector{Bool}) =
Expand Down Expand Up @@ -185,8 +184,8 @@ checkbounds(s::AbstractString, I::Union{Integer,AbstractArray}) =
string() = ""
string(s::AbstractString) = s

(::Type{Vector{UInt8}})(s::AbstractString) = Vector{UInt8}(String(s))
(::Type{Array{UInt8}})(s::AbstractString) = Vector{UInt8}(s)
(::Type{Vector{UInt8}})(s::AbstractString) = unsafe_wrap(Vector{UInt8}, String(s))
(::Type{Array{UInt8}})(s::AbstractString) = unsafe_wrap(Vector{UInt8}, String(s))
(::Type{Vector{Char}})(s::AbstractString) = collect(s)

Symbol(s::AbstractString) = Symbol(String(s))
Expand Down Expand Up @@ -629,3 +628,40 @@ next(r::Iterators.Reverse{<:AbstractString}, i) = (r.itr[i], prevind(r.itr, i))
start(r::Iterators.Reverse{<:EachStringIndex}) = endof(r.itr.s)
done(r::Iterators.Reverse{<:EachStringIndex}, i) = i < start(r.itr.s)
next(r::Iterators.Reverse{<:EachStringIndex}, i) = (i, prevind(r.itr.s, i))

## code unit access ##

"""
CodeUnits(s::AbstractString)
Wrap a string (without copying) in an immutable vector-like object that accesses the code units
of the string's representation.
"""
struct CodeUnits{T,S<:AbstractString} <: DenseVector{T}
s::S
CodeUnits(s::S) where {S<:AbstractString} = new{codeunit(s),S}(s)
end

length(s::CodeUnits) = ncodeunits(s.s)
sizeof(s::CodeUnits{T}) where {T} = ncodeunits(s.s) * sizeof(T)
size(s::CodeUnits) = (length(s),)
strides(s::CodeUnits) = (1,)
@propagate_inbounds getindex(s::CodeUnits, i::Int) = codeunit(s.s, i)
IndexStyle(::Type{<:CodeUnits}) = IndexLinear()
start(s::CodeUnits) = 1
next(s::CodeUnits, i) = (@_propagate_inbounds_meta; (s[i], i+1))
done(s::CodeUnits, i) = (@_inline_meta; i == length(s)+1)

write(io::IO, s::CodeUnits) = write(io, s.s)

unsafe_convert(::Type{Ptr{T}}, s::CodeUnits{T}) where {T} = unsafe_convert(Ptr{T}, s.s)
unsafe_convert(::Type{Ptr{Int8}}, s::CodeUnits{UInt8}) = unsafe_convert(Ptr{Int8}, s.s)

"""
codeunits(s::AbstractString)
Obtain a vector-like object containing the code units of a string.
Returns a `CodeUnits` wrapper by default, but `codeunits` may optionally be defined
for new string types if necessary.
"""
codeunits(s::AbstractString) = CodeUnits(s)
9 changes: 6 additions & 3 deletions base/strings/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ julia> String(take!(io))
"Haho"
```
"""
IOBuffer(str::String) = IOBuffer(Vector{UInt8}(str))
IOBuffer(s::SubString{String}) = IOBuffer(view(Vector{UInt8}(s.string), s.offset + 1 : s.offset + sizeof(s)))
IOBuffer(str::String) = IOBuffer(unsafe_wrap(Vector{UInt8}, str))
IOBuffer(s::SubString{String}) = IOBuffer(view(unsafe_wrap(Vector{UInt8}, s.string), s.offset + 1 : s.offset + sizeof(s)))

# join is implemented using IO

Expand Down Expand Up @@ -373,7 +373,10 @@ function unescape_string(io, s::AbstractString)
end
end

macro b_str(s); :(Vector{UInt8}($(unescape_string(s)))); end
macro b_str(s)
v = Vector{UInt8}(codeunits(unescape_string(s)))
QuoteNode(v)
end

"""
@raw_str -> String
Expand Down
6 changes: 3 additions & 3 deletions base/strings/search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function search(a::ByteArray, b::Char, i::Integer = 1)
if isascii(b)
search(a,UInt8(b),i)
else
search(a,Vector{UInt8}(string(b)),i).start
search(a,unsafe_wrap(Vector{UInt8},string(b)),i).start
end
end

Expand Down Expand Up @@ -62,7 +62,7 @@ function rsearch(a::ByteArray, b::Char, i::Integer = length(a))
if isascii(b)
rsearch(a,UInt8(b),i)
else
rsearch(a,Vector{UInt8}(string(b)),i).start
rsearch(a,unsafe_wrap(Vector{UInt8},string(b)),i).start
end
end

Expand Down Expand Up @@ -147,7 +147,7 @@ function _search_bloom_mask(c)
end

_nthbyte(s::String, i) = codeunit(s, i)
_nthbyte(a::ByteArray, i) = a[i]
_nthbyte(a::Union{AbstractVector{UInt8},AbstractVector{Int8}}, i) = a[i]

function _searchindex(s::Union{String,ByteArray}, t::Union{String,ByteArray}, i)
n = sizeof(t)
Expand Down
6 changes: 5 additions & 1 deletion base/strings/string.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ This representation is often appropriate for passing strings to C.
String(s::AbstractString) = print_to_string(s)
String(s::Symbol) = unsafe_string(unsafe_convert(Ptr{UInt8}, s))

(::Type{Vector{UInt8}})(s::String) = ccall(:jl_string_to_array, Ref{Vector{UInt8}}, (Any,), s)
unsafe_wrap(::Type{Vector{UInt8}}, s::String) = ccall(:jl_string_to_array, Ref{Vector{UInt8}}, (Any,), s)

(::Type{Vector{UInt8}})(s::CodeUnits{UInt8,String}) = copyto!(Vector{UInt8}(uninitialized, length(s)), s)

String(s::CodeUnits{UInt8,String}) = s.s

## low-level functions ##

Expand Down
1 change: 0 additions & 1 deletion base/strings/strings.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

include("strings/substring.jl")
include("strings/basic.jl")
include("strings/search.jl")
include("strings/unicode.jl")
include("strings/util.jl")
Expand Down
2 changes: 2 additions & 0 deletions base/strings/substring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,5 @@ function reverse(s::Union{String,SubString{String}})::String
end
end
end

getindex(s::AbstractString, r::UnitRange{<:Integer}) = SubString(s, r)
21 changes: 13 additions & 8 deletions base/strings/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,24 +475,29 @@ julia> hex2bytes(a)
"""
function hex2bytes end

hex2bytes(s::AbstractString) = hex2bytes(Vector{UInt8}(String(s)))
hex2bytes(s::AbstractVector{UInt8}) = hex2bytes!(Vector{UInt8}(uninitialized, length(s) >> 1), s)
hex2bytes(s::AbstractString) = hex2bytes(String(s))
hex2bytes(s::Union{String,AbstractVector{UInt8}}) = hex2bytes!(Vector{UInt8}(uninitialized, length(s) >> 1), s)

_firstbyteidx(s::String) = 1
_firstbyteidx(s::AbstractVector{UInt8}) = first(eachindex(s))
_lastbyteidx(s::String) = sizeof(s)
_lastbyteidx(s::AbstractVector{UInt8}) = endof(s)

"""
hex2bytes!(d::AbstractVector{UInt8}, s::AbstractVector{UInt8})
hex2bytes!(d::AbstractVector{UInt8}, s::Union{String,AbstractVector{UInt8}})
Convert an array `s` of bytes representing a hexadecimal string to its binary
representation, similar to [`hex2bytes`](@ref) except that the output is written in-place
in `d`. The length of `s` must be exactly twice the length of `d`.
"""
function hex2bytes!(d::AbstractVector{UInt8}, s::AbstractVector{UInt8})
if 2length(d) != length(s)
isodd(length(s)) && throw(ArgumentError("input hex array must have even length"))
function hex2bytes!(d::AbstractVector{UInt8}, s::Union{String,AbstractVector{UInt8}})
if 2length(d) != sizeof(s)
isodd(sizeof(s)) && throw(ArgumentError("input hex array must have even length"))
throw(ArgumentError("output array must be half length of input array"))
end
j = first(eachindex(d)) - 1
for i = first(eachindex(s)):2:endof(s)
@inbounds d[j += 1] = number_from_hex(s[i]) << 4 + number_from_hex(s[i+1])
for i = _firstbyteidx(s):2:_lastbyteidx(s)
@inbounds d[j += 1] = number_from_hex(_nthbyte(s,i)) << 4 + number_from_hex(_nthbyte(s,i+1))
end
return d
end
Expand Down
7 changes: 4 additions & 3 deletions base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,6 @@ macro MIME_str(s)
:(MIME{$(Expr(:quote, Symbol(s)))})
end

include("char.jl")
include("strings/string.jl")

# SIMD loops
include("simdloop.jl")
using .SimdLoop
Expand All @@ -212,6 +209,10 @@ include("iterators.jl")
using .Iterators: zip, enumerate
using .Iterators: Flatten, product # for generators

include("char.jl")
include("strings/basic.jl")
include("strings/string.jl")

# Definition of StridedArray
StridedReshapedArray{T,N,A<:Union{DenseArray,FastContiguousSubArray}} = ReshapedArray{T,N,A}
StridedReinterpretArray{T,N,A<:Union{DenseArray,FastContiguousSubArray}} = ReinterpretArray{T,N,S,A} where S
Expand Down
1 change: 1 addition & 0 deletions doc/src/stdlib/strings.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Base.transcode
Base.unsafe_string
Base.ncodeunits(::AbstractString)
Base.codeunit
Base.codeunits
Base.ascii
Base.@r_str
Base.@raw_str
Expand Down
5 changes: 5 additions & 0 deletions src/array.c
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,11 @@ JL_DLLEXPORT jl_value_t *jl_array_to_string(jl_array_t *a)
if (jl_is_string(o)) {
a->flags.isshared = 1;
*(size_t*)o = jl_array_len(a);
a->nrows = 0;
#ifdef STORE_ARRAY_LEN
a->length = 0;
#endif
a->maxsize = 0;
return o;
}
}
Expand Down
2 changes: 1 addition & 1 deletion test/char.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ end
@testset "read incomplete character at end of stream or file" begin
local file = tempname()
local iob = IOBuffer([0xf0])
local bytes(c::Char) = Vector{UInt8}(string(c))
local bytes(c::Char) = codeunits(string(c))
@test bytes(read(iob, Char)) == [0xf0]
@test eof(iob)
try
Expand Down
2 changes: 1 addition & 1 deletion test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4114,7 +4114,7 @@ b = "aaa"
c = [0x2, 0x1, 0x3]

@test check_nul(a)
@test check_nul(Vector{UInt8}(b))
@test check_nul(unsafe_wrap(Vector{UInt8},b))
@test check_nul(c)
d = [0x2, 0x1, 0x3]
@test check_nul(d)
Expand Down
2 changes: 1 addition & 1 deletion test/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ end

let s = "abcα🐨\0x\0"
for T in (UInt8, UInt16, UInt32, Int32)
@test transcode(T, s) == transcode(T, Vector{UInt8}(s))
@test transcode(T, s) == transcode(T, codeunits(s))
@test transcode(String, transcode(T, s)) == s
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ let b = ['0':'9';'A':'Z';'a':'z']
@test length(randstring(rng...)) == 8
@test length(randstring(rng..., 20)) == 20
@test issubset(randstring(rng...), b)
for c = ['a':'z', "qwèrtï", Set(Vector{UInt8}("gcat"))],
for c = ['a':'z', "qwèrtï", Set(codeunits("gcat"))],
len = [8, 20]
s = len == 8 ? randstring(rng..., c) : randstring(rng..., c, len)
@test length(s) == len
Expand Down
Loading

0 comments on commit 2043060

Please sign in to comment.