# Setup

In [5]:
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
from tensorflow import keras
from keras import layers
from diffusion_model_for_speech import *
from wavenet import *
from speech_dataset import *
import math
import numpy as np
import glob

# Hyperparameters

In [6]:
diffusion_steps = 20
learning_rate = 2e-4
weight_decay = 1e-5
checkpoint_path = "checkpoints/diffusion_model_for_voice/diffusion_steps_%d"%(diffusion_steps)
final_checkpoint_path = "checkpoints/diffusion_model_for_voice/final_diffusion_steps_%d"%(diffusion_steps)

# Data Pipeline

In [7]:
dataset = Speech_dataset()

dataset size: 180
sample size: (66270,)


2023-04-01 14:38:47.496899: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22292 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:af:00.0, compute capability: 8.6
2023-04-01 14:38:48.005466: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


In [4]:
dataset_tf = dataset.audio_dataset()

In [5]:
for signal, logmel in dataset_tf.take(5):
    print("Signal shape:", signal.shape)
    print("Log Mel shape:", logmel.shape)

Signal shape: (8, 6400)
Log Mel shape: (8, 25, 80)
Signal shape: (8, 6400)
Log Mel shape: (8, 25, 80)
Signal shape: (8, 6400)
Log Mel shape: (8, 25, 80)
Signal shape: (8, 6400)
Log Mel shape: (8, 25, 80)
Signal shape: (8, 6400)
Log Mel shape: (8, 25, 80)


# Training

In [6]:
model = DiffusionModel(diffusion_steps = diffusion_steps)

model.compile(
    optimizer=tfa.optimizers.AdamW(
        learning_rate=learning_rate, 
        weight_decay=weight_decay),
    loss=keras.losses.MeanSquaredError(),  # pixelwise mean absolute error is used as loss, mse also ok
    #run_eagerly=True
)

# save the best model based on the validation KID metric
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    monitor="mel_loss",
    mode="min",
    save_best_only=True,
)

final_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=final_checkpoint_path,
    save_weights_only=True,
)

reduce_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='noise_loss',  
    factor=0.5,  
    patience=5,  
    verbose=1,  
    min_lr=1e-6  
)
# calculate mean and variance of training dataset for normalization
#model.normalizer.adapt(dataset_tf)
#model.load_weights(checkpoint_path)

total_size = [i for i,_ in enumerate(dataset_tf)][-1] + 1

val_split = 0.1
val_size = int(total_size * val_split)
train_size = total_size - val_size

dataset_tf = dataset_tf.shuffle(buffer_size=total_size)
train_dataset = dataset_tf.take(train_size)
val_dataset = dataset_tf.skip(train_size)


In [7]:
for signal, logmel in train_dataset.take(5):
    print("Signal shape:", signal.shape)
    print("Log Mel shape:", logmel.shape)

Signal shape: (8, 6400)
Log Mel shape: (8, 25, 80)
Signal shape: (8, 6400)
Log Mel shape: (8, 25, 80)
Signal shape: (8, 6400)
Log Mel shape: (8, 25, 80)
Signal shape: (8, 6400)
Log Mel shape: (8, 25, 80)
Signal shape: (8, 6400)
Log Mel shape: (8, 25, 80)


In [None]:
#if put tensor into fit directly, it can't use the validation_data. please use the validation_split to split the data into training and validation.
model.fit(
    dataset_tf,
    epochs=num_epochs,
    validation_data=val_dataset,
    callbacks=[
        #keras.callbacks.LambdaCallback(on_epoch_end=model.to_speech), 
        checkpoint_callback,
        final_checkpoint_callback,
        #reduce_lr_callback 
    ],
)

Epoch 1/10000


2023-04-01 12:05:41.559250: I tensorflow/stream_executor/cuda/cuda_dnn.cc:369] Loaded cuDNN version 8204
2023-04-01 12:05:44.252171: I tensorflow/stream_executor/cuda/cuda_blas.cc:1760] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


