Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the execution time of the shift transform? #41

Closed
iver56 opened this issue Nov 12, 2020 · 13 comments · Fixed by #42 or #51
Closed

Improve the execution time of the shift transform? #41

iver56 opened this issue Nov 12, 2020 · 13 comments · Fixed by #42 or #51
Assignees

Comments

@iver56
Copy link
Collaborator

iver56 commented Nov 12, 2020

Like https://github.com/iver56/audiomentations/blob/6f0b8b6c783a6c8268eb00120ab6bb25cab1aab7/audiomentations/augmentations/transforms.py#L273

@iver56 iver56 self-assigned this Nov 12, 2020
@mpariente
Copy link
Contributor

Suggestion: use torch.roll

@iver56
Copy link
Collaborator Author

iver56 commented Nov 12, 2020

Yep :) On it

@iver56
Copy link
Collaborator Author

iver56 commented Nov 12, 2020

Hmm, it looks like torch.roll doesn't support rolling each example in a batch by a different amount. I might have to do a for loop unless there's a better alternative

@mogwai
Copy link
Contributor

mogwai commented Nov 12, 2020

import math
import torch

@torch.jit.script
def batch_roll(tensor: torch.Tensor, max_pct:float=.2, per_channel:bool=False):
    b, c, t = tensor.shape
    
    # Max to roll by
    high = math.floor(max_pct*t)
    
    # Arange indexes
    x = torch.arange(b*c*t).view(b, c, t)
    pc = torch.tensor(per_channel).int()
    
    # for per channel
    pc = max(1, pc*1)
    r = torch.randint(high, (b,pc*1,1))
    
    # Force to be in range 0 < i < t
    idxs = (x - r + t) % t
    
    # Back to flattened indexes
    add = (torch.arange(b) * t)[:,None,None].repeat(1, c, t)
    idxs = add + idxs.long()
    
    return tensor.flatten()[idxs.flatten()].view(b,c,t)

x = torch.arange(5).view(1, 1 ,5).repeat(5,3,1)
batch_roll(x, 1)
# Input
tensor([[[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]],

        [[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]],

        [[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]]])

# Rolled
tensor([[[1, 2, 3, 4, 0],
         [1, 2, 3, 4, 0],
         [1, 2, 3, 4, 0]],

        [[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]],

        [[1, 2, 3, 4, 0],
         [1, 2, 3, 4, 0],
         [1, 2, 3, 4, 0]]])

@iver56
Copy link
Collaborator Author

iver56 commented Nov 12, 2020

Thanks 👍

@mogwai
Copy link
Contributor

mogwai commented Nov 12, 2020

Looking at your PR, you want to set the lower value to roll by then you need to add that to the method? Shall I submit this code to the PR?

@iver56
Copy link
Collaborator Author

iver56 commented Nov 12, 2020

Yes please 👍

@mpariente
Copy link
Contributor

That's cool !
Why torch.jit.script? How much faster is it?
I'm asking because starting JIT takes time at startup.

@mogwai
Copy link
Contributor

mogwai commented Nov 12, 2020

Not sure how much faster it is, after bench marking a few torch scripts, I assume it always improves it. However I didn't account for the time it takes to jit something, as it is usually negligible in comparison to the savings on a functions invocation over a few epochs.

I can remove it for now, or comment it out. I was going to open a ticket about trying to jit as many of the transforms in this library as possible as this would help with speed. I'd be happy to help with this too.

@mpariente
Copy link
Contributor

Sounds cool !

It's not the time to jit, but to start the jit compiler which I was referring to. If you're interested, you can have a look here: pytorch/pytorch#33418 (comment)

@mogwai
Copy link
Contributor

mogwai commented Nov 12, 2020

That explains why torchaudio doesn't jit anything but tests that some of the code can be jit compiled. Perhaps this library could aim for the same thing?

@iver56
Copy link
Collaborator Author

iver56 commented Nov 12, 2020

Perhaps it's cleaner if I merge my pull request, and then you can create a new pull request that proposes the improved implementation? We also have to check that the alternative is actually faster :)
There is some simple execution time measurement in scripts/demo.py already, but we could extend it if needed

@iver56 iver56 reopened this Nov 13, 2020
@iver56 iver56 changed the title Shift transform Improve the execution time of the shift transform? Nov 13, 2020
@iver56
Copy link
Collaborator Author

iver56 commented Nov 30, 2020

I compared the two approaches in #51 and figured out that the one with torch.roll is faster and less memory-hungry. Let's keep that one. Closing this issue now.

@iver56 iver56 closed this as completed Nov 30, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants