Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 162 additions & 80 deletions src/rankcorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,121 +33,203 @@ corspearman(X::RealMatrix) = (Z = mapslices(tiedrank, X, dims=1); cor(Z, Z))
#
#######################################

# Knight JASA (1966)

function corkendall!(x::RealVector, y::RealVector)
# Knight, William R. “A Computer Method for Calculating Kendall's Tau with Ungrouped Data.”
# Journal of the American Statistical Association, vol. 61, no. 314, 1966, pp. 436–439.
# JSTOR, www.jstor.org/stable/2282833.
function corkendall!(x::RealVector, y::RealVector, permx::AbstractVector{<:Integer}=sortperm(x))
if any(isnan, x) || any(isnan, y) return NaN end
n = length(x)
if n != length(y) error("Vectors must have same length") end

# Initial sorting
pm = sortperm(y)
x[:] = x[pm]
y[:] = y[pm]
pm[:] = sortperm(x)
x[:] = x[pm]

# Counting ties in x and y
iT = 1
nT = 0
iU = 1
nU = 0
for i = 2:n
if x[i] == x[i-1]
iT += 1
else
nT += iT*(iT - 1)
iT = 1
end
if y[i] == y[i-1]
iU += 1
else
nU += iU*(iU - 1)
iU = 1
permute!(x, permx)
permute!(y, permx)

# Use widen to avoid overflows on both 32bit and 64bit
npairs = div(widen(n) * (n - 1), 2)
ntiesx = ndoubleties = nswaps = widen(0)
k = 0

@inbounds for i = 2:n
if x[i - 1] == x[i]
k += 1
elseif k > 0
# Sort the corresponding chunk of y, so the rows of hcat(x,y) are
# sorted first on x, then (where x values are tied) on y. Hence
# double ties can be counted by calling countties.
sort!(view(y, (i - k - 1):(i - 1)))
ntiesx += div(widen(k) * (k + 1), 2) # Must use wide integers here
ndoubleties += countties(y, i - k - 1, i - 1)
k = 0
end
end
if iT > 1 nT += iT*(iT - 1) end
nT = div(nT,2)
if iU > 1 nU += iU*(iU - 1) end
nU = div(nU,2)

# Sort y after x
y[:] = y[pm]

# Calculate double ties
iV = 1
nV = 0
jV = 1
for i = 2:n
if x[i] == x[i-1] && y[i] == y[i-1]
iV += 1
else
nV += iV*(iV - 1)
iV = 1
end
if k > 0
sort!(view(y, (n - k):n))
ntiesx += div(widen(k) * (k + 1), 2)
ndoubleties += countties(y, n - k, n)
end
if iV > 1 nV += iV*(iV - 1) end
nV = div(nV,2)

nD = div(n*(n - 1),2)
return (nD - nT - nU + nV - 2swaps!(y)) / (sqrt(nD - nT) * sqrt(nD - nU))
end
nswaps = merge_sort!(y, 1, n)
ntiesy = countties(y, 1, n)

# Calls to float below prevent possible overflow errors when
# length(x) exceeds 77_936 (32 bit) or 5_107_605_667 (64 bit)
(npairs + ndoubleties - ntiesx - ntiesy - 2 * nswaps) /
sqrt(float(npairs - ntiesx) * float(npairs - ntiesy))
end

"""
corkendall(x, y=x)

Compute Kendall's rank correlation coefficient, τ. `x` and `y` must both be either
matrices or vectors.
"""
corkendall(x::RealVector, y::RealVector) = corkendall!(float(copy(x)), float(copy(y)))
corkendall(x::RealVector, y::RealVector) = corkendall!(copy(x), copy(y))

corkendall(X::RealMatrix, y::RealVector) = Float64[corkendall!(float(X[:,i]), float(copy(y))) for i in 1:size(X, 2)]

corkendall(x::RealVector, Y::RealMatrix) = (n = size(Y,2); reshape(Float64[corkendall!(float(copy(x)), float(Y[:,i])) for i in 1:n], 1, n))
function corkendall(X::RealMatrix, y::RealVector)
permy = sortperm(y)
return([corkendall!(copy(y), X[:,i], permy) for i in 1:size(X, 2)])
end

corkendall(X::RealMatrix, Y::RealMatrix) = Float64[corkendall!(float(X[:,i]), float(Y[:,j])) for i in 1:size(X, 2), j in 1:size(Y, 2)]
function corkendall(x::RealVector, Y::RealMatrix)
n = size(Y, 2)
permx = sortperm(x)
return(reshape([corkendall!(copy(x), Y[:,i], permx) for i in 1:n], 1, n))
end