Epoch 2/10000
Epoch 3/10000
Epoch 4/10000
Epoch 5/10000
Epoch 6/10000
Epoch 7/10000
Epoch 8/10000
Epoch 9/10000
Epoch 10/10000
Epoch 11/10000
Epoch 12/10000
Epoch 13/10000
Epoch 14/10000
Epoch 15/10000
Epoch 16/10000
Epoch 17/10000
Epoch 18/10000
Epoch 19/10000
Epoch 20/10000
Epoch 21/10000
Epoch 22/10000
Epoch 23/10000
Epoch 24/10000
Epoch 25/10000
Epoch 26/10000
Epoch 27/10000
Epoch 28/10000
Epoch 29/10000
Epoch 30/10000
Epoch 31/10000
Epoch 32/10000
Epoch 33/10000
Epoch 34/10000
Epoch 35/10000
Epoch 36/10000
Epoch 37/10000
Epoch 38/10000
Epoch 39/10000
Epoch 40/10000
Epoch 41/10000
Epoch 42/10000
Epoch 43/10000
Epoch 44/10000
Epoch 45/10000
Epoch 46/10000
Epoch 47/10000
Epoch 48/10000
Epoch 49/10000
Epoch 50/10000
Epoch 51/10000
Epoch 52/10000
Epoch 53/10000
Epoch 54/10000
Epoch 55/10000
Epoch 56/10000
Epoch 57/10000
Epoch 58/10000
Epoch 59/10000
Epoch 60/10000
Epoch 61/10000
Epoch 62/10000
Epoch 63/10000
Epoch 64/10000
Epoch 65/10000
Epoch 66/10000
Epoch 67/10000
Epoch 68/10000
Epo

Epoch 79/10000
Epoch 80/10000
Epoch 81/10000
Epoch 82/10000
Epoch 83/10000
Epoch 84/10000
Epoch 85/10000
Epoch 86/10000
Epoch 87/10000
Epoch 88/10000
Epoch 89/10000
Epoch 90/10000
Epoch 91/10000
Epoch 92/10000
Epoch 93/10000
Epoch 94/10000
Epoch 95/10000
Epoch 96/10000
Epoch 97/10000
Epoch 98/10000
Epoch 99/10000
Epoch 100/10000
Epoch 101/10000
Epoch 102/10000
Epoch 103/10000
Epoch 104/10000
Epoch 105/10000
Epoch 106/10000
Epoch 107/10000
Epoch 108/10000
Epoch 109/10000
Epoch 110/10000
Epoch 111/10000
Epoch 112/10000
Epoch 113/10000
Epoch 114/10000
Epoch 115/10000
Epoch 116/10000
Epoch 117/10000
Epoch 118/10000
Epoch 119/10000
Epoch 120/10000
Epoch 121/10000
Epoch 122/10000
Epoch 123/10000
Epoch 124/10000
Epoch 125/10000
Epoch 126/10000
Epoch 127/10000
Epoch 128/10000
Epoch 129/10000
Epoch 130/10000
Epoch 131/10000
Epoch 132/10000
Epoch 133/10000
Epoch 134/10000
Epoch 135/10000
Epoch 136/10000
Epoch 137/10000
Epoch 138/10000
Epoch 139/10000
Epoch 140/10000
Epoch 141/10000
Epoch 142/100

Epoch 155/10000
Epoch 156/10000
Epoch 157/10000
Epoch 158/10000
Epoch 159/10000
Epoch 160/10000
Epoch 161/10000
Epoch 162/10000
Epoch 163/10000
Epoch 164/10000
Epoch 165/10000
Epoch 166/10000
Epoch 167/10000
Epoch 168/10000
Epoch 169/10000
Epoch 170/10000
Epoch 171/10000
Epoch 172/10000
Epoch 173/10000
Epoch 174/10000
Epoch 175/10000
Epoch 176/10000
Epoch 177/10000
Epoch 178/10000
Epoch 179/10000
Epoch 180/10000
Epoch 181/10000
Epoch 182/10000
Epoch 183/10000
Epoch 184/10000
Epoch 185/10000
Epoch 186/10000
Epoch 187/10000
Epoch 188/10000
Epoch 189/10000
Epoch 190/10000
Epoch 191/10000
Epoch 192/10000


Epoch 193/10000
Epoch 194/10000
Epoch 195/10000
Epoch 196/10000
Epoch 197/10000
Epoch 198/10000
Epoch 199/10000
Epoch 200/10000
Epoch 201/10000
Epoch 202/10000
Epoch 203/10000
Epoch 204/10000
Epoch 205/10000
Epoch 206/10000
Epoch 207/10000
Epoch 208/10000
Epoch 209/10000
Epoch 210/10000
Epoch 211/10000
Epoch 212/10000
Epoch 213/10000
Epoch 214/10000
Epoch 215/10000
Epoch 216/10000
Epoch 217/10000
Epoch 218/10000
Epoch 219/10000
Epoch 220/10000
Epoch 221/10000
Epoch 222/10000
Epoch 223/10000
Epoch 224/10000
Epoch 225/10000
Epoch 226/10000
Epoch 227/10000
Epoch 228/10000
Epoch 229/10000


