/
eight-gaussians.jl
381 lines (319 loc) · 10.9 KB
/
eight-gaussians.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
## Start of code
## Import dependencies
cd(@__DIR__)
using Pkg
Pkg.activate(".")
Pkg.instantiate()
using Flux
using LinearAlgebra
using jInv.Mesh
using Printf
using Plots
using JLD
using Random
using Zygote
using DelimitedFiles
using LaTeXStrings
using Measures
include("viewers.jl")
include("gaussian-util.jl")
## Experiment parameters
d = 100 # number of dimensions
# Interaction kernel parameters
μ = 1.0 # Mean = Value of K(x, x)
ind_sigma_value = 1.25 * sqrt(d/2) # sigma value scaled by dimension
# ind_sigma_value = 0.2
σ = ind_sigma_value * ones(d) # sigma values for interaction for each dimension
epochs = 10000 # number of iterations
ha = 0.6 # step size for a coefficient updates
hv = 0.6 # step size for velocity updates
hvtemp = 1.0
# scales for the various costs
# [ costL, costAf, costPsi ]
α = [ 1.0; 1.0; 10.0] # costAf must be set to 1.0
# convenience functions
R = Float64
ar = x -> R.(x)
# ----------------------------------------------------------------
# Structure
# ----------------------------------------------------------------
# Build nonlocal MFG
mutable struct NonlocalMFG{R}
a::Any # an array where each column is the list of a_i's at a particular time
fBasis::Any # basis function
K::AbstractArray{R} # kernel matrix
M::AbstractArray{R} # inverse of kernel matrix
Q::Any # function for obstacle
Psi::Any # function corresponding to terminal function
X0::AbstractArray{R} # training points
w::AbstractVector{R} # quadrature weights for X0
α::Any # vector containing penalties for objective functions
tspan::AbstractVector{R} # time interval
nt::Any # number of time steps in ODE solve
cs::Any # contains the costs w/o penalty
end
mutable struct DataStore
norm_grads::Float32
end
NonlocalMFG(
a,
fBasis,
K::AbstractArray{R},
M::AbstractArray{R},
Q,
Psi,
X0::AbstractArray{R};
w,
α,
tspan,
nt,
) where {R<:Real} = NonlocalMFG(
a,
fBasis,
K,
M,
Q,
Psi,
X0,
w,
α,
tspan,
nt,
)
function (J::NonlocalMFG{R})(vel,veltemp,hv) where {R<:Real}
(d, _, nex) = size(J.X0)
# step size
h = R.((J.tspan[2] - J.tspan[1]) / J.nt)
# final time
tk = R.(J.tspan[1])
# generate trajectories
Z = getZ(vel, X0)
cAf = zeros(R, 1, nex)
costInteraction = 0.0
for k = 1:J.nt
# integrate over time to calculate costs of interaction
at = copy(a[:, k])
cAf += (h) .* sum(at .* fBasis(Z[:, k, :]), dims = 1) # A*f cost
costInteraction += h*0.5*sum((J.fBasis(Z[:, k, :])*J.w).^2)
end
costAf = dot(vec(cAf), J.w)
# lagrangian cost
cL = 0.5*h*sum(sum(vel.^2, dims = 1), dims = 2)
costL = dot(vec(cL), J.w)
# terminal cost
phi1 = J.Psi(Z[:, end, :])
costPsi = R.(dot(vec(phi1), vec(J.w)))
# cost as a potential
costPotential = costL * J.α[1] + costPsi * J.α[3] + costInteraction
J.cs = [costL, costAf, costPsi, costPotential, costInteraction]
subtr = vec(vel-veltemp)
prox = 0.5/hv*dot(subtr, subtr)
Jc = dot(vec(J.cs[1:3]), vec(J.α)) + prox
return Jc
end
# generates the trajectories from the starting positions given the velocities across time
function getZ(V, X0)
(d, nt, nex) = size(V)
h = Float64.(1/nt)
Z = X0
for i in 1:nt
Z = cat(Z, reshape(Z[:, end, :] + V[:, i, :]*h,(d,1,nex)), dims=2)
end
return Z
end
# generates interpolated trajectories
function getZv2(V, X0, interpolate=2)
(d, nt, nex) = size(V)
h = Float64.(1/nt)
Z = X0
hh = 1.0/interpolate
for i in 1:nt
for j in 1:interpolate
Z = cat(Z, reshape(Z[:, end, :] + V[:, i, :]*h*hh,(d,1,nex)), dims=2)
end
end
return Z
end
## Plotting function
# number of dimensions to plot
d_plot = 2
# create the 8 gaussian distribution for plotting
sig_plot = ar(0.1*0.1*ones(R,d_plot))
Gs_plot = Array{Gaussian{R,Vector{R}}}(undef,8)
ang_plot = range(0,stop=2*pi, length=length(Gs_plot)+1)[1:end-1]
for k=1:length(ang_plot)
μk_plot = ar(1.0*([cos(ang_plot[k]); sin(ang_plot[k]);zeros(d_plot-2)]))
Gs_plot[k] = Gaussian(d_plot,sig_plot,μk_plot,R(1.0/length(Gs_plot)))
end
rho0_plot = GaussianMixture(Gs_plot)
# generates the plots using the current velocities
function plotting(X0, vel, iter, time_elapsed, save_figure = false, rel_norm_grads=1, grad_a=1)
p1 = plot()
# plotting recipe
default( tickfont = (4, "arial", :grey),
guidefont = (4, "arial", :black),
legendfont = (4, "arial", :grey),
titlefont = (7, "arial", :grey),
legend = false,
xformatter = :plain,
yminorgrid = true,
dpi = 200,
linewidth = 1.5,
markersize = 3.0,
)
# plotting grid parameters
domain = [-1.5 1.5 -1.5 1.5]
n = [64 64]
MM = getRegularMesh(domain, n)
Xc = Matrix(getCellCenteredGrid(MM)')
# Plot initial distribution
r0 = rho0_plot(Xc)
p1 = viewImage2D(r0, MM, aspect_ratio = :equal, clims = (minimum(r0), maximum(r0)), c = :amp)
# generate trajectories for each agent
Ztraj = getZv2(vel, X0, 2)
XX = Ztraj[:,:,1:2:end]
nEx = size(XX, 3)
nT = size(XX, 2)
for k = 1:nEx
plot!(
p1,
XX[1, :, k],
XX[2, :, k],
legend = false,
linewidth = 1,
seriestype = [:scatter],
aspect_ratio = :equal,
markersize = 4,
markeralpha = 0.4,
markercolor = :lightblue,
markerstrokewidth = 1,
markerstrokealpha = 1,
markerstrokecolor = :black,
markerstrokestyle = :dot
)
end
# include information about the norm of the gradient in the title of the plot
norm_grads = data_stored.norm_grads
str_title = @sprintf("d: %d agents: %d sigma: %0.2f mu: %1.2f nBasis: %d\n time: %2.2e rel||grads||: %0.2e aResidual: %0.2e\ncP: %2.2e cL: %2.2e cI: %2.2e cPsi: %2.2e\n\n", d, nTrain, σ[1], μ, nBasis, time_elapsed, rel_norm_grads, grad_a, (J.cs[4].data), J.cs[1].data, J.cs[5].data, J.cs[3].data)
title!(str_title)
plot(p1)
p1 = plot!(size=(340,360))
display(p1)
if save_figure
savefig(filename)
end
end
#-----------------------------------------------------------------
# Training parameters
# ----------------------------------------------------------------
# number of training samples
nTrain = 256
# number of time steps
nTSteps = 12
# number of features for the interaction kernel
r = 512
nBasis = r
# d_ = number of dimensions to apply the interaction to
d_ = d # 2: first two, d: all interaction
varK̂ = 1 ./ (σ[1:d_].^2) # 1/ variance of original kernel
K̂ = Gaussian(d_, varK̂, zeros(d_))
ωs = sample(K̂, floor(Int32, r/2))
ck = fill(2/r, floor(Int32, r/2))
# basis function to map trajectories to random features
fBasis(x) = sqrt(μ)*vcat( sqrt.(ck).*cos.((ωs' * x[1:d_, :])), sqrt.(ck).*sin.((ωs' * x[1:d_, :])))
# initialize the coefficients, a
a = randn(Float64, (nBasis, nTSteps))
a0 = copy(a) # initial value of a
K = 1.0 * Matrix(I, nBasis, nBasis)
M = K^-1
# Terminal function
xtarget = 0.0 * ones(d) # Mean of target
Psi(x) = (sum((x[1:d, :] .- xtarget) .^ 2, dims = 1))
# obstacle function
Q(x) = 0.0
## Initial Density: 8 Gaussians
sig_original = 0.1 # sigma of initial distribution
sig = ar(sig_original*sig_original*ones(R,d))
Gs = Array{Gaussian{R,Vector{R}}}(undef,8)
ang = range(0,stop=2*pi, length=length(Gs)+1)[1:end-1]
for k=1:length(ang)
μk = ar(1*([cos(ang[k]); sin(ang[k]);zeros(d-2)]))
Gs[k] = Gaussian(d,sig,μk,R(1.0/length(Gs)))
end
rho0 = GaussianMixture(Gs)
X0 = reshape(sample(rho0, nTrain), d, 1, nTrain) # sample from rho0 for training data
w = fill(1.0 / size(X0, 3), size(X0, 3)) # quadrature weights for training data
## Initialize the nonlocal MFG
J = NonlocalMFG(a, fBasis, K, M, Q, Psi, X0, w , α, R.([0.0; 1.0]), nTSteps, 0)
#----------------------------------------------------------------
# Update function
# ---------------------------------------------------------------
# intial velocities
vel = fill(0.0, d, nTSteps, nTrain)
veltemp = copy(vel)
# calculate initial cost and gradients
c, back = Zygote.Tracker.forward((vel)->J(vel,veltemp,hv), vel)
grads0 = back(1)[1]
grads0 = grads0.data
norm_grads0 = norm(grads0,Inf)
grads = copy(grads0)
data_stored = DataStore(0.0)
# updates all the trajectories and coefficients
function updateAll(ha, hv, hvtemp, iter, doplot=false)
global vel, norm_grads0, grads
start_time = time()
(d, nex) = size(J.X0)
h = R.((J.tspan[2] - J.tspan[1]) / J.nt)
norm_grads = norm(grads,Inf)
rel_norm_grads = abs(norm_grads/norm_grads0)
veltemp = copy(vel)
# calculate the total objective cost
c, back = Zygote.Tracker.forward((vel)->J(vel,veltemp,hv), vel)
# calculate the gradients
grads = back(1)[1]
grads = grads.data
# update the velocities
vel -= grads*hv
# generate the new trajectories
Z = getZ(vel + hvtemp*(vel-veltemp), J.X0)
grad_a = 0
for tval = 1:J.nt # integrate across time
# integrate the features
foZ = R.(J.fBasis(Z[:, tval, :]))
integ = sum(reshape(J.w, 1, :).* foZ , dims = 2)
# gradient for a at time t
grad_at = -J.M*a[1:nBasis, tval] + integ
# updating a
a[1:nBasis, tval] = a[1:nBasis, tval] + ha*grad_at
# storing away the norm
grad_a += norm(grad_at)^2
end
# total gradient of a
grad_a = sqrt(grad_a)/J.nt
end_time = time()
# plot every 100 iterations
if mod(iter+1, 100) == 0
if doplot == true
plotting(X0, vel, floor(Int,iter/50), end_time-start_time, true, rel_norm_grads, grad_a)
end
end
# print the costs every 10 iterations
if mod(iter+1, 10) == 0
costL, costAf, costPsi, costPotential, costInteraction = J.cs
@printf "epochs: %d costPotential: %0.4f costL: %0.4f costInterac: %0.4f costPsi: %0.4f norm grads: %0.4f rel grad norm: %0.4f dualnorm: %0.4e time: %0.4f\n" iter+1 (costPotential) costL costInteraction costPsi norm_grads rel_norm_grads grad_a end_time-start_time
data_stored.norm_grads = norm_grads
end
return rel_norm_grads
end
## ----------------------------------------------------------------
# Run the training loop
# ---------------------------------------------------------------
filename = ("./figures/8-gaussian-d-$(d)-d_-$(d_)-nBasis-$(nBasis)-dual-sigma-$(ind_sigma_value)-mu-$(μ).png")
print("save into", filename, "\n\n")
for i in 1:epochs
rel_norm_grads = updateAll(ha, hv, hvtemp, i, true)
end
## Saving results
writefilenameparam = "./data/primal-dual-8-gaussian-d-$(d)-d_-$(d_)-sigma-$(ind_sigma_value)-mu-$(μ)-N-$(nTrain)-nBasis-$(nBasis)-param.txt"
writedlm(writefilenameparam, [nTrain, nTSteps, d, nBasis, ind_sigma_value, μ, X0, vel])
## End of code