Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing Axis Swap in ExtractPatches and MergePatches #2

Closed
young-geng opened this issue Feb 27, 2022 · 4 comments
Closed

Missing Axis Swap in ExtractPatches and MergePatches #2

young-geng opened this issue Feb 27, 2022 · 4 comments
Labels
bug Something isn't working

Comments

@young-geng
Copy link

In patch_utils.py, the modules ExtractPatches and MergePatches are missing an axis swap between the reshapes, resulting in the extracted patches becoming horizontal stripes. For example, if we follow the code in ExtractPatches:

>>> inputs = jnp.arange(16).reshape(1, 4, 4, 1)
>>> inputs[0, :, :, 0]

DeviceArray([[ 0,  1,  2,  3],
             [ 4,  5,  6,  7],
             [ 8,  9, 10, 11],
             [12, 13, 14, 15]], dtype=int32)

>>> patch_size = 2
>>> batch, height, width, channels = inputs.shape
>>> height, width = height // patch_size, width // patch_size
>>> x = jnp.reshape(inputs, (batch, height, patch_size, width, patch_size, channels))
>>> x = jnp.reshape(x, (batch, height * width, patch_size ** 2 * channels))
>>> x[0, 0, :]

DeviceArray([0, 1, 2, 3], dtype=int32)

We see that the first patch extracted is not the patch containing [0, 1, 4, 5], but the horizontal stripe [0, 1, 2, 3].
To fix this problem, we should add an axis swap. For ExtractPatches, this should be:

batch, height, width, channels = inputs.shape
height, width = height // patch_size, width // patch_size
x = jnp.reshape(
    inputs, (batch, height, patch_size, width, patch_size, channels)
)
x = jnp.swapaxes(x, 2, 3)
x = jnp.reshape(x, (batch, height * width, patch_size ** 2 * channels))

For MergePatches, this should be:

batch, length, _ = inputs.shape
height = width = int(length**0.5)
x = jnp.reshape(inputs, (batch, height, width, patch_size, patch_size, -1))
x = jnp.swapaxes(x, 2, 3)
x = jnp.reshape(x, (batch, height * patch_size, width * patch_size, -1))
@young-geng young-geng added the bug Something isn't working label Feb 27, 2022
@DarshanDeshpande
Copy link
Owner

@young-geng Thanks for pointing it out. If you would like to create a PR then please go ahead otherwise I'll correct it in a couple of days along with the next push.

@DarshanDeshpande
Copy link
Owner

Fixed in 2f063c9. Closing this issue.

@young-geng
Copy link
Author

Unfortunately the problem has not been completely fixed. Because we have changed our order of axes in ExtractPatches, we have to take that into account when reshaping in MergePatches. Therefore, in this line, we need to change

x = jnp.reshape(inputs, (batch, height, patch_size, width, patch_size, -1))

to

x = jnp.reshape(inputs, (batch, height, width, patch_size, patch_size, -1))

Thanks a lot for working on this and fixing the problem! This project is awesome!

@DarshanDeshpande
Copy link
Owner

Fixed. I have also added more specific tests for both ExtractPatches and MergePatches. Let me know if there is anything else :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants