From 1149823dc8c328a01bd72c0a0060c516898bafc5 Mon Sep 17 00:00:00 2001 From: Jake Bolewski Date: Wed, 27 May 2015 15:22:41 -0400 Subject: [PATCH] add bounds checking for get/setindex --- src/pyarray.jl | 50 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/src/pyarray.jl b/src/pyarray.jl index f3db2582..48e91685 100644 --- a/src/pyarray.jl +++ b/src/pyarray.jl @@ -212,31 +212,43 @@ Base.summary{T}(a::PyArray{T}) = string(Base.dims2string(size(a)), " ", #TODO: is this correct for all buffer types other than contig/dense? #TODO: get rid of this, should be copy! but copy! uses similar under the hood function Base.copy{T,N}(a::PyArray{T,N}) - if N > 1 && a.c_contig # equivalent to f_contig with reversed dims - B = pointer_to_array(a.data, ntuple(N, n -> a.dims[N - n + 1])) - return N == 2 ? transpose(B) : permutedims(B, (N:-1:1)) + if N > 1 && a.c_contig i + # equivalent to f_contig with reversed dims + B = pointer_to_array(a.data, (Int[a.dims[N - d + 1] for d in 1:N]...)) + if N == 2 + return transpose(B) + else + return permutedims(B, N:-1:1) + end end A = Array(T, a.dims) if a.f_contig ccall(:memcpy, Void, (Ptr{T}, Ptr{T}, Int), A, a, sizeof(T) * length(a)) return A - else - return copy!(A, a) end + return copy!(A, a) end -#TODO: Bounds checking is needed Base.getindex{T}(a::PyArray{T,0}) = unsafe_load(a.data) -Base.getindex{T}(a::PyArray{T,1}, i::Integer) = +function Base.getindex{T}(a::PyArray{T,1}, i::Integer) + 1 <= i <= length(a) || throw(BoundsError()) unsafe_load(a.data, 1 + (i-1) * a.strides[1]) +end -Base.getindex{T}(a::PyArray{T,2}, i::Integer, j::Integer) = +function Base.getindex{T}(a::PyArray{T,2}, i::Integer, j::Integer) + 1 <= i <= size(a,1) || throw(BoundsError()) + 1 <= j <= size(a,2) || throw(BoundsError()) unsafe_load(a.data, 1 + (i-1) * a.strides[1] + (j-1) * a.strides[2]) +end -Base.getindex(a::PyArray, i::Integer) = - a.f_contig ? unsafe_load(a.data, i) : - getindex(a, ind2sub(a.dims, i)...) +function Base.getindex(a::PyArray, i::Integer) + if a.f_contig + 1 <= i <= length(a) || throw(BoundsError()) + return unsafe_load(a.data, i) + end + return getindex(a, ind2sub(a.dims, i)...) +end function Base.getindex(a::PyArray, is::Integer...) index = 1 @@ -265,14 +277,22 @@ end Base.setindex!{T}(a::PyArray{T,0}, v) = (unsafe_store!(pointer(a), v, 1); v) -Base.setindex!{T}(a::PyArray{T,1}, v, i::Integer) = - (unsafe_store!(pointer(a), v, 1 + (i-1) * a.strides[1]); v) +function Base.setindex!{T}(a::PyArray{T,1}, v, i::Integer) + 1 <= i <= length(a) || throw(BoundsError()) + unsafe_store!(pointer(a), v, 1 + (i-1) * a.strides[1]) + return v +end -Base.setindex!{T}(a::PyArray{T,2}, v, i::Integer, j::Integer) = - (unsafe_store!(pointer(a), v, 1 + (i-1) * a.strides[1] + (j-1) * a.strides[2]); v) +function Base.setindex!{T}(a::PyArray{T,2}, v, i::Integer, j::Integer) + 1 <= i <= size(a,1) || throw(BoundsError()) + 1 <= j <= size(a,2) || throw(BoundsError()) + unsafe_store!(pointer(a), v, 1 + (i-1) * a.strides[1] + (j-1) * a.strides[2]) + return v +end function Base.setindex!(a::PyArray, v, i::Integer) if a.f_contig + 1 <= i <= length(a) || throw(BoundsError()) unsafe_store!(pointer(a), v, i) return v end