Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgraded tensorflow and handled torch cuda #619

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 54 additions & 50 deletions hparams.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,93 @@
import tensorflow as tf
from text import symbols

class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self

def create_hparams(hparams_string=None, verbose=False):
"""Create model hyperparameters. Parse nondefault from given string."""

hparams = tf.contrib.training.HParams(
hparams = AttrDict({
################################
# Experiment Parameters #
################################
epochs=500,
iters_per_checkpoint=1000,
seed=1234,
dynamic_loss_scaling=True,
fp16_run=False,
distributed_run=False,
dist_backend="nccl",
dist_url="tcp://localhost:54321",
cudnn_enabled=True,
cudnn_benchmark=False,
ignore_layers=['embedding.weight'],
"epochs":500,
"iters_per_checkpoint":1000,
"seed":1234,
"dynamic_loss_scaling":True,
"fp16_run":False,
"distributed_run":False,
"dist_backend":"nccl",
"dist_url":"tcp://localhost:54321",
"cudnn_enabled":True,
"cudnn_benchmark":False,
"ignore_layers":['embedding.weight'],

################################
# Data Parameters #
################################
load_mel_from_disk=False,
training_files='filelists/ljs_audio_text_train_filelist.txt',
validation_files='filelists/ljs_audio_text_val_filelist.txt',
text_cleaners=['english_cleaners'],
"load_mel_from_disk":False,
"training_files":'filelists/ljs_audio_text_train_filelist.txt',
"validation_files":'filelists/ljs_audio_text_val_filelist.txt',
"text_cleaners":['english_cleaners'],

################################
# Audio Parameters #
################################
max_wav_value=32768.0,
sampling_rate=22050,
filter_length=1024,
hop_length=256,
win_length=1024,
n_mel_channels=80,
mel_fmin=0.0,
mel_fmax=8000.0,
"max_wav_value":32768.0,
"sampling_rate":22050,
"filter_length":1024,
"hop_length":256,
"win_length":1024,
"n_mel_channels":80,
"mel_fmin":0.0,
"mel_fmax":8000.0,

################################
# Model Parameters #
################################
n_symbols=len(symbols),
symbols_embedding_dim=512,
"n_symbols":len(symbols),
"symbols_embedding_dim":512,

# Encoder parameters
encoder_kernel_size=5,
encoder_n_convolutions=3,
encoder_embedding_dim=512,
"encoder_kernel_size":5,
"encoder_n_convolutions":3,
"encoder_embedding_dim":512,

# Decoder parameters
n_frames_per_step=1, # currently only 1 is supported
decoder_rnn_dim=1024,
prenet_dim=256,
max_decoder_steps=1000,
gate_threshold=0.5,
p_attention_dropout=0.1,
p_decoder_dropout=0.1,
"n_frames_per_step":1, # currently only 1 is supported
"decoder_rnn_dim":1024,
"prenet_dim":256,
"max_decoder_steps":1000,
"gate_threshold":0.5,
"p_attention_dropout":0.1,
"p_decoder_dropout":0.1,

# Attention parameters
attention_rnn_dim=1024,
attention_dim=128,
"attention_rnn_dim":1024,
"attention_dim":128,

# Location Layer parameters
attention_location_n_filters=32,
attention_location_kernel_size=31,
"attention_location_n_filters":32,
"attention_location_kernel_size":31,

# Mel-post processing network parameters
postnet_embedding_dim=512,
postnet_kernel_size=5,
postnet_n_convolutions=5,
"postnet_embedding_dim":512,
"postnet_kernel_size":5,
"postnet_n_convolutions":5,

################################
# Optimization Hyperparameters #
################################
use_saved_learning_rate=False,
learning_rate=1e-3,
weight_decay=1e-6,
grad_clip_thresh=1.0,
batch_size=64,
mask_padding=True # set model's padded outputs to padded values
)
"use_saved_learning_rate":False,
"learning_rate":1e-3,
"weight_decay":1e-6,
"grad_clip_thresh":1.0,
"batch_size":64,
"mask_padding":True # set model's padded outputs to padded values
})

if hparams_string:
tf.logging.info('Parsing command line hparams: %s', hparams_string)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
matplotlib==2.1.0
tensorflow==1.15.2
tensorflow==2.15.0
numpy==1.13.3
inflect==0.2.5
librosa==0.6.0
Expand Down
5 changes: 4 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def prepare_directories_and_logger(output_directory, log_directory, rank):


def load_model(hparams):
model = Tacotron2(hparams).cuda()
if torch.cuda.is_available():
model = Tacotron2(hparams).cuda()
else:
model = Tacotron2(hparams)
if hparams.fp16_run:
model.decoder.attention_layer.score_mask_value = finfo('float16').min

Expand Down
5 changes: 4 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

def get_mask_from_lengths(lengths):
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
if torch.cuda.is_available():
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
else:
ids = torch.arange(0, max_len, out=torch.LongTensor(max_len))
mask = (ids < lengths.unsqueeze(1)).bool()
return mask

Expand Down