In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from nbdev import *
%nbdev_default_export dcgan

Cells will be exported to tf2_gans.dcgan,
unless a different module is specified after an export flag: `%nbdev_export special.module`


In [None]:
from tf2_gans.losses import BCE_generator_loss, BCE_discriminator_loss
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras

# DCGAN

In [None]:
class DCGAN(object):
    def __init__(self, noise_dim, out_channels=1):
        self.noise_dim = noise_dim
        self.generator = Generator(
            noise_dim=noise_dim, out_channels=out_channels)
        self.discriminator = Discriminator()
        self.generator_loss = BCE_generator_loss
        self.discriminator_loss = BCE_discriminator_loss
        self.generator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4,
                                                            beta_1=0.5)
        self.discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4,
                                                                beta_1=0.5)

        self.checkpoint = tf.train.Checkpoint(generator_optimizer=self.generator_optimizer,
                                              generator=self.generator,
                                              discriminator_optimizer=self.discriminator_optimizer,
                                              discriminator=self.discriminator)

    def save(self, file_prefix):
        self.checkpoint.save(file_prefix)

    def generate(self, input_image):
        return self.generator(input_image, training=True)

    @tf.function
    def train_step(self, noise, target):
        with tf.GradientTape() as disc_tape, tf.GradientTape() as gen_tape:
            generated_images = self.generator(noise, training=True)

            real_output = self.discriminator(target, training=True)
            fake_output = self.discriminator(generated_images, training=True)

            gen_loss = self.generator_loss(fake_output)
            disc_loss = self.discriminator_loss(real_output, fake_output)

        discriminator_gradients = disc_tape.gradient(disc_loss,
                                                     self.discriminator.trainable_variables)
        self.discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                                         self.discriminator.trainable_variables))

        generator_gradients = gen_tape.gradient(gen_loss,
                                                self.generator.trainable_variables)

        self.generator_optimizer.apply_gradients(zip(generator_gradients,
                                                     self.generator.trainable_variables))

        return {'gen_total_loss': gen_loss,
                'disc_loss': disc_loss}




In [None]:
class Trainer(object):
    def __init__(self):
        self.has_trained = False
        
        


    def set_training_parameters(self,
                                epochs=10,
                                continue_training=False,
                                save_fq=0,
                                preview_fq=5,
                                progress_fn=None,
                                preview_fn=None):
        self.epochs = epochs
        self.continue_training = continue_training
        self.save_fq = save_fq
        self.preview_fq = preview_fq
        self.progress_fn = progress_fn
        self.preview_fn = preview_fn

        print("parameters set")

    def train_epoch(self):
        pass

    def preprocess_data(self, raw_data):
        self.raw_data = raw_data
        return self.raw_data

    def dataset_generator(self):
        pass

    def create_tf_dataset(self):
        self.dataset = tf.data.Dataset.from_generator(self.dataset_generator,
                                                      self.output_type)
        self.dataset = self.dataset.batch(self.batch_size)
        return self.dataset

    def init_dataset(self, dataset, batch_size=1, buffer_size=100):
        # Batch and shuffle the data
        self.dataset = dataset.shuffle(buffer_size)
        self.dataset = self.dataset.batch(batch_size)
        return self.dataset

    def train(self):
        if self.progress_fn:
            self.progress_fn(0.0)
        self.has_trained = True

        if self.continue_training:
            pass
        else:
            # reset weights
            pass

        for epoch in range(1, self.epochs+1):
            self.train_epoch()

            self.end_epoch(epoch)

        print("saving last epoch")

    def end_epoch(self, epoch):
        if self.progress_fn:
            self.progress_fn((100*epoch)//self.epochs)

        if epoch % (self.epochs//self.preview_fq) == 0:
            if self.preview_fn:
                self.preview_fn(self.generate_preview())

        print("Epoch: ", epoch)

        # if (self.epochs//self.save_fq) == 0:
        #     self.gan.save(file_prefix=checkpoint_prefix)

    def generate_preview(self):
        return {}

In [None]:
class Discriminator(keras.Model):
    def __init__(self,
                 conv_dim=64,
                 n_down_blocks=2,
                 dropout_rate=0.3,
                 leaky_alpha=0.2):
        super(Discriminator, self).__init__()
        self.main = keras.Sequential(name="main")
        for n in range(n_down_blocks):
            self.main.add(keras.layers.Conv2D(conv_dim*(2**n), 4,
                                              strides=2,
                                              padding='same',
                                              use_bias=False))
            self.main.add(keras.layers.BatchNormalization())
            self.main.add(keras.layers.LeakyReLU(leaky_alpha))
            self.main.add(keras.layers.Dropout(dropout_rate))
        self.main.add(keras.layers.Flatten())
        self.main.add(keras.layers.Dense(1, activation="sigmoid"))

    def call(self, inputs):
        return self.main(inputs)


In [None]:
class Generator(keras.Model):
    def __init__(self,
                 output_size=28,
                 conv_dim=64,
                 noise_dim=100,
                 n_up_blocks=1,
                 c_dim=5,
                 out_channels=1):
        super(Generator, self).__init__()
        self.main = keras.Sequential(name="main")
        init_size = output_size // (2**(n_up_blocks+1))
        conv_dim = conv_dim*(2**(n_up_blocks+1))
        self.main.add(keras.layers.Dense(init_size*init_size *
                                         noise_dim, use_bias=False, input_shape=(noise_dim,)))
        self.main.add(keras.layers.Reshape((init_size, init_size, )))
        self.main.add(keras.layers.Conv2DTranspose(conv_dim, 4,
                                                   strides=(1, 1),
                                                   padding='valid',
                                                   use_bias=False))
        self.main.add(keras.layers.BatchNormalization())
        self.main.add(keras.layers.LeakyReLU())

        curr_dim = conv_dim // 2

        for n in range(n_up_blocks):
            self.main.add(keras.layers.Conv2DTranspose(curr_dim,
                                                       4,
                                                       strides=(2, 2),
                                                       padding='same',
                                                       use_bias=False))
            self.main.add(keras.layers.BatchNormalization())
            self.main.add(keras.layers.LeakyReLU())
            curr_dim = curr_dim // 2

        self.main.add(keras.layers.Conv2DTranspose(out_channels,
                                                   (5, 5),
                                                   strides=(2, 2),
                                                   padding='same',
                                                   use_bias=False,
                                                   activation='tanh'))

    def call(self, inputs):
        return self.main(inputs)

In [None]:
notebook2script()

Converted 01_dcgan.ipynb.
Converted blocks.ipynb.
Converted core.ipynb.
Converted index.ipynb.
Converted losses.ipynb.
