In [1]:
import tsgm
import tensorflow as tf



In [None]:

wavegan = tsgm.models.WaveGANArchitecture(
    seq_len=16384,  
    feat_dim=1,     
    latent_dim=100, 
    kernel_size=25, 
    phase_rad=2
)

In [None]:

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [None]:
def train_step(real_audio):
    noise = tf.random.normal([batch_size, wavegan.latent_dim])
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_audio = wavegan._generator(noise, training=True)
        
        real_output = wavegan._discriminator(real_audio, training=True)
        fake_output = wavegan._discriminator(generated_audio, training=True)
        
        gen_loss = cross_entropy(tf.ones_like(fake_output), fake_output)
        disc_loss = cross_entropy(tf.ones_like(real_output), real_output) + \
                    cross_entropy(tf.zeros_like(fake_output), fake_output)
    
    gradients_of_generator = gen_tape.gradient(gen_loss, wavegan._generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, wavegan._discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(gradients_of_generator, wavegan._generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, wavegan._discriminator.trainable_variables))
    
    return gen_loss, disc_loss


In [None]:
def train(dataset, epochs):
    for epoch in range(epochs):
        for batch in dataset:
            gen_loss, disc_loss = train_step(batch)
        
        # Print losses, save checkpoints, generate samples, etc.
        print(f"Epoch {epoch+1}, Gen Loss: {gen_loss:.4f}, Disc Loss: {disc_loss:.4f}")

dataset = load_audio_dataset()
train(dataset, epochs=100)