function corkendall(X::RealMatrix)
n = size(X, 2)
C = Matrix{eltype(X)}(I, n, n)
C = Matrix{Float64}(I, n, n)
for j = 2:n
for i = 1:j-1
C[i,j] = corkendall!(X[:,i],X[:,j])
C[j,i] = C[i,j]
permx = sortperm(X[:,j])
for i = 1:j - 1
C[j,i] = corkendall!(X[:,j], X[:,i], permx)
C[i,j] = C[j,i]
end
end
return C
end

function corkendall(X::RealMatrix, Y::RealMatrix)
nr = size(X, 2)
nc = size(Y, 2)
C = Matrix{Float64}(undef, nr, nc)
for j = 1:nr
permx = sortperm(X[:,j])
for i = 1:nc
C[j,i] = corkendall!(X[:,j], Y[:,i], permx)
end
end
return C
end

# Auxilliary functions for Kendall's rank correlation

function swaps!(x::RealVector)
n = length(x)
if n == 1 return 0 end
n2 = div(n, 2)
xl = view(x, 1:n2)
xr = view(x, n2+1:n)
nsl = swaps!(xl)
nsr = swaps!(xr)
sort!(xl)
sort!(xr)
return nsl + nsr + mswaps(xl,xr)
"""
countties(x::RealVector, lo::Integer, hi::Integer)

Return the number of ties within `x[lo:hi]`. Assumes `x` is sorted.
"""
function countties(x::AbstractVector, lo::Integer, hi::Integer)
# Use of widen below prevents possible overflow errors when
# length(x) exceeds 2^16 (32 bit) or 2^32 (64 bit)
thistiecount = result = widen(0)
checkbounds(x, lo:hi)
@inbounds for i = (lo + 1):hi
if x[i] == x[i - 1]
thistiecount += 1
elseif thistiecount > 0
result += div(thistiecount * (thistiecount + 1), 2)
thistiecount = widen(0)
end
end

if thistiecount > 0
result += div(thistiecount * (thistiecount + 1), 2)
end
result
end

function mswaps(x::RealVector, y::RealVector)
i = 1
j = 1
nSwaps = 0
n = length(x)
while i <= n && j <= length(y)
if y[j] < x[i]
nSwaps += n - i + 1
# Tests appear to show that a value of 64 is optimal,
# but note that the equivalent constant in base/sort.jl is 20.
const SMALL_THRESHOLD = 64

# merge_sort! copied from Julia Base
# (commit 28330a2fef4d9d149ba0fd3ffa06347b50067647, dated 20 Sep 2020)
"""
merge_sort!(v::AbstractVector, lo::Integer, hi::Integer, t::AbstractVector=similar(v, 0))

Mutates `v` by sorting elements `x[lo:hi]` using the merge sort algorithm.
This method is a copy-paste-edit of sort! in base/sort.jl, amended to return the bubblesort distance.
"""
function merge_sort!(v::AbstractVector, lo::Integer, hi::Integer, t::AbstractVector=similar(v, 0))
# Use of widen below prevents possible overflow errors when
# length(v) exceeds 2^16 (32 bit) or 2^32 (64 bit)
nswaps = widen(0)
@inbounds if lo < hi
hi - lo <= SMALL_THRESHOLD && return insertion_sort!(v, lo, hi)

m = midpoint(lo, hi)
(length(t) < m - lo + 1) && resize!(t, m - lo + 1)

nswaps = merge_sort!(v, lo, m, t)
nswaps += merge_sort!(v, m + 1, hi, t)

i, j = 1, lo
while j <= m
t[i] = v[j]
i += 1
j += 1
else
end

i, k = 1, lo
while k < j <= hi
if v[j] < t[i]
v[k] = v[j]
j += 1
nswaps += m - lo + 1 - (i - 1)
else
v[k] = t[i]
i += 1
end
k += 1
end
while k < j
v[k] = t[i]
k += 1
i += 1
end
end
return nSwaps
return nswaps
end

# insertion_sort! and midpoint copied from Julia Base
# (commit 28330a2fef4d9d149ba0fd3ffa06347b50067647, dated 20 Sep 2020)
midpoint(lo::T, hi::T) where T <: Integer = lo + ((hi - lo) >>> 0x01)
midpoint(lo::Integer, hi::Integer) = midpoint(promote(lo, hi)...)

"""
insertion_sort!(v::AbstractVector, lo::Integer, hi::Integer)

Mutates `v` by sorting elements `x[lo:hi]` using the insertion sort algorithm.
This method is a copy-paste-edit of sort! in base/sort.jl, amended to return the bubblesort distance.
"""
function insertion_sort!(v::AbstractVector, lo::Integer, hi::Integer)
if lo == hi return widen(0) end
nswaps = widen(0)
@inbounds for i = lo + 1:hi
j = i
x = v[i]
while j > lo
if x < v[j - 1]
nswaps += 1
v[j] = v[j - 1]
j -= 1
continue
end
break
end
v[j] = x
end
return nswaps
end
80 changes: 73 additions & 7 deletions test/rankcorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,86 @@ c22 = corspearman(x2, x2)
@test corspearman(X, X) ≈ [c11 c12; c12 c22]
@test corspearman(X) ≈ [c11 c12; c12 c22]