Epoch 230/10000
Epoch 231/10000
Epoch 232/10000
Epoch 233/10000
Epoch 234/10000
Epoch 235/10000
Epoch 236/10000
Epoch 237/10000
Epoch 238/10000
Epoch 239/10000
Epoch 240/10000
Epoch 241/10000
Epoch 242/10000
Epoch 243/10000
Epoch 244/10000
Epoch 245/10000
Epoch 246/10000
Epoch 247/10000
Epoch 248/10000
Epoch 249/10000
Epoch 250/10000
Epoch 251/10000
Epoch 252/10000
Epoch 253/10000
Epoch 254/10000
Epoch 255/10000
Epoch 256/10000
Epoch 257/10000
Epoch 258/10000
Epoch 259/10000
Epoch 260/10000
Epoch 261/10000
Epoch 262/10000
Epoch 263/10000
Epoch 264/10000
Epoch 265/10000
Epoch 266/10000
Epoch 267/10000


Epoch 268/10000
Epoch 269/10000
Epoch 270/10000
Epoch 271/10000
Epoch 272/10000
Epoch 273/10000
Epoch 274/10000
Epoch 275/10000
Epoch 276/10000
Epoch 277/10000
Epoch 278/10000
Epoch 279/10000
Epoch 280/10000
Epoch 281/10000
Epoch 282/10000
Epoch 283/10000
Epoch 284/10000
Epoch 285/10000
Epoch 286/10000
Epoch 287/10000
Epoch 288/10000
Epoch 289/10000
Epoch 290/10000
Epoch 291/10000
Epoch 292/10000
Epoch 293/10000
Epoch 294/10000
Epoch 295/10000
Epoch 296/10000
Epoch 297/10000
Epoch 298/10000
Epoch 299/10000
Epoch 300/10000
Epoch 301/10000
Epoch 302/10000
Epoch 303/10000
Epoch 304/10000
Epoch 305/10000


Epoch 306/10000
Epoch 307/10000
Epoch 308/10000
Epoch 309/10000
Epoch 310/10000
Epoch 311/10000
Epoch 312/10000
Epoch 313/10000
Epoch 314/10000
Epoch 315/10000
Epoch 316/10000
Epoch 317/10000
Epoch 318/10000
Epoch 319/10000
Epoch 320/10000
Epoch 321/10000
Epoch 322/10000
Epoch 323/10000
Epoch 324/10000
Epoch 325/10000
Epoch 326/10000
Epoch 327/10000
Epoch 328/10000
Epoch 329/10000
Epoch 330/10000
Epoch 331/10000
Epoch 332/10000
Epoch 333/10000
Epoch 334/10000
Epoch 335/10000
Epoch 336/10000
Epoch 337/10000
Epoch 338/10000
Epoch 339/10000
Epoch 340/10000
Epoch 341/10000
Epoch 342/10000
Epoch 343/10000


Epoch 344/10000
Epoch 345/10000
Epoch 346/10000
Epoch 347/10000
Epoch 348/10000
Epoch 349/10000
Epoch 350/10000
Epoch 351/10000
Epoch 352/10000
Epoch 353/10000
Epoch 354/10000
Epoch 355/10000
Epoch 356/10000
Epoch 357/10000
Epoch 358/10000
Epoch 359/10000
Epoch 360/10000
Epoch 361/10000
Epoch 362/10000
Epoch 363/10000
Epoch 364/10000
Epoch 365/10000
Epoch 366/10000
Epoch 367/10000
Epoch 368/10000
Epoch 369/10000
Epoch 370/10000
Epoch 371/10000
Epoch 372/10000
Epoch 373/10000
Epoch 374/10000
Epoch 375/10000
Epoch 376/10000
Epoch 377/10000
Epoch 378/10000
Epoch 379/10000
Epoch 380/10000
Epoch 381/10000


Epoch 382/10000
Epoch 383/10000
Epoch 384/10000
Epoch 385/10000
Epoch 386/10000
Epoch 387/10000
Epoch 388/10000
Epoch 389/10000
Epoch 390/10000
Epoch 391/10000
Epoch 392/10000
Epoch 393/10000
Epoch 394/10000
Epoch 395/10000
Epoch 396/10000
Epoch 397/10000
Epoch 398/10000
Epoch 399/10000
Epoch 400/10000
Epoch 401/10000
Epoch 402/10000
Epoch 403/10000
Epoch 404/10000
Epoch 405/10000
Epoch 406/10000
Epoch 407/10000
Epoch 408/10000
Epoch 409/10000
Epoch 410/10000
Epoch 411/10000
Epoch 412/10000
Epoch 413/10000
Epoch 414/10000
Epoch 415/10000
Epoch 416/10000
Epoch 417/10000
Epoch 418/10000
Epoch 419/10000


Epoch 420/10000
Epoch 421/10000
Epoch 422/10000
Epoch 423/10000
Epoch 424/10000
Epoch 425/10000
Epoch 426/10000
Epoch 427/10000
Epoch 428/10000
Epoch 429/10000
Epoch 430/10000
Epoch 431/10000
Epoch 432/10000
Epoch 433/10000
Epoch 434/10000
Epoch 435/10000
Epoch 436/10000
Epoch 437/10000
Epoch 438/10000
Epoch 439/10000
Epoch 440/10000
Epoch 441/10000
Epoch 442/10000
Epoch 443/10000
Epoch 444/10000
Epoch 445/10000
Epoch 446/10000
Epoch 447/10000
Epoch 448/10000
Epoch 449/10000
Epoch 450/10000
Epoch 451/10000
Epoch 452/10000
Epoch 453/10000
Epoch 454/10000
Epoch 455/10000
Epoch 456/10000
Epoch 457/10000


Epoch 458/10000
Epoch 459/10000
Epoch 460/10000
Epoch 461/10000
Epoch 462/10000
Epoch 463/10000
Epoch 464/10000
Epoch 465/10000
Epoch 466/10000
Epoch 467/10000
Epoch 468/10000
Epoch 469/10000
Epoch 470/10000
Epoch 471/10000
Epoch 472/10000
Epoch 473/10000
Epoch 474/10000
Epoch 475/10000
Epoch 476/10000
Epoch 477/10000
Epoch 478/10000
Epoch 479/10000
Epoch 480/10000
Epoch 481/10000
Epoch 482/10000
Epoch 483/10000
Epoch 484/10000
Epoch 485/10000
Epoch 486/10000
Epoch 487/10000
Epoch 488/10000
Epoch 489/10000
Epoch 490/10000
Epoch 491/10000
Epoch 492/10000
Epoch 493/10000
Epoch 494/10000
Epoch 495/10000


Epoch 496/10000
Epoch 497/10000
Epoch 498/10000
Epoch 499/10000
Epoch 500/10000
Epoch 501/10000
Epoch 502/10000
Epoch 503/10000
Epoch 504/10000
Epoch 505/10000
Epoch 506/10000
Epoch 507/10000
Epoch 508/10000
Epoch 509/10000
Epoch 510/10000
Epoch 511/10000
Epoch 512/10000
Epoch 513/10000
Epoch 514/10000

# inference

In [8]:
speech = next(iter(dataset.rawset))
speech = speech[:speech.shape[0] // 256 * 256]
librosa.output.write_wav(os.path.join('/home/lqm/Diffusion-Bison/out/',  'test_gt.wav'),speech.numpy(),16000)
noise = tf.random.normal(tf.shape(speech[None]))
librosa.output.write_wav(
        os.path.join('/home/lqm/Diffusion-Bison/out/' , 'test_noise.wav'),noise[0].numpy(),16000)

In [9]:
noise.shape

TensorShape([1, 66048])

In [10]:
model_best = DiffusionModel(diffusion_steps=diffusion_steps)
model_best.compile(
    optimizer=tfa.optimizers.AdamW(
        learning_rate=learning_rate,
        weight_decay=weight_decay),
    loss=keras.losses.mean_absolute_error,  # pixelwise mean absolute error is used as loss, mse also ok
)
model_best.load_weights(checkpoint_path)


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f9fa46b3c10>

In [11]:
_, logmel = dataset.mel_fn(speech[None])
data=[]
data.append(speech)
data.append(logmel)
pred_sig,ori_sig = model_best.generate(data,int(0.95*diffusion_steps),noise)
pred_s=pred_sig.numpy()
ori_s=ori_sig.numpy()
#librosa.output.write_wav(os.path.join('/home/lqm/Diffusion-Bison/out/', 'step.wav'),s[0],16000)

2023-04-01 14:38:56.311274: I tensorflow/stream_executor/cuda/cuda_blas.cc:1760] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2023-04-01 14:38:56.347110: W tensorflow/stream_executor/cuda/cuda_dnn.cc:342] There was an error before creating cudnn handle: cudaErrorMemoryAllocation : out of memory


