Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conv constructor should accept scalar pad/stride #235

Closed
MikeInnes opened this issue Apr 14, 2018 · 6 comments
Closed

Conv constructor should accept scalar pad/stride #235

MikeInnes opened this issue Apr 14, 2018 · 6 comments

Comments

@MikeInnes
Copy link
Member

MikeInnes commented Apr 14, 2018

and just expand them to N dimensions.

So should the pooling function (so the logic probably needs to go in NNlib).

@tejank10
Copy link
Contributor

I'll take this up

@tejank10
Copy link
Contributor

tejank10 commented Apr 15, 2018

Would that mean instead of

Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
     stride::NTuple{N,Integer} = map(_->1,k),
     pad::NTuple{N,Integer} = map(_->0,k))

we should have something like

Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
     stride::Integer = 1, pad::Integer = 0)

?
Because both of them cannot exist together. Please correct me if I am mistaken.

@MikeInnes
Copy link
Member Author

MikeInnes commented Apr 15, 2018

They can exist together, but you'll obviously have to remove the types. Higher-level constructors can default to 1/0 and forward whatever they are given. When we actually construct a conv object, call some function like

expand(::Val{N}, i::Integer) = ntuple(_ -> i, Val{N})
expand(::Val{N}, i::NTuple{N}) = i
# error

In the conv case, N is ndims(weight)-2. In the pooling case it's length(kernel).

@tejank10
Copy link
Contributor

I had tried that approach, and it gave me method overwriting warnings. It was something like the following:

Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
     stride::Integer = 1, pad::Integer = 0) where N =
  Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
       stride = map(_->stride,k), pad = map(_->pad,k))

@MikeInnes
Copy link
Member Author

Right, don't use dispatch in the Conv constructor for this, just dispatch inside expand at the last minute. If that's not clear I'll put up a code snippet.

@MikeInnes
Copy link
Member Author

Closed by #237.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants