/
cross_validation.jl
112 lines (93 loc) · 3.77 KB
/
cross_validation.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
using MLBase
# from Taddy (2015, distrom):
# CV selection rules: both CV1se, which chooses the largest λt with mean
# OOS deviance no more than one standard error away from minimum, and CVmin,
# which chooses λt at lowest mean OOS deviance.
function CVmin(oosdevs)
cvmeans = mean(oosdevs,2)
segCVmin = indmin(cvmeans)
end
function CV1se(oosdevs)
nλ,nfolds = size(oosdevs)
cvmeans = vec(mean(oosdevs,2))
(mincvmean,segCVmin) = findmin(cvmeans)
mincvstds = std(view(oosdevs, segCVmin, :)) / sqrt(nfolds-1)
mincvmean_plus_mincvstds = mincvmean + mincvstds
for s=1:nλ
cv1se = mincvmean_plus_mincvstds - cvmeans[s]
if cv1se >= 0
return s
end
end
error("should have found the cv1se by now")
end
# convenience function to use the same data as in original path
function cross_validate_path(path::RegularizationPath; # fitted path
gen=Kfold(length(y),10), # folds generator (see MLBase)
select=:CVmin) # :CVmin or :CV1se
m = path.m
y = m.rr.y
offset = m.rr.offset
Xstandardized = m.pp.X
cross_validate_path(path,Xstandardized,y;gen=gen,select=select,offset=offset,standardize=false)
end
pathtype(::LassoPath) = LassoPath
pathtype(::GammaLassoPath) = GammaLassoPath
function cross_validate_path{T<:AbstractFloat,V<:FPVector}(path::RegularizationPath, # fitted path
X::AbstractMatrix{T}, y::V; # potentially new data
gen=Kfold(length(y),10), # folds generator (see MLBase)
select=:CVmin, # :CVmin or :CV1se
offset::FPVector=T[],
fitargs...)
@extractfields path m λ
n,p = size(X)
@assert n == length(y) "size(X,1) != length(y)"
nfolds = length(gen)
nλ = length(λ)
d = distfun(path)
l = linkfun(path)
# valid offset given?
if length(m.rr.offset) > 0
length(offset) == n ||
throw(ArgumentError("fit with offset, so `offset` kw arg must be an offset of length $n"))
else
length(offset) > 0 && throw(ArgumentError("fit without offset, so value of `offset` kw arg does not make sense"))
end
# EQUAL WEIGHTS ONLY!
wts = ones(T, n)
# results array
oosdevs = zeros(T,nλ,nfolds)
for (f, train_inds) in enumerate(gen)
test_inds = setdiff(1:n, train_inds)
nis = length(test_inds)
if length(offset) > 0
foldoffset = offset[train_inds]
else
foldoffset = offset
end
# fit model to train_inds
foldpath = fit(pathtype(path),X[train_inds,:],y[train_inds],d,l;λ=λ,wts=wts[train_inds],offset=foldoffset,fitargs...)
if length(offset) > 0
foldoffset = offset[test_inds]
else
foldoffset = offset
end
# calculate etas for each obs x segment
μ = predict(foldpath, X[test_inds,:]; offset=foldoffset, select=:all)
# calculate deviations on test sets efficiently (not much mem)
for s=1:nλ
# deviance of segment s (cummulator for sum of obs deviances)
devs = zero(T)
for ip=1:nis
# get test obs
yi = y[test_inds[ip]]
# deviance of a single observation i in segment s
devs += devresid(d, yi, μ[ip,s])
end
# store result
oosdevs[s,f] = devs/nis
end
end
CVfun = eval(select)
CVfun(oosdevs)
end