signal shape (66048,)
logmel shape (1, 258, 80)
initial_noise shape (1, 66048)
step19...


2023-04-01 14:38:57.125576: I tensorflow/stream_executor/cuda/cuda_dnn.cc:369] Loaded cuDNN version 8204


signal_loss 0.0440080426633358
Mel loss: 3.1987593
step18...
signal_loss 0.04351256042718887
Mel loss: 3.1399357
step17...
signal_loss 0.03655099496245384
Mel loss: 2.895056
step16...
signal_loss 0.03617250546813011
Mel loss: 2.8157704
step15...
signal_loss 0.032315075397491455
Mel loss: 2.5797672
step14...
signal_loss 0.03480013832449913
Mel loss: 2.7183242
step13...
signal_loss 0.03205684944987297
Mel loss: 2.5322673
step12...
signal_loss 0.03129318729043007
Mel loss: 2.4822145
step11...
signal_loss 0.029960177838802338
Mel loss: 2.3162453
step10...
signal_loss 0.02919125370681286
Mel loss: 2.2823093
step9...
signal_loss 0.0292984526604414
Mel loss: 2.3007755
step8...
signal_loss 0.02920069545507431
Mel loss: 2.2474706
step7...
signal_loss 0.02793525718152523
Mel loss: 1.8299419
step6...
signal_loss 0.028817934915423393
Mel loss: 1.9239103
step5...
signal_loss 0.028321731835603714
Mel loss: 1.8756804
step4...
signal_loss 0.03005647286772728
Mel loss: 1.9259465
step3...
signal_loss 0.

In [12]:
librosa.output.write_wav(os.path.join('/home/lqm/Diffusion-Bison/out/', 'best_pred_step.wav'),pred_s,16000)
librosa.output.write_wav(os.path.join('/home/lqm/Diffusion-Bison/out/', 'best_ori_step.wav'),ori_s,16000)

In [13]:
model_final = DiffusionModel(diffusion_steps=diffusion_steps)
model_final.compile(
    optimizer=tfa.optimizers.AdamW(
        learning_rate=learning_rate,
        weight_decay=weight_decay),
    loss=keras.losses.mean_absolute_error,  # pixelwise mean absolute error is used as loss, mse also ok
)
model_final.load_weights(final_checkpoint_path)


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f9fa43c6400>

In [14]:
pred_sig,ori_sig = model_final.generate(data,int(0.95*diffusion_steps),noise)
pred_s=pred_sig.numpy()
ori_s=ori_sig.numpy()

signal shape (66048,)
logmel shape (1, 258, 80)
initial_noise shape (1, 66048)
step19...
signal_loss 0.040738631039857864
Mel loss: 3.139355
step18...
signal_loss 0.036955758929252625
Mel loss: 2.9476438
step17...
signal_loss 0.03298822417855263
Mel loss: 2.7424583
step16...
signal_loss 0.036415114998817444
Mel loss: 2.8384395
step15...
signal_loss 0.03396943584084511
Mel loss: 2.7241504
step14...
signal_loss 0.03251845762133598
Mel loss: 2.6080275
step13...
signal_loss 0.03018995374441147
Mel loss: 2.35445
step12...
signal_loss 0.03012385219335556
Mel loss: 2.3001935
step11...
signal_loss 0.03133554011583328
Mel loss: 2.3554063
step10...
signal_loss 0.030574453994631767
Mel loss: 2.2707727
step9...
signal_loss 0.028864117339253426
Mel loss: 2.0836008
step8...
signal_loss 0.02849842980504036
Mel loss: 2.0218234
step7...
signal_loss 0.028412001207470894
Mel loss: 1.9493036
step6...
signal_loss 0.029902905225753784
Mel loss: 1.8463776
step5...
signal_loss 0.03133463114500046
Mel loss: 1.

In [15]:
librosa.output.write_wav(os.path.join('/home/lqm/Diffusion-Bison/out/', 'final_pred_step.wav'),pred_s,16000)
librosa.output.write_wav(os.path.join('/home/lqm/Diffusion-Bison/out/', 'final_ori_step.wav'),ori_s,16000)