Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLArrays issue #77

Closed
rveltz opened this issue Oct 9, 2017 · 3 comments
Closed

CLArrays issue #77

rveltz opened this issue Oct 9, 2017 · 3 comments

Comments

@rveltz
Copy link

rveltz commented Oct 9, 2017

Hi,

When running the MNIST example as follows, there is an error. It seems there is an error

ERROR: LoadError: MethodError: Cannot `convert` an object of type Flux.OneHotMatrix{Array{Flux.OneHotVector,1}} to an object of type CLArrays.CLArray
This may have arisen from a call to the constructor CLArrays.CLArray(...),
since type constructors fall back to convert methods.
Stacktrace:
 [1] CLArrays.CLArray(::Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}) at ./sysimg.jl:24
 [2] include_from_node1(::String) at ./loading.jl:569
 [3] include(::String) at ./sysimg.jl:14
 [4] process_options(::Base.JLOptions) at ./client.jl:305
 [5] _start() at ./client.jl:371
while loading /Users/rveltz/work/prog_gd/julia/flux-mnist-cl.jl, in expression starting on line 20

because the type OneHotMatrix has not been wrapped into CLArrays. Is it an easy fix?

Thank you for your help,

Best regards.

using Flux, MNIST
using Flux: onehotbatch, argmax, mse, throttle
using Base.Iterators: repeated

x, y = traindata()
y = onehotbatch(y, 0:9)

m = Chain(
  Dense(28^2, 32, relu),
  Dense(32, 10),
  softmax)

# using CuArrays
# x, y = cu(x), cu(y)
# m = mapparams(cu, m)

using CLArrays
CLArrays.init(CLArrays.devices()[2])
cl = CLArray
x, y = cl(x), cl(y)
m = mapparams(cl, m)


loss(x, y) = mse(m(x), y)

dataset = repeated((x, y), 200)
evalcb = () -> @show(loss(x, y))
opt = SGD(params(m), 0.1)

Flux.train!(loss, dataset, opt, cb = throttle(evalcb, 5))

# Check the prediction for the first digit
argmax(m(x[:,1]), 0:9) == argmax(y[:,1], 0:9)
@MikeInnes
Copy link
Member

We need a little more support on the CLArrays side for this (see #72).

@rveltz
Copy link
Author

rveltz commented Oct 13, 2017

OK!

@MikeInnes
Copy link
Member

Closing so we can track this in #173.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants