In [63]:
import importlib
import data.VCTK
import model.model
import model.trainer
import torch

In [64]:
importlib.reload(data.VCTK) # prevent cache when making changes
dataset = data.VCTK.VCTKDataset("VCTK-Corpus-smaller/")

In [65]:
# device = torch.device("cuda")
device = torch.device("cpu")

In [66]:
importlib.reload(model.model) # prevent cache when making changes

# deallocate if exists
try:
    del myModel
except NameError:
    pass

myModel = model.model.SpectrogramModel().to(device)
print(f"Model has {myModel.get_param_count()} parameters.")

Model has 129184 parameters.


In [67]:
import sys
import os.path

# add WaveRNN to path
sys.path.append(os.path.abspath("ForwardTacotron"))
import models.fatchord_version
importlib.reload(models.fatchord_version)

# deallocate if exists
try:
    del vocoderModel
except NameError:
    pass

vocoderModel = models.fatchord_version.WaveRNN(
    rnn_dims=512,
    fc_dims=512,
    bits=9, # OrigAuthor: bit depth of signal
    pad=2, # OrigAuthor: this will pad the input so that the resnet can 'see' wider than input length
    upsample_factors=(5, 5, 8), # OrigAuthor: NB - this needs to correctly factorise hop_length
    feat_dims=80,
    compute_dims=128,
    res_out_dims=128,
    res_blocks=10,
    hop_length=200,
    sample_rate=16000,
    mode="RAW", # OrigAuthor: either 'RAW' (softmax on raw bits) or 'MOL' (sample from mixture of logistics)
).to(device)

In [68]:
t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
f = r-a  # free inside reserved
def reasonable(n):
    return n/(1000**2)
print(reasonable(t),reasonable(r),reasonable(a),reasonable(f))

11996.954624 0.0 0.0 0.0


In [69]:
try:
    del x
except NameError:
    pass

before = torch.cuda.memory_allocated(0)

text, clips, spectros = dataset[0]
batch = 1
# take one audio clip and sectrogram, repeat for overfitting test
clips = clips[0,:].to(device).repeat(batch,1)
spectros = spectros[0,:,:].to(device).repeat(batch,1,1)
print("clips", clips.shape)
print("spectros", spectros.shape)

after = torch.cuda.memory_allocated(0)
print(f"Took {reasonable(after-before):.1f}MB")

clips torch.Size([1, 25600])
spectros torch.Size([1, 80, 129])
Took 0.0MB


In [70]:
print(spectros.shape)
myModel(spectros).shape

torch.Size([1, 80, 129])


torch.Size([1, 80, 128])

In [71]:
try:
    del myTrainer
except NameError:
    pass
importlib.reload(model.trainer) # prevent cache when making changes
myTrainer = model.trainer.Trainer(myModel, vocoderModel, device)

In [72]:
torch.cuda.empty_cache()

In [73]:
loss = myTrainer.train_step(clips, spectros)
print("loss", loss)

teehee
audios torch.Size([1, 25600]) pred torch.Size([1, 25600])
tensor(3.6780, grad_fn=<AddBackward0>) tensor(1.5073, grad_fn=<L1LossBackward0>) tensor(2.1707, grad_fn=<DivBackward1>)
loss tensor(3.6780, grad_fn=<AddBackward0>)


In [74]:
import tqdm.notebook as tqdm
for i, _ in tqdm.tqdm(list(enumerate(range(0,10)))):
    total_loss, spectro_loss, vocoder_loss = \
        myTrainer.train_step(clips, spectros)
    print(\
        "total", total_loss.numpy(),\
        "spectro", spectro_loss.numpy(),\
        "vocoder" vocoder_loss.numpy())


  0%|          | 0/100 [00:00<?, ?it/s]

teehee
audios torch.Size([1, 25600]) pred torch.Size([1, 25600])
tensor(1.7347, grad_fn=<AddBackward0>) tensor(1.5190, grad_fn=<L1LossBackward0>) tensor(0.2156, grad_fn=<DivBackward1>)
teehee
audios torch.Size([1, 25600]) pred torch.Size([1, 25600])
tensor(-5.2986, grad_fn=<AddBackward0>) tensor(1.6274, grad_fn=<L1LossBackward0>) tensor(-6.9260, grad_fn=<DivBackward1>)
teehee
audios torch.Size([1, 25600]) pred torch.Size([1, 25600])
tensor(-29.6101, grad_fn=<AddBackward0>) tensor(1.7403, grad_fn=<L1LossBackward0>) tensor(-31.3504, grad_fn=<DivBackward1>)
teehee
audios torch.Size([1, 25600]) pred torch.Size([1, 25600])
tensor(-45.4433, grad_fn=<AddBackward0>) tensor(1.9017, grad_fn=<L1LossBackward0>) tensor(-47.3450, grad_fn=<DivBackward1>)
teehee
audios torch.Size([1, 25600]) pred torch.Size([1, 25600])
tensor(-118.6740, grad_fn=<AddBackward0>) tensor(2.0889, grad_fn=<L1LossBackward0>) tensor(-120.7629, grad_fn=<DivBackward1>)
teehee
audios torch.Size([1, 25600]) pred torch.Size([1, 25

In [76]:
import pathlib
import os.path
checkpoint_dir = pathlib.Path("checkpoints")
checkpoint_dir.mkdir(exist_ok=True)
model_path = checkpoint_dir / "model2.pth"
if os.path.exists(model_path):
    raise Exception("Won't overwrite existing models")
else:
    torch.save(myModel.state_dict(), model_path)

Exception: Won't overwrite existing models