-
Notifications
You must be signed in to change notification settings - Fork 4
/
fit.jl
250 lines (212 loc) · 10.7 KB
/
fit.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""
fit_iht(y, x, z; k=10, J=1, d = Normal(), l=IdentityLink(), group=Int[],
weight=Float64[], est_r=:None, debias=false, verbose=true, tol=1e-4,
max_iter=200, max_step=3, io=stdout)
Fits a model on design matrix (genotype data) `x`, response (phenotype) `y`,
and non-genetic covariates `z` on a specific sparsity parameter `k`. Variables in
`x` and `z` will both be subject to sparsity constraint.
If `k` is a constant, then each group will have the same sparsity level. To run doubly
sparse IHT, construct `k` to be a vector where `k[i]` indicates the max number
of predictors for group `i`.
# Arguments:
+ `y`: Phenotype vector or matrix. Should be an `Array{T, 1}` (single traits) or
`Array{T, 2}` (multivariate Gaussian traits). For multivariate traits, each
column of `y` should be a sample.
+ `x`: Genotype matrix (an `Array{T, 2}` or `SnpLinAlg`). For univariate
analysis, samples are rows of `x`. For multivariate analysis, samples are
columns of `x` (i.e. input `Transpose(x)` for `SnpLinAlg`)
+ `z`: Matrix of non-genetic covariates of type `Array{T, 2}` or `Array{T, 1}`.
For univariate analysis, sample covariates are rows of `z`. For multivariate
analysis, sample covariates are columns of `z`. If this is not specified, an
intercept term will be included automatically. If `z` is specified, make sure
the first column (row) is all 1s to represent the intercept.
# Optional Arguments:
+ `k`: Number of non-zero predictors. Can be a constant or a vector (for group IHT).
+ `J`: The number of maximum groups (set as 1 if no group infomation available)
+ `d`: Distribution of phenotypes. Specify `Normal()` for quantitative traits,
`Bernoulli()` for binary traits, `Poisson()` or `NegativeBinomial()` for
count traits, and `MvNormal()` for multiple quantitative traits.
+ `l`: A link function. The recommended link functions are `l=IdentityLink()` for
quantitative traits, `l=LogitLink()` for binary traits, `l=LogLink()` for Poisson
distribution, and `l=Loglink()` for NegativeBinomial distribution.
+ `group`: vector storing (non-overlapping) group membership
+ `weight`: vector storing vector of weights containing prior knowledge on each SNP
+ `est_r`: Symbol (`:MM`, `:Newton` or `:None`) to estimate nuisance parameters for negative binomial regression
+ `use_maf`: boolean indicating whether we want to scale projection with minor allele frequencies (see paper)
+ `debias`: boolean indicating whether we debias at each iteration
+ `verbose`: boolean indicating whether we want to print intermediate results
+ `tol`: used to track convergence
+ `max_iter`: is the maximum IHT iteration for a model to converge. Defaults to 200, or 100 for cross validation
+ `min_iter`: is the minimum IHT iteration before checking for convergence. Defaults to 5.
+ `max_step`: is the maximum number of backtracking per IHT iteration. Defaults 3
+ `io`: An `IO` object for displaying intermediate results. Default `stdout`.
+ `init_beta`: Whether to initialize beta values to univariate regression values.
Currently only Gaussian traits can be initialized. Default `false`.
"""
function fit_iht(
y :: AbstractVecOrMat{T},
x :: AbstractMatrix{T},
z :: AbstractVecOrMat{T};
k :: Union{Int, Vector{Int}} = 10,
J :: Int = 1,
d :: Distribution = size(y, 2) > 1 ? MvNormal(T[]) : Normal(),
l :: Link = IdentityLink(),
group :: AbstractVector{Int} = Int[],
weight :: AbstractVector{T} = T[],
est_r :: Symbol = :None,
use_maf :: Bool = false,
debias :: Bool = false,
verbose :: Bool = true, # print informative things to stdout
tol :: T = convert(T, 1e-4), # tolerance for tracking convergence
max_iter :: Int = 200, # maximum IHT iterations
min_iter :: Int = 5, # minimum IHT iterations
max_step :: Int = 3, # maximum backtracking for each iteration
io :: IO = stdout,
init_beta :: Bool = false
) where T <: Float
# first handle errors
@assert J ≥ 0 "Value of J (max number of groups) must be nonnegative!\n"
@assert max_iter ≥ 0 "Value of max_iter must be nonnegative!\n"
@assert max_step ≥ 0 "Value of max_step must be nonnegative!\n"
@assert tol > eps(T) "Value of global tol must exceed machine precision!\n"
checky(y, d) # make sure response data y is in the form compatible with specified GLM
check_group(k, group) # make sure sparsity parameter `k` is reasonable.
!(typeof(d) <: NegativeBinomial) && est_r != :None &&
error("Only negative binomial regression currently supports nuisance parameter estimation")
typeof(x) <: AbstractSnpArray && error("x is a SnpArray! Please convert it to a SnpLinAlg first!")
check_data_dim(y, x, z)
if typeof(x) <: SnpLinAlg
x.center || error("x is not centered! Please construct SnpLinAlg{Float64}(::SnpArray, center=true, scale=true)")
x.scale || @warn("x is not scaled! We highly recommend `scale=true` in `SnpLinAlg` constructor")
x.impute || @warn("x does not have impute flag! We highly recommend `impute=true` in `SnpLinAlg` constructor")
end
# initialize IHT variable
v = initialize(x, z, y, J, k, d, l, group, weight, est_r, init_beta)
# print information
if verbose
print_iht_signature(io)
print_parameters(io, k, d, l, use_maf, group, debias, tol, max_iter, min_iter)
end
tot_time, best_logl, mm_iter = fit_iht!(v, debias=debias, verbose=verbose,
tol=tol, max_iter=max_iter, min_iter=min_iter, max_step=max_step, io=io)
# compute phenotype's proportion of variation explained
σ2 = pve(v)
return IHTResult(tot_time, best_logl, mm_iter, σ2, v)
end
function fit_iht(
y::AbstractVecOrMat{T},
x::AbstractMatrix{T};
kwargs...
) where T
z = is_multivariate(y) ? ones(T, 1, size(y, 2)) : ones(T, length(y))
return fit_iht(y, x, z; kwargs...)
end
"""
fit_iht!(v; kwargs...)
Fits a IHT variable `v`.
# Arguments:
+ `v`: A properly initialized `mIHTVariable` or `IHTVariable`. Users should run [`fit_iht`](@ref)
# Optional Arguments:
+ `debias`: boolean indicating whether we debias at each iteration
+ `verbose`: boolean indicating whether we want to print results if model does not converge.
+ `tol`: used to track convergence
+ `max_iter`: is the maximum IHT iteration for a model to converge. Defaults to 200, or 100 for cross validation
+ `max_step`: is the maximum number of backtracking. Since l0 norm is not convex, we have no ascent guarantee
+ `io`: An `IO` object for displaying intermediate results. Default `stdout`.
"""
function fit_iht!(
v :: Union{mIHTVariable{T, M}, IHTVariable{T, M}};
debias :: Bool = false,
verbose :: Bool = true, # print informative things
tol :: T = convert(T, 1e-4), # tolerance for tracking convergence
max_iter :: Int = 200, # maximum IHT iterations
min_iter :: Int = 5, # minimum IHT iterations
max_step :: Int = 3, # maximum backtracking for each iteration
io :: IO = stdout
) where {T <: Float, M}
#start timer
start_time = time()
# initialize constants
mm_iter = 0 # number of iterations
tot_time = 0.0 # compute time *within* fit!
next_logl = typemin(T) # loglikelihood
best_logl = typemin(T) # best loglikelihood achieved
η_step = 0 # counts number of backtracking steps for η
# Begin 'iterative' hard thresholding algorithm
for iter in 1:max_iter
# notify and return current model if maximum iteration exceeded
if iter ≥ max_iter
best_logl = save_prev!(v, next_logl, best_logl)
save_best_model!(v)
mm_iter = iter
tot_time = time() - start_time
verbose && printstyled(io, "Did not converge after $max_iter " *
"iterations! IHT run time was " * string(tot_time) *
" seconds\n", color=:red)
break
end
# save values from previous iterate and update loglikelihood
best_logl = save_prev!(v, next_logl, best_logl)
# take one IHT step in positive score direction
(η, η_step, next_logl) = iht_one_step!(v, next_logl, max_step)
# perform debiasing if support didn't change
debias && iter ≥ 5 && v.idx == v.idx0 && debias!(v)
# track convergence
# Note: estimated beta in first few iterations can be very small, so scaled_norm is very small
# Thus we force IHT to iterate at least 5 times
scaled_norm = check_convergence(v)
progr = "Iteration $iter: loglikelihood = $next_logl, backtracks = $η_step, tol = $scaled_norm"
verbose && println(io, progr)
verbose && io != stdout && println(progr)
if iter ≥ min_iter && scaled_norm < tol
best_logl = save_prev!(v, next_logl, best_logl)
save_best_model!(v)
tot_time = time() - start_time
mm_iter = iter
break
end
end
return tot_time, best_logl, mm_iter
end
"""
Performs 1 iteration of the IHT algorithm, backtracking a maximum of `nstep` times.
We allow loglikelihood to potentially decrease to avoid bad boundary cases.
"""
function iht_one_step!(
v::Union{IHTVariable{T, M}, mIHTVariable{T, M}},
old_logl::T,
nstep::Int
) where {T <: Float, M <: AbstractMatrix}
# first calculate step size
η = iht_stepsize!(v)
# update b and c by taking gradient step v.b = P_k(β + ηv) where v is the score direction
_iht_gradstep!(v, η)
# update the linear predictors `xb`, `μ`, and residuals with the new proposed b
update_xb!(v)
update_μ!(v)
# for multivariate IHT, also update precision matrix Γ = 1/n * (Y-BX)(Y-BX)'
if typeof(v) <: mIHTVariable
solve_Σ!(v)
end
# update r (nuisance parameter for negative binomial)
if typeof(v) <: IHTVariable && v.est_r != :None
v.d = mle_for_r(v)
end
# calculate current loglikelihood with the new computed xb and zc
new_logl = loglikelihood(v)
η_step = 0
while _iht_backtrack_(new_logl, old_logl, η_step, nstep)
# stephalving
η /= 2
# compute new loglikelihood after linesearch
new_logl = backtrack!(v, η)
# increment the counter
η_step += 1
end
# compute score with the new mean
score!(v)
# check for finiteness before moving to the next iteration
isnan(new_logl) && throw(error("Loglikelihood function is NaN, aborting..."))
isinf(new_logl) && throw(error("Loglikelihood function is Inf, aborting..."))
return η::T, η_step::Int, new_logl::T
end