In [1]:
from utils import generate_real_samples, generate_latent_points, load_real_samples, generate_fake_samples, generate_images
from numpy import ones
import matplotlib.pyplot as plt
from statistics import mean
from wgan import WGAN
from tensorflow.keras.models import load_model

In [2]:
def train(g_model, c_model, gan_model, dataset, latent_dim, n_epochs=10, n_batch=64, n_critic=5):
    # Calculate the number of batches per epoch
    bat_per_epo = int(dataset.shape[0] / n_batch)
    # Calculate the total number of training steps
    n_steps = bat_per_epo * n_epochs
    # Calculate the size of half a batch
    half_batch = int(n_batch / 2)
    # Initialize lists to store the history of losses
    c1_hist, c2_hist, g_hist = list(), list(), list()
    # Loop over the total number of training steps
    for i in range(n_steps):
        # Initialize lists to store the losses for this step
        c1_tmp, c2_tmp = list(), list()
        # Train the critic for n_critic times
        for _ in range(n_critic):
            # Generate real samples and train the critic on them
            X_real, y_real = generate_real_samples(dataset, half_batch)
            c_loss1 = c_model.train_on_batch(X_real, y_real)
            c1_tmp.append(c_loss1)
            # Generate fake samples and train the critic on them
            X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            c_loss2 = c_model.train_on_batch(X_fake, y_fake)
            c2_tmp.append(c_loss2)
        # Store the average losses for this step
        c1_hist.append(mean(c1_tmp))
        c2_hist.append(mean(c2_tmp))
        # Generate points in the latent space as input for the generator
        X_gan = generate_latent_points(latent_dim, n_batch)
        # Create labels for the fake samples
        y_gan = -ones((n_batch, 1))
        g_loss = gan_model.train_on_batch(X_gan, y_gan)
        g_hist.append(g_loss)
        # Print the progress
        print('> %d, c1=%.3f, c2=%.3f g=%.3f' % (i+1, c1_hist[-1], c2_hist[-1], g_loss))

    plt.plot(c1_hist, label='crit_real')
    plt.plot(c2_hist, label='crit_fake')
    plt.plot(g_hist, label='gen')
    plt.legend()
    plt.savefig('plot_line_plot_loss.png')
    plt.close()

In [3]:
dataset = load_real_samples()
# Define the dimensionality of the latent space
latent_dim = 50
# Create a WGAN
wgan = WGAN(dataset.shape, latent_dim)
# Build the critic and the generator
discriminator = wgan.build_discriminator()
generator = wgan.build_generator()
# Build the combined model
wgan_model = wgan.build_wgan()

# Train the WGAN
train(g_model=generator, c_model=discriminator, gan_model=wgan_model, dataset=dataset, latent_dim=latent_dim)

  super().__init__(name, **kwargs)


> 1, c1=-2.574, c2=-0.021 g=-0.351
> 2, c1=-7.138, c2=-0.023 g=-1.715
> 3, c1=-9.827, c2=-0.030 g=-3.224
> 4, c1=-12.673, c2=-0.038 g=-4.396
> 5, c1=-14.753, c2=-0.045 g=-5.729
> 6, c1=-16.321, c2=-0.052 g=-7.067
> 7, c1=-17.874, c2=-0.072 g=-8.178
> 8, c1=-19.766, c2=-0.080 g=-9.194
> 9, c1=-20.291, c2=-0.104 g=-10.055
> 10, c1=-21.549, c2=-0.136 g=-10.871
> 11, c1=-22.320, c2=-0.161 g=-11.748
> 12, c1=-23.514, c2=-0.199 g=-12.276
> 13, c1=-24.264, c2=-0.226 g=-13.119
> 14, c1=-25.739, c2=-0.263 g=-13.524
> 15, c1=-26.342, c2=-0.325 g=-14.321
> 16, c1=-27.085, c2=-0.374 g=-14.697
> 17, c1=-28.256, c2=-0.436 g=-15.436
> 18, c1=-29.160, c2=-0.511 g=-15.792
> 19, c1=-29.294, c2=-0.604 g=-16.355
> 20, c1=-29.956, c2=-0.707 g=-16.920
> 21, c1=-30.786, c2=-0.782 g=-17.285
> 22, c1=-31.579, c2=-0.889 g=-17.731
> 23, c1=-32.023, c2=-1.018 g=-18.323
> 24, c1=-33.098, c2=-1.125 g=-18.686
> 25, c1=-34.012, c2=-1.286 g=-18.910
> 26, c1=-34.776, c2=-1.423 g=-19.322
> 27, c1=-35.058, c2=-1.584 g=-1