Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable TTS & other improvements #19

Merged
merged 6 commits into from
Nov 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
config/*
!config/voxceleb2.yaml
!config/unblizzard.yaml
!config/kss.yaml

# logs, checkpoints
chkpt/
Expand Down
41 changes: 41 additions & 0 deletions config/kss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
model:
tier: 6
layers: [12, 5, 4, 3, 2, 2]
hidden: 512
gmm: 10
---
data:
path: 'KSS'
extension: '*.wav'
---
audio:
sr: 22050
duration: 5.0
n_mels: 256
hop_length: 256
win_length: 1536
n_fft: 1536
num_freq: 769
ref_level_db: 20.0
min_level_db: -80.0
---
train:
num_workers: 4
optimizer: 'adam'
sgd:
lr: 0.0001
momentum: 0.9
rmsprop: # from paper
lr: 0.0001
momentum: 0.9
adam:
lr: 0.0001
# Gradient Accumulation
# you'll be specifying batch size with argument of trainer.py
# (update interval) * (batch size) = (paper's batch size) = 128
update_interval: 1 # for batch size 1.
---
log:
summary_interval: 1
chkpt_dir: 'chkpt'
log_dir: 'logs'
78 changes: 24 additions & 54 deletions datasets/wavloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,21 @@
from utils.utils import *
from utils.audio import MelGen
from utils.tierutil import TierUtil
# from text import text_to_sequence


from text import text_to_sequence

def create_dataloader(hp, args, train):
if train:
return DataLoader(dataset=AudioOnlyDataset(hp, args, True),
batch_size=args.batch_size,
shuffle=True,
num_workers=hp.train.num_workers,
pin_memory=True,
drop_last=True)
# collate_fn=TextCollate())
if args.tts:
dataset = AudioTextDataset(hp, args, train)
else:
return DataLoader(dataset=AudioOnlyDataset(hp, args, False),
batch_size=args.batch_size,
shuffle=False,
num_workers=1,
pin_memory=True,
drop_last=True)
# collate_fn=TextCollate())

dataset = AudioOnlyDataset(hp, args, train)

return DataLoader(dataset=dataset,
batch_size=args.batch_size,
shuffle=train,
num_workers=hp.train.num_workers,
pin_memory=True,
drop_last=True,
collate_fn=TextCollate())

class AudioOnlyDataset(Dataset):
def __init__(self, hp, args, train):
Expand All @@ -42,12 +34,6 @@ def __init__(self, hp, args, train):
self.tierutil = TierUtil(hp)

# this will search all files within hp.data.path
self.file_list = []
# for i, f in enumerate(glob.glob(os.path.join(hp.data.path, '**', hp.data.extension), recursive=True)):
# wav = read_wav_np(f)
# duraton = (len(wav)/hp.audio.sr)
# if duraton < hp.audio.duration:
# self.file_list.append(f)
self.file_list = glob.glob(os.path.join(hp.data.path, '**', hp.data.extension), recursive=True)

random.seed(123)
Expand All @@ -67,7 +53,7 @@ def __len__(self):
return len(self.file_list)

def __getitem__(self, idx):
wav = read_wav_np(self.file_list[idx])
wav = read_wav_np(self.file_list[idx], sample_rate=self.hp.audio.sr)
wav = cut_wav(self.wavlen, wav)
mel = self.melgen.get_normalized_mel(wav)
source, target = self.tierutil.cut_divide_tiers(mel, self.tier)
Expand All @@ -88,19 +74,18 @@ def __init__(self, hp, args, train):
# this will search all files within hp.data.path
self.root_dir = hp.data.path
self.dataset = []
with open(os.path.join(self.root_dir, 'transcript.v.1.2.txt'), 'r') as f:
with open(os.path.join(self.root_dir, 'transcript.v.1.3.txt'), 'r') as f:
lines = f.read().splitlines()
for line in lines:
wav_name, _, _, text, _ = line.split('|')
wav_name = wav_name[2:-4] + '.wav'
wav_name, _, _, text, _, _ = line.split('|')

wav_path = os.path.join(self.root_dir, 'wavs', wav_name)
wav = read_wav_np(wav_path)
wav_path = os.path.join(self.root_dir, 'kss', wav_name)
wav = read_wav_np(wav_path, sample_rate=self.hp.audio.sr)
duraton = (len(wav)/hp.audio.sr)
if duraton < hp.audio.duration:
self.dataset.append((wav_path, text))

#if len(self.dataset) > 100: break
# if len(self.dataset) > 100: break


random.seed(123)
Expand All @@ -123,42 +108,27 @@ def __getitem__(self, idx):
text = self.dataset[idx][1]
seq = text_to_sequence(text)

wav = read_wav_np(self.dataset[idx][0])
wav = read_wav_np(self.dataset[idx][0], sample_rate=self.hp.audio.sr)
wav = cut_wav(self.wavlen, wav)
mel = self.melgen.get_normalized_mel(wav)
source, target = self.tierutil.cut_divide_tiers(mel, self.tier)

return seq, source, target



class TextCollate():
def __init__(self):
return

def __call__(self, batch):
seq = [torch.from_numpy(x[0]).long() for x in batch]
# Right zero-pad all one-hot text sequences to max input length
input_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True)
max_input_len = input_lengths[0]

seq_padded = torch.zeros(len(batch), max_input_len, dtype=torch.long)
for i in range(len(ids_sorted_decreasing)):
seq = batch[ids_sorted_decreasing[i]][0]
seq_padded[i, :len(seq)] = torch.from_numpy(seq).long()
input_lengths = torch.LongTensor([x.shape[0] for x in seq])

source_padded = torch.stack( [ torch.from_numpy(x[1]) for x in batch] )
target_padded = torch.stack( [ torch.from_numpy(x[2]) for x in batch] )

### MASKING ###
equal_check = target_padded - target_padded[:, 0:1]
output_lengths = torch.sum(torch.any(equal_check!=0, dim=1), dim=-1).long()

idx = torch.arange(1, target_padded.size(-1)+1).long()
mask = (output_lengths.unsqueeze(-1) < idx.unsqueeze(0)).to(torch.bool) # B, T
source_padded.masked_fill_(mask.unsqueeze(1), 0)
target_padded.masked_fill_(mask.unsqueeze(1), 0)
seq_padded = torch.nn.utils.rnn.pad_sequence(seq, batch_first=True)
source_padded = torch.stack([torch.from_numpy(x[1]) for x in batch])
target_padded = torch.stack([torch.from_numpy(x[2]) for x in batch])

return seq_padded, input_lengths, source_padded, target_padded, output_lengths
return seq_padded, input_lengths, source_padded, target_padded



Expand Down
2 changes: 0 additions & 2 deletions model/tier.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def forward(self, x):
theta_hat = self.W_theta(h_f)

mu = theta_hat[..., :self.K] # eq. (3)
# std = torch.exp(theta_hat[..., self.K:2*self.K]) # eq. (4)
# pi = self.pi_softmax(theta_hat[..., 2*self.K:]) # eq. (5)
std = theta_hat[..., self.K:2*self.K]
pi = theta_hat[..., 2*self.K:]

Expand Down
91 changes: 45 additions & 46 deletions model/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,36 @@ def __init__(self, hp):

def attention(self, h_i, memory, ksi):
phi_hat = self.W_g(h_i)
ksi = ksi+torch.exp(phi_hat[:, :self.M])
ksi.clamp_(max=memory.size(1)-1)

beta = torch.exp( phi_hat[:, self.M:2*self.M] )

ksi = ksi + torch.exp(phi_hat[:, :self.M])
beta = torch.exp(phi_hat[:, self.M:2*self.M])
alpha = F.softmax(phi_hat[:, 2*self.M:3*self.M], dim=-1)

u = memory.new_tensor( range(memory.size(1)), dtype=torch.float )
u_R = u + 0.5
u_L = u - 0.5
u = memory.new_tensor(np.arange(memory.size(1)), dtype=torch.float)
u_R = u + 1.5
u_L = u + 0.5

term1 = torch.sum(alpha.unsqueeze(-1)*
torch.reciprocal(1 + torch.exp((ksi.unsqueeze(-1) - u_R)
/ beta.unsqueeze(-1))), dim=1)
term1 = torch.sum(
alpha.unsqueeze(-1) * torch.sigmoid(
(u_R - ksi.unsqueeze(-1)) / beta.unsqueeze(-1)
),
keepdim=True,
dim=1
)

term2 = torch.sum(alpha.unsqueeze(-1)*
torch.reciprocal(1 + torch.exp((ksi.unsqueeze(-1) - u_L)
/ beta.unsqueeze(-1))), dim=1)
term2 = torch.sum(
alpha.unsqueeze(-1) * torch.sigmoid(
(u_L - ksi.unsqueeze(-1)) / beta.unsqueeze(-1)
),
keepdim=True,
dim=1
)

weights = term1 - term2

weights = (term1-term2).unsqueeze(1)
weights = weights / torch.sum(weights, dim=-1, keepdim=True)
context = torch.bmm(weights, memory)

termination = 1 - torch.sum(alpha.unsqueeze(-1)*
torch.reciprocal(1 + torch.exp((ksi.unsqueeze(-1) - u_R)
/ beta.unsqueeze(-1))),
dim=1)
termination = 1 - term1.squeeze(1)

return context, weights, termination, ksi # (B, 1, D), (B, 1, T), (B, T)

Expand Down Expand Up @@ -73,11 +77,9 @@ def forward(self, input_h_c, memory, input_lengths):


class TTS(nn.Module):
def __init__(self, hp, freq, layers, tierN):
def __init__(self, hp, freq, layers):
super(TTS, self).__init__()
self.hp = hp
assert tierN==1, 'TTS tier must be 1'
self.tierN = tierN

self.W_t_0 = nn.Linear(1, hp.model.hidden)
self.W_f_0 = nn.Linear(1, hp.model.hidden)
Expand All @@ -91,20 +93,26 @@ def __init__(self, hp, freq, layers, tierN):

# map output to produce GMM parameter eq. (10)
self.W_theta = nn.Linear(hp.model.hidden, 3*self.K)

self.TextEncoder = nn.Sequential(nn.Embedding(len(symbols), hp.model.hidden),
nn.LSTM(input_size=hp.model.hidden,
hidden_size=hp.model.hidden//2,
batch_first=True,
bidirectional=True)
)

self.embedding_text = nn.Embedding(len(symbols), hp.model.hidden)
self.text_lstm = nn.LSTM(input_size=hp.model.hidden,
hidden_size=hp.model.hidden//2,
batch_first=True,
bidirectional=True)

self.attention = Attention(hp)

def text_encode(self, text, input_lengths):
total_length = text.size(1)
embed = self.embedding_text(text)
packed = nn.utils.rnn.pack_padded_sequence(embed, input_lengths, batch_first=True, enforce_sorted=False)
memory, _ = self.text_lstm(packed)
unpacked, _ = nn.utils.rnn.pad_packed_sequence(memory, batch_first=True, total_length=total_length)
return unpacked

def forward(self, x, text, input_lengths, output_lengths):
def forward(self, x, text, input_lengths):
# Extract memory
memory, _ = self.TextEncoder(text)
memory = self.text_encode(text, input_lengths)

# x: [B, M, T] / B=batch, M=mel, T=time
h_t = self.W_t_0(F.pad(x, [1, -1]).unsqueeze(-1))
Expand All @@ -127,23 +135,15 @@ def forward(self, x, text, input_lengths, output_lengths):
theta_hat = self.W_theta(h_f)

mu = theta_hat[:,:,:, :self.K] # eq. (3)
std = torch.exp(theta_hat[:,:,:, self.K:2*self.K]) # eq. (4)
pi = self.pi_softmax(theta_hat[:,:,:, 2*self.K:]) # eq. (5)

### MASKING ###
idx = torch.arange(1, mu.size(-2)+1, device=mu.device)
mask = (output_lengths.unsqueeze(-1) < idx.unsqueeze(0)).to(torch.bool) # B, T
mask = mask.unsqueeze(1).unsqueeze(3)

mu = torch.sigmoid(mu.masked_fill(mask, 0))
std = std.masked_fill(mask, 1/np.sqrt(2 * np.pi))
std = theta_hat[:,:,:, self.K:2*self.K] # eq. (4)
pi = theta_hat[:,:,:, 2*self.K:] # eq. (5)

return mu, std, pi, alignment


def sample(self, x, text, input_lengths):
# Extract memory
memory = self.TextEncoder(text)
memory = self.text_encode(text, input_lengths)

# x: [1, M, T] / B=1, M=mel, T=time
x_t, x_f = x.clone(), x.clone()
Expand All @@ -169,10 +169,9 @@ def sample(self, x, text, input_lengths):

theta_hat = self.W_theta(h_f)

mu = torch.sigmoid(theta_hat[:, :, :, :self.K]) # eq. (3)
pi = self.pi_softmax(theta_hat[:, :, :, 2*self.K:]) # eq. (5)

mu = torch.sum(mu*pi, dim=3)
mu = theta_hat[:, :, :, :self.K] # eq. (3)
std = theta_hat[:, :, :, self.K:2*self.K] # eq. (4)
pi = theta_hat[:, :, :, 2*self.K:] # eq. (5)

x_t[:,:,i+1] = mu[:,:,i]
x_f[:,i+1,:] = mu[:,i,:]
Expand Down
4 changes: 2 additions & 2 deletions model/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ def __init__(self, hp):
super(UpsampleRNN, self).__init__()
self.num_hidden = hp.model.hidden

self.rnn_x = nn.GRU(
self.rnn_x = nn.LSTM(
input_size=self.num_hidden, hidden_size=self.num_hidden, batch_first=True, bidirectional=True
)
self.rnn_y = nn.GRU(
self.rnn_y = nn.LSTM(
input_size=self.num_hidden, hidden_size=self.num_hidden, batch_first=True, bidirectional=True
)

Expand Down
2 changes: 2 additions & 0 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
help="Number of tier to train")
parser.add_argument('-b', '--batch_size', type=int, required=True,
help="Batch size")
parser.add_argument('-s', '--tts', type=bool, required=True,
help="TTS")
args = parser.parse_args()

hp = HParam(args.config)
Expand Down