Skip to content

Commit

Permalink
glow_old.py: adding WN to glow_old to avoid res_skip_layers
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaelvalle committed Nov 14, 2018
1 parent dfefc09 commit 88f0ee0
Showing 1 changed file with 86 additions and 4 deletions.
90 changes: 86 additions & 4 deletions glow_old.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,88 @@
import copy
import torch
from glow import Invertible1x1Conv, remove
from glow import WN

@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b,n_channels):
n_channels_int = n_channels[0]
in_act = input_a+input_b
t_act = torch.nn.functional.tanh(in_act[:, :n_channels_int, :])
s_act = torch.nn.functional.sigmoid(in_act[:,n_channels_int:, :])
acts = t_act * s_act
return acts


class WN(torch.nn.Module):
"""
This is the WaveNet like layer for the affine coupling. The primary difference
from WaveNet is the convolutions need not be causal. There is also no dilation
size reset. The dilation only doubles on each layer
"""
def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
kernel_size):
super(WN, self).__init__()
assert(kernel_size % 2 == 1)
assert(n_channels % 2 == 0)
self.n_layers = n_layers
self.n_channels = n_channels
self.in_layers = torch.nn.ModuleList()
self.res_layers = torch.nn.ModuleList()
self.skip_layers = torch.nn.ModuleList()
self.cond_layers = torch.nn.ModuleList()

start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
start = torch.nn.utils.weight_norm(start, name='weight')
self.start = start

# Initializing last layer to 0 makes the affine coupling layers
# do nothing at first. This helps with training stability
end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1)
end.weight.data.zero_()
end.bias.data.zero_()
self.end = end

for i in range(n_layers):
dilation = 2 ** i
padding = int((kernel_size*dilation - dilation)/2)
in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size,
dilation=dilation, padding=padding)
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
self.in_layers.append(in_layer)

cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1)
cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
self.cond_layers.append(cond_layer)

# last one is not necessary
if i < n_layers - 1:
res_layer = torch.nn.Conv1d(n_channels, n_channels, 1)
res_layer = torch.nn.utils.weight_norm(res_layer, name='weight')
self.res_layers.append(res_layer)

skip_layer = torch.nn.Conv1d(n_channels, n_channels, 1)
skip_layer = torch.nn.utils.weight_norm(skip_layer, name='weight')
self.skip_layers.append(skip_layer)

def forward(self, forward_input):
audio, spect = forward_input
audio = self.start(audio)

for i in range(self.n_layers):
acts = fused_add_tanh_sigmoid_multiply(
self.in_layers[i](audio),
self.cond_layers[i](spect),
torch.IntTensor([self.n_channels]))

if i < self.n_layers - 1:
res_acts = self.res_layers[i](acts)
audio = res_acts + audio

if i == 0:
output = self.skip_layers[i](acts)
else:
output = self.skip_layers[i](acts) + output

return self.end(output)


class WaveGlow(torch.nn.Module):
Expand Down Expand Up @@ -140,12 +221,13 @@ def infer(self, spect, sigma=1.0):

return audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data

def remove_weightnorm(self):
waveglow = copy.deepcopy(self)
@staticmethod
def remove_weightnorm(model):
waveglow = copy.deepcopy(model)
for WN in waveglow.WN:
WN.start = torch.nn.utils.remove_weight_norm(WN.start)
WN.in_layers = remove(WN.in_layers)
WN.cond_layers = remove(WN.cond_layers)
WN.res_layers = remove(WN.res_layers)
WN.skip_layers = remove(WN.skip_layers)
self = waveglow
return waveglow

0 comments on commit 88f0ee0

Please sign in to comment.