In [None]:
using Flux, PyPlot, Random, Compat

function fun()
N = 30 #number of hidden units
R = 1 #number of readouts
B = 20 #batches per epoch
η = .001f0 #learning rate
β₁,β₂ = 0.5,0.5 # extra hyperparameters for ADAM




Nepochs = 100
dt = .1f0
T = 20
fscaling = 0.2
period = 10f0 #sine wave period
NT = Int(T/dt) #number of timesteps

σ = 0.1f0 #std deviation of initial condition

t = dt*(1:NT)

s = zeros(Float32,NT,1,B) #input, in this case just set to zero
rtarg = zeros(Float32,NT,R,B) #target output
for bi = 1:B
	rtarg[:,1,bi] = sign.(sin.(2π*t/period))
end


Random.seed!(1)
wsInit = randn(Float32,N,1)
wrInit = randn(Float32,R,N)/Float32(sqrt(N))
Jinit = randn(Float32,N,N)/Float32(sqrt(N))

ws = param(wsInit) #input to hidden
wr = param(wrInit) #hidden to readout
J = param(Jinit) #hidden to hidden


function calcloss(x0)
	loss = 0
	x = x0
	for ti = 1:NT
		x += dt*(-x + tanh.(J*x))
		r = wr*x

		loss += sum((r-rtarg[ti,:,:]).^2)/B
	end
	return loss
end


function calclossVisual(x0)
	loss = 0
	x = x0
    rAll = []
	for ti = 1:NT
		x += dt*(-x + tanh.(J*x))
		r = wr*x

		loss += sum((r-rtarg[ti,:,:]).^2)/B
push!(rAll,r.data)
	end
println("loss: ",loss.data)

rAll2 = reduce(hcat,rAll)
out = copy((rAll2)')
subplot(122)
plot(out)
plot(vec(rtarg[:,1,1]))
	return loss
end

function visualLoss(x0,NTtest)
	x = x0
    rAll = []
	for ti = 1:NTtest
		x += dt*(-x + tanh.( J*x))
		r = wr*x
push!(rAll,r)
    end
rAll2 = reduce(hcat,rAll)
out = copy((rAll2.data)')
#rtarg = zeros(Float32,NTtest)
rtargTest = sign.(sin.(2π*dt*(1:NTtest)/period))


subplot(122)
plot(out)
plot(rtargTest)
ylim(-1.5,1.5)
return rAll
end

opt = ADAM(η,(β₁,β₂))
ps=(ws,J,wr)
Random.seed!(6)
xinit = -randn(Float32,N,1)

prevt = time()
testError = []
for ei = 1:Nepochs
	print(ei,"\r")
    Random.seed!(ei)
	x0 = xinit .+ σ*randn(Float32,N,B) #initial condition

	Flux.train!(calcloss,ps,[[x0]],opt)
    Random.seed!(ei)
	xt = xinit .+ σ*randn(Float32,N,B) #initial condition for testing
    push!(testError,calcloss(xt))
GC.gc()
clf()
subplot(121)
semilogy([testError[i].data for i=1:length(testError)])

visualLoss(xinit .+ σ*randn(N),NT)
end
print("train time: ",time() - prevt)
end

In [None]:
fun()

### exercises
       

1. What is the minimum number of units to learn a sine wave?
2. Implement a square wave as desired output. How many units does it take now?
3. Investigate the effect of changing the hyperparamters β₁,β₂ and learning rate η and time T on learnability. Hypothesize, why/when learning fails.
4. Implement the fixed point task (input and output are a constant). What is the minimum number of units to learn a fixed point?
5. Compare the resulting geometry of representations for sine wave and fixed points with the analytical results by Rivkind/Barak 2017 and Mastrogiuseppe/Ostojic 2019
