New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve the execution time of the shift transform? #41
Comments
Suggestion: use |
Yep :) On it |
Hmm, it looks like torch.roll doesn't support rolling each example in a batch by a different amount. I might have to do a for loop unless there's a better alternative |
import math
import torch
@torch.jit.script
def batch_roll(tensor: torch.Tensor, max_pct:float=.2, per_channel:bool=False):
b, c, t = tensor.shape
# Max to roll by
high = math.floor(max_pct*t)
# Arange indexes
x = torch.arange(b*c*t).view(b, c, t)
pc = torch.tensor(per_channel).int()
# for per channel
pc = max(1, pc*1)
r = torch.randint(high, (b,pc*1,1))
# Force to be in range 0 < i < t
idxs = (x - r + t) % t
# Back to flattened indexes
add = (torch.arange(b) * t)[:,None,None].repeat(1, c, t)
idxs = add + idxs.long()
return tensor.flatten()[idxs.flatten()].view(b,c,t)
x = torch.arange(5).view(1, 1 ,5).repeat(5,3,1)
batch_roll(x, 1)
|
Thanks 👍 |
Looking at your PR, you want to set the lower value to roll by then you need to add that to the method? Shall I submit this code to the PR? |
Yes please 👍 |
That's cool ! |
Not sure how much faster it is, after bench marking a few torch scripts, I assume it always improves it. However I didn't account for the time it takes to jit something, as it is usually negligible in comparison to the savings on a functions invocation over a few epochs. I can remove it for now, or comment it out. I was going to open a ticket about trying to jit as many of the transforms in this library as possible as this would help with speed. I'd be happy to help with this too. |
Sounds cool ! It's not the time to jit, but to start the jit compiler which I was referring to. If you're interested, you can have a look here: pytorch/pytorch#33418 (comment) |
That explains why torchaudio doesn't jit anything but tests that some of the code can be jit compiled. Perhaps this library could aim for the same thing? |
Perhaps it's cleaner if I merge my pull request, and then you can create a new pull request that proposes the improved implementation? We also have to check that the alternative is actually faster :) |
I compared the two approaches in #51 and figured out that the one with |
Like https://github.com/iver56/audiomentations/blob/6f0b8b6c783a6c8268eb00120ab6bb25cab1aab7/audiomentations/augmentations/transforms.py#L273
The text was updated successfully, but these errors were encountered: