Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fold.jl #303

Closed
wants to merge 2 commits into from
Closed

Add fold.jl #303

wants to merge 2 commits into from

Conversation

spazewalker
Copy link

Adding Fold/Unfold functions in reference to pytorch feature parity here.

@CarloLucibello
Copy link
Member

from a quick glance looks mostly ok, thanks. Convenience methods for the 1d and 2d cases are missing.
Then some tests and rrules

Copy link
Member

@DhairyaLGandhi DhairyaLGandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much to look into this! I've left a couple thoughts but overall I'm wondering if you've checked this on the GPU as well

src/fold.jl Outdated
cdims = DenseConvDims(x_dim, w_dim; stride=stride, padding=padding, dilation=dilation)

# Calculate the total number of sliding blocks
col_dim = (im2col_dims(cdims))[1:2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the 1:2 stand for here? Are you trying to extract the features from eg the channel separation/ batching dimension?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im2col_dims() returns a Tuple of 3 elements (L, M, N) where N is the number of thread. However, im2col!() assumes a 2d array col in the shape of (L, M). Hence the [1:2], to take only the first two.

src/fold.jl Outdated
# Iterate through all batchs
for i = 1:x_dim[end]
temp = fill(0., col_dim)
im2col!(temp, X[:,:,:,:,i], cdims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would allocate the datapoint everytime, maybe use views?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted.

src/fold.jl Outdated

# Calculate the total number of sliding blocks
col_dim = (im2col_dims(cdims))[1:2]
col = undef
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little uncomfortable using undef since it have erratic behaviour in corner cases.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted. I'll instead use fill() to initialise and then use views to fill it with actual values in the loop.

src/fold.jl Outdated
col = undef

# Iterate through all batchs
for i = 1:x_dim[end]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to consider the possibility of having some other dimsension to fold over as well. Not critical though

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x_dim[end] is the number of batches. This loop iterates over all batches and unfolds them, and then stack each unfolded array on top of the previous one along a new dimension. Folding over other dimensions can be implemented on top of this implementation by adding a dummy singleton dimension.

src/fold.jl Outdated
if i == 1
col = reshape(temp, col_dim[1], col_dim[2], 1)
else
col = cat(dims=3, col, temp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll drop this if-else clause and use views here too.

src/fold.jl Outdated
col = cat(dims=3, col, temp)
end
end
return col;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer to drop the return, and there is a trailing semicolon

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove the semicolon and are you talking about dropping the return statement or the return character?

Copy link
Author

@spazewalker spazewalker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from a quick glance looks mostly ok, thanks. Convenience methods for the 1d and 2d cases are missing.
Then some tests and rrules

Should I implement them here or will they be taken care of while implementing a layer in Flux.jl.

src/fold.jl Outdated
cdims = DenseConvDims(x_dim, w_dim; stride=stride, padding=padding, dilation=dilation)

# Calculate the total number of sliding blocks
col_dim = (im2col_dims(cdims))[1:2]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im2col_dims() returns a Tuple of 3 elements (L, M, N) where N is the number of thread. However, im2col!() assumes a 2d array col in the shape of (L, M). Hence the [1:2], to take only the first two.

src/fold.jl Outdated

# Calculate the total number of sliding blocks
col_dim = (im2col_dims(cdims))[1:2]
col = undef
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted. I'll instead use fill() to initialise and then use views to fill it with actual values in the loop.

src/fold.jl Outdated
col = undef

# Iterate through all batchs
for i = 1:x_dim[end]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x_dim[end] is the number of batches. This loop iterates over all batches and unfolds them, and then stack each unfolded array on top of the previous one along a new dimension. Folding over other dimensions can be implemented on top of this implementation by adding a dummy singleton dimension.

src/fold.jl Outdated
# Iterate through all batchs
for i = 1:x_dim[end]
temp = fill(0., col_dim)
im2col!(temp, X[:,:,:,:,i], cdims)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted.

src/fold.jl Outdated
if i == 1
col = reshape(temp, col_dim[1], col_dim[2], 1)
else
col = cat(dims=3, col, temp)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll drop this if-else clause and use views here too.

src/fold.jl Outdated
col = cat(dims=3, col, temp)
end
end
return col;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove the semicolon and are you talking about dropping the return statement or the return character?

@spazewalker
Copy link
Author

spazewalker commented Mar 19, 2021

Thanks so much to look into this! I've left a couple thoughts but overall I'm wondering if you've checked this on the GPU as well

I haven't tested it on GPU yet, I'll check it once.

@DhairyaLGandhi
Copy link
Member

We might not have a layer in or maybe we need to just hold the config to construct the cdims object.

@ToucheSir
Copy link
Member

from a quick glance looks mostly ok, thanks. Convenience methods for the 1d and 2d cases are missing.
Then some tests and rrules

Should I implement them here or will they be taken care of while implementing a layer in Flux.jl.

Thanks for the PR! Tests and rrules need to be here. I'm not sure what exactly is meant by convenience methods, but if it makes sense to call directly (as opposed to only as a helper for a layer struct) then it should be in NNlib as well.

@spazewalker
Copy link
Author

We might not have a layer in or maybe we need to just hold the config to construct the cdims object.

Okay, so I'll try to implement those convenience wrappers.

from a quick glance looks mostly ok, thanks. Convenience methods for the 1d and 2d cases are missing.
Then some tests and rrules

Should I implement them here or will they be taken care of while implementing a layer in Flux.jl.

Thanks for the PR! Tests and rrules need to be here. I'm not sure what exactly is meant by convenience methods, but if it makes sense to call directly (as opposed to only as a helper for a layer struct) then it should be in NNlib as well.

As per my understanding, Those are the wrappers around this implementation, so that it can handle 1d and 2d inputs too. I'll work on tests but I'm not sure what 'rrule`s are.

@ToucheSir
Copy link
Member

rrules are what allow for the AD system and gradients to work, more details at https://juliadiff.org/ChainRulesCore.jl/stable/#frule-and-rrule. For functions like this that use mutation and inner loops, you'll probably need another function that manually performs the backwards pass. They're a little abstracted, but it may be worth looking at how the conv rrules work:

@eval function rrule(::typeof($conv), x, w, cdims; kw...)
.

# Create DenseConvDims object
col_dim = size(col)
channels = col_dim[2]÷prod(w_dim)
x_dim = (out_dim... , fill(3-length(out_dim))... , channels,col_dim[3])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The calls to fill seem suspect. Did you need ntuple there?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed a ntuple with dimensions of out_dim appended with dummy dimension to make it 3d. I think it should be fill(1,3-length(out_dim)). It's okay to use fill here right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use insert_singleton_dimension instead. Or a reshape. fill is incorrect here.

end

return reshape(img, (out_dim... , channels,col_dim[3]))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely don't want to see this. Possibly the editor is using file endings different from Linux. I'd check that.

@@ -0,0 +1,70 @@
export unfold, fold
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not export them, the names would dirty the name space very quickly.

@nikopj nikopj mentioned this pull request Nov 23, 2022
2 tasks
@mcabbott
Copy link
Member

Replaced by #444

@mcabbott mcabbott closed this Nov 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants