-
Notifications
You must be signed in to change notification settings - Fork 3
/
wh.jl
106 lines (93 loc) · 2.97 KB
/
wh.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
wh_learn(X, Y)
Widrow-Hoff Learning.
# Obligatory Arguments
- `test_mode::Symbol`: which test mode, currently supports :train_only, :pre_split, :careful_split and :random_split.
# Optional Arguments
- `eta::Float64=0.1`: the learning rate
- `n_epochs::Int64=1`: the number of epochs to be trained
- `weights::Matrix=nothing`: the initial weights
- `learn_seq::Vector=nothing`: the learning sequence
- `save_history::Bool=false`: if true, a partical training history will be saved
- `history_cols::Vector=nothing`: the list of column indices you want to saved in history, e.g. `[1,32,42]` or `[2]`
- `history_rows::Vector=nothing`: the list of row indices you want to saved in history, e.g. `[1,32,42]` or `[2]`
- `verbose::Bool = false`: if true, more information will be printed out
"""
function wh_learn(
X,
Y;
eta = 0.01,
n_epochs = 1,
weights = nothing,
learn_seq = nothing,
save_history = false,
history_cols = nothing,
history_rows = nothing,
verbose = false,
)
X = Array(X)
Y = Array(Y)
if size(X, 1) != size(Y, 1)
throw(ArgumentError("X($(size(inputs,1))) and Y($(size(outputs,1))) length doesn't match"))
end
if isnothing(weights)
W = zeros(Float64, (size(X, 2), size(Y, 2)))
else
W = weights
end
# construct learn_seq if nothing
if isnothing(learn_seq)
learn_seq = 1:size(X, 1)
end
if save_history
history = zeros(Float64, n_epochs, length(history_rows), length(history_cols))
end
inputT = Matrix{Float64}(undef, (size(X, 2), 1))
pred = Matrix{Float64}(undef, (1, size(Y, 2)))
deltaW = Matrix{Float64}(undef, (size(X, 2), size(Y, 2)))
if verbose
if isnothing(learn_seq)
pb = Progress(size(X, 1) * n_epochs)
else
pb = Progress(size(learn_seq, 1) * n_epochs)
end
end
for j = 1:n_epochs # 100 epochs
for i in learn_seq # for each events
# pred = X[i:i, :]*W
mul!(pred, view(X, i:i, :), W)
# obsv = Y[i:i, :]-pred
broadcast!(-, pred, view(Y, i:i, :), pred)
# inputT = X[i:i, :]'
transpose!(inputT, view(X, i:i, :))
# update = inputT*obsv
mul!(deltaW, inputT, pred)
# deltaW = eta*update
rmul!(deltaW, eta)
# W += deltaW
broadcast!(+, W, W, deltaW)
verbose && ProgressMeter.next!(pb)
end
# push history
if save_history
history[j,:,:] = copy(W[history_rows, history_cols])
end
end
if save_history
return W, history
end
W
end
"""
make_learn_seq(freq)
Make Widrow-Hoff learning sequence.
"""
function make_learn_seq(freq; random_seed = 314)
if isnothing(freq)
return nothing
end
learn_seq = [repeat([i], n) for (i, n) in enumerate(freq)]
learn_seq = collect(Iterators.flatten(learn_seq))
rng = MersenneTwister(random_seed)
shuffle(rng, learn_seq)
end