# corkendall

@test corkendall(x1, y) ≈ -0.105409255338946
@test corkendall(x2, y) ≈ -0.117851130197758
# Check error, handling of NaN, Inf etc
@test_throws ErrorException("Vectors must have same length") corkendall([1,2,3,4], [1,2,3])
@test isnan(corkendall([1,2], [3,NaN]))
@test isnan(corkendall([1,1,1], [1,2,3]))
@test corkendall([-Inf,-0.0,Inf],[1,2,3]) == 1.0

# Test, with exact equality, some known results.
# RealVector, RealVector
@test corkendall(x1, y) == -1/sqrt(90)
@test corkendall(x2, y) == -1/sqrt(72)
# RealMatrix, RealVector
@test corkendall(X, y) == [-1/sqrt(90), -1/sqrt(72)]
# RealVector, RealMatrix
@test corkendall(y, X) == [-1/sqrt(90) -1/sqrt(72)]

# n = 78_000 tests for overflow errors on 32 bit
# Testing for overflow errors on 64bit would require n be too large for practicality
# This also tests merge_sort! since n is (much) greater than SMALL_THRESHOLD.
n = 78_000
# Test with many repeats
@test corkendall(repeat(x1, n), repeat(y, n)) ≈ -1/sqrt(90)
@test corkendall(repeat(x2, n), repeat(y, n)) ≈ -1/sqrt(72)
@test corkendall(repeat(X, n), repeat(y, n)) ≈ [-1/sqrt(90), -1/sqrt(72)]
@test corkendall(repeat(y, n), repeat(X, n)) ≈ [-1/sqrt(90) -1/sqrt(72)]
@test corkendall(repeat([0,1,1,0], n), repeat([1,0,1,0], n)) == 0.0

# Test with no repeats, note testing for exact equality
@test corkendall(collect(1:n), collect(1:n)) == 1.0
@test corkendall(collect(1:n), reverse(collect(1:n))) == -1.0

@test corkendall(X, y) ≈ [-0.105409255338946, -0.117851130197758]
@test corkendall(y, X) ≈ [-0.105409255338946 -0.117851130197758]
# All elements identical should yield NaN
@test isnan(corkendall(repeat([1], n), collect(1:n)))

c11 = corkendall(x1, x1)
c12 = corkendall(x1, x2)
c22 = corkendall(x2, x2)

@test c11 ≈ 1.0
@test c22 ≈ 1.0
# RealMatrix, RealMatrix
@test corkendall(X, X) ≈ [c11 c12; c12 c22]
# RealMatrix
@test corkendall(X) ≈ [c11 c12; c12 c22]

@test c11 == 1.0
@test c22 == 1.0
@test c12 == 3/sqrt(20)

# Finished testing for overflow, so redefine n for speedier tests
n = 100

@test corkendall(repeat(X, n), repeat(X, n)) ≈ [c11 c12; c12 c22]
@test corkendall(repeat(X, n)) ≈ [c11 c12; c12 c22]

# All eight three-element permutations
z = [1 1 1;
1 1 2;
1 2 2;
1 2 2;
1 2 1;
2 1 2;
1 1 2;
2 2 2]

@test corkendall(z) == [1 0 1/3; 0 1 0; 1/3 0 1]
@test corkendall(z, z) == [1 0 1/3; 0 1 0; 1/3 0 1]
@test corkendall(z[:,1], z) == [1 0 1/3]
@test corkendall(z, z[:,1]) == [1; 0; 1/3]

z = float(z)
@test corkendall(z) == [1 0 1/3; 0 1 0; 1/3 0 1]
@test corkendall(z, z) == [1 0 1/3; 0 1 0; 1/3 0 1]
@test corkendall(z[:,1], z) == [1 0 1/3]
@test corkendall(z, z[:,1]) == [1; 0; 1/3]

w = repeat(z, n)
@test corkendall(w) == [1 0 1/3; 0 1 0; 1/3 0 1]
@test corkendall(w, w) == [1 0 1/3; 0 1 0; 1/3 0 1]
@test corkendall(w[:,1], w) == [1 0 1/3]
@test corkendall(w, w[:,1]) == [1; 0; 1/3]

StatsBase.midpoint(1,10) == 5
StatsBase.midpoint(1,widen(10)) == 5