In [21]:
from matplotlib import pyplot as plt
import numpy as np
from IPython.display import Audio, display
from IPython.display import clear_output
import os
from scipy.io import wavfile
from scipy import signal
import scipy
import matplotlib
import tensorflow as tf
import time

from tensorflow.keras import layers

matplotlib.rcParams['figure.figsize'] = [8, 8]

In [2]:
frequency_bins = 2000
nfft = frequency_bins * 2 - 1
nperseg = nfft // 2
noverlap = nperseg // 64

target_sequence_length = 16
sequence_duration = (nperseg - noverlap) * (target_sequence_length - 1)

def transform_sample(sample):
    sample = sample * (15000 / np.amax(np.abs(sample)))
    
    _, _, zxx = signal.stft(sample, nperseg=nperseg, noverlap=noverlap, nfft=nfft)
    result = zxx.real
    logged = np.log10(np.abs(zxx.real) + 10)
    result = np.where(zxx.real < 0, -logged + 1, result)
    result = np.where(zxx.real > 0, logged - 1, result)
    
    result = result * (1 / 5.7)
    result = result + 0.5
    assert np.amin(result) >= 0
    assert np.amax(result) <= 1

    return np.swapaxes(result, 0, 1)

def invert_sample(sample):
    result = np.swapaxes(sample, 0, 1)
    result = (result - 0.5) * 5.7
    logprep = np.where(result < 0, result - 1, result)
    logprep = np.where(logprep > 0, result + 1, logprep)
    alogged = (10 ** np.abs(logprep)) - 10
    result = np.where(result < 0, -alogged, result)
    result = np.where(result > 0, alogged, result)
    return signal.istft(result, nperseg=nperseg, noverlap=noverlap)[1]

In [3]:
source_files = [f"./data/highfreq/{i}.wav" for i in range(1, 2)]

prepared_data = []

sample_rate = 32000
sampling_step = sequence_duration // 2

assert target_sequence_length / 4 == target_sequence_length // 4

for source in source_files:
    rate, data = wavfile.read(source)
    
    steps = (len(data) - sequence_duration) // sampling_step
    for i in range(steps):
        start = i * sampling_step
        end = start + sequence_duration
        transformed = transform_sample(data[start:end])
        prepared_data.append(transformed)

prepared_data = np.array(prepared_data)
print(prepared_data.shape)

(375, 16, 2000)


In [4]:
buffer = prepared_data[32]
display(Audio(invert_sample(buffer), rate=sample_rate))

In [5]:
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(
        target_sequence_length // 4 * 2000 * 256,
        use_bias=False, input_shape=(100,))
    )
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Reshape(
        (target_sequence_length // 4, frequency_bins, 256)))
    
    model.add(layers.Conv2DTranspose(
        128, (5, 5), strides=(1, 1), padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Conv2DTranspose(
        64, (5, 5), strides=(2, 1), padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Conv2DTranspose(
        1, (5, 5), strides=(2, 1), padding="same", use_bias=False,
        activation="tanh"
    ))
    model.add(layers.Reshape((target_sequence_length, frequency_bins)))
    model.summary()
    return model

In [6]:
generator = make_generator_model()
noise = tf.random.normal([1, 100])
generated_sample = generator(noise, training=False).numpy()
print(generated_sample.shape)

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 2048000)           204800000 
_________________________________________________________________
batch_normalization (BatchNo (None, 2048000)           8192000   
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 2048000)           0         
_________________________________________________________________
reshape (Reshape)            (None, 4, 2000, 256)      0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 4, 2000, 128)      819200    
_________________________________________________________________
batch_normalization_1 (Batch (None, 4, 2000, 128)      512       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 4, 2000, 128)      0

In [7]:
display(Audio(invert_sample(generated_sample[0]), rate=sample_rate))

In [8]:
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv1D(
        64, 5, strides=2, padding="same",
        input_shape=(target_sequence_length, frequency_bins)))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
    
    model.add(layers.Conv1D(
        128, 5, strides=2, padding="same"))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
    
    model.add(layers.Flatten())
    model.add(layers.Dense(1))
    print(model.summary())
    return model

In [9]:
discriminator = make_discriminator_model()
decision = discriminator(generated_sample)
print(decision)

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv1d (Conv1D)              (None, 8, 64)             640064    
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 8, 64)             0         
_________________________________________________________________
dropout (Dropout)            (None, 8, 64)             0         
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 4, 128)            41088     
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 4, 128)            0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 4, 128)            0         
_________________________________________________________________
flatten (Flatten)            (None, 512)              

In [10]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [11]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [12]:
checkpoint_dir = "./training-checkpoints/audio-gan"
checkpoint_prefix = os.path.join(checkpoint_dir, "audio-gan")
checkpoint = tf.train.Checkpoint(
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    generator=generator,
    discriminator=discriminator
)

In [13]:
BATCH_SIZE = 32
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 4

seed = tf.random.normal([num_examples_to_generate, noise_dim])

In [22]:
@tf.function
def train_step(sequences):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_samples = generator(noise, training=True)

        real_output = discriminator(sequences, training=True)
        fake_output = discriminator(generated_samples, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

        gradients_of_generator = gen_tape.gradient(
            gen_loss, generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(
            disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(
        zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(
        zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return gen_loss, disc_loss

In [23]:
def train(dataset, epochs):
    print("Beginning training...")
    for epoch in range(epochs):
        start = time.time()
        
        for batch in dataset:
            gen_loss, disc_loss = train_step(batch)

        if (epoch + 1) % 100 == 0:
            clear_output(wait=True)

        if (epoch + 1) % 10 == 0 or epoch == 0:
            output_generation_results(generator, epoch + 1, seed)
            print(f"Time for epoch {epoch + 1} is {time.time() - start}")
            print(f"Gen loss: {gen_loss}, Disc loss: {disc_loss}")
        
        if (epoch + 1) % 2000 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
        
    
    output_generation_results(generator, epochs, seed)

In [19]:
def output_generation_results(model, epoch, test_input):
    predictions = model(test_input, training=False)
    print(f"Epoch {epoch}")
    for prediction in predictions:
        display(Audio(invert_sample(prediction.numpy()), rate=sample_rate))

In [17]:
BUFFER_SIZE = 60000
train_dataset = (
    tf.data.Dataset.from_tensor_slices(prepared_data)
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
)

In [24]:
train(train_dataset, 2000000)

Epoch 2400


Time for epoch 2400 is 10.211037397384644
Gen loss: 1.1156129837036133, Disc loss: 1.1582651138305664
Epoch 2410


Time for epoch 2410 is 10.243515968322754
Gen loss: 1.3301520347595215, Disc loss: 0.9049346446990967
Epoch 2420


Time for epoch 2420 is 10.244282722473145
Gen loss: 0.8846646547317505, Disc loss: 1.107039451599121


ResourceExhaustedError:  OOM when allocating tensor with shape[100,2048000] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[node MatMul_2 (defined at <ipython-input-22-4a651fea2272>:14) ]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
 [Op:__inference_train_step_5602]

Errors may have originated from an input operation.
Input Source operations connected to node MatMul_2:
 random_normal (defined at <ipython-input-22-4a651fea2272>:3)

Function call stack:
train_step
