Skip to content

Commit

Permalink
Add to() implementation for device-agnostic code
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnVinyard committed Jul 8, 2018
1 parent 1190a8d commit 6d94cc7
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions zounds/learn/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from zounds.timeseries import SampleRate


class PerceptualLoss(object):
class PerceptualLoss(nn.Module):
"""
`PerceptualLoss` computes loss/distance in a feature space that roughly
approximates early stages of the human audio processing pipeline, instead
Expand Down Expand Up @@ -70,11 +70,17 @@ def __init__(
.float().view(1, len(self.scale), 1)

def cuda(self, device=None):
self.weights = self.weights.cuda()
self.weights = self.weights.cuda(device=device)
if self.frequency_weights is not None:
self.frequency_weights = self.frequency_weights.cuda()
self.frequency_weights = self.frequency_weights.cuda(device=device)
return super(PerceptualLoss, self).cuda(device=device)

def to(self, device=None):
self.weights = self.weights.to(device)
if self.frequency_weights is not None:
self.frequency_weights = self.frequency_weights.to(device=device)
return super(PerceptualLoss, self).to(device=device)

def _transform(self, x):
x = x.view(x.shape[0], 1, -1)

Expand Down

0 comments on commit 6d94cc7

Please sign in to comment.