Skip to content

Commit

Permalink
activation normalization is optional, and parameterized
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnVinyard committed Mar 8, 2018
1 parent ffa0675 commit 6e7d4af
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions zounds/learn/gated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from util import sample_norm
from torch import nn
from torch.nn import functional as F

Expand All @@ -13,8 +12,11 @@ def __init__(
stride=1,
padding=0,
dilation=1,
attention_func=F.sigmoid):
attention_func=F.sigmoid,
norm=lambda x: x):

super(GatedLayer, self).__init__()
self.norm = norm
self.conv = layer_type(
in_channels=in_channels,
out_channels=out_channels,
Expand All @@ -35,9 +37,9 @@ def __init__(

def forward(self, x):
c = self.conv(x)
c = sample_norm(c)
c = self.norm(c)
g = self.gate(x)
g = sample_norm(g)
g = self.norm(g)
out = F.tanh(c) * self.attention_func(g)
return out

Expand All @@ -51,7 +53,8 @@ def __init__(
stride=1,
padding=0,
dilation=1,
attention_func=F.sigmoid):
attention_func=F.sigmoid,
norm=lambda x: x):
super(GatedConvLayer, self).__init__(
nn.Conv1d,
in_channels,
Expand All @@ -60,7 +63,8 @@ def __init__(
stride,
padding,
dilation,
attention_func)
attention_func,
norm)


class GatedConvTransposeLayer(GatedLayer):
Expand All @@ -72,7 +76,8 @@ def __init__(
stride=1,
padding=0,
dilation=1,
attention_func=F.sigmoid):
attention_func=F.sigmoid,
norm=lambda x: x):
super(GatedConvTransposeLayer, self).__init__(
nn.ConvTranspose1d,
in_channels,
Expand All @@ -81,4 +86,5 @@ def __init__(
stride,
padding,
dilation,
attention_func)
attention_func,
norm)

0 comments on commit 6e7d4af

Please sign in to comment.