In [None]:
ebm_input = layers.Input(shape = (32, 32, 1))
x = layers.Conv2D(16, kernel_size = 5, strides = 2, padding = "same",activation = activations.swish)(ebm_input)
x = layers.Conv2D(32, kernel_size = 5, strides = 2, padding = "same",activation = activations.swish)(x)
x = layers.Conv2D(64, kernel_size = 5, strides = 2, padding = "same",activation = activations.swish)(x)
x = layers.Conv2D(64, kernel_size = 5, strides = 2, padding = "same",activation = activations.swish)(x)
x = layers.Flatten()(x)
x = layers.Dense(64, activation =activations.swish)(x)
ebm_output = layers.Dense(1)(x)
model = models.Model(ebm_input, ebm_output)

In [None]:
def generate_samples(model, inp_imgs, steps, step_size, noise):
    imgs_per_step = []
    for _ in range(steps):
        inp_imgs += tf.random.normal(inp_imgs.shape, mean = 0, stddev = noise)
        inp_imgs = tf.clip_by_value(inp_imgs, -1.0, 1,0)
        with tf.GradientTape() as tape:
            tape.watch(inp_imgs)
            out_score = -model(inp_imgs)
        grads = tape.gradient(out_score, inp_imgs)
        grads = tf.clip_by_value(grads, -0.03, 0.03)
        inp_imgs += -step_size * grads
        inp_imgs = tf.clip_by_value(inp_imgs, -1.0, 1.0)
        return inp_imgs

In [None]:
class Buffer:
    def __init_(self, model):
        super().__init__()
        self.model = model
        self.examples = [
            tf.random.uniform(shape = (1, 32, 32, 1)) * 2 - 1
            for _ in range(128)
        ]
    
    def sample_new_exmps(self, steps, step_size, noise):
        n_new = np.random.binomial(128, 0.05)
        rand_imgs = (tf.random.uniform((n_new, 32, 32, 1)) * 2 - 1)
        old_imgs = tf.concat(random.choices(self.examples, k = 128 - n_new), axis = 0)
        inp_imgs = tf.concat([rand_imgs, old_imgs], axis = 0)
        inp_imgs = generate_samples(self.model, inp_imgs, steps = steps, step_size = step_size, noise = noise)
        self.examples = tf.split(inp_imgs, 128, axis = 0) + self.examples
        self.examples = self.examples[:8192]
        return inp_imgs

In [None]:
class EBM(models.Model):
    def __init__(self):
        super(EBM, self).__init__()
        self.model = model
        self.buffer = Buffer(self.model)
        self.alpha = 0.1
        self.loss_metric = metrics.Mean(name = "loss")
        self.reg_loss_metric = metrics.Mean(name = "log")
        self.cdiv_loss_metric = metrics.Mean(name = "cdiv")
        self.real_out_metric = metrics.Mean(name = "real")
        self.fake_out_metric = metrics.Mean(name = "fake")
    
    @property
    def metrics(self):
        return [
            self.loss_metric,
            self.reg_loss_metric,
            self.cdiv_loss_metric ,
            self.real_out_metric,
            self.fake_out_metric
        ]
    
    def train_step(self, real_imgs):
        real_imgs += tf.random.normal(
            shape = tf.shape(real_imgs),  mean = 0, stddev = 0.005
        )
        real_imgs = tf.clip_by_value(real_imgs, -1.0, 1.0)
        fake_imgs = self.buffer.sample_new_exmps(
            steps = 60, step_size = 10, noise = 0.005
        )
        inp_imgs = tf.concat([real_imgs, fake_imgs], axis = 0)
        with tf.GradientTApe() as training_tape:
            real_out, fake_out = tf.split(self.model(inp_imgs), 2, axis = 0)
            cdiv_loss = tf.reduce_mean(fake_out, axis = 0) - tf.reduce_mean(
                real_out, axis = 0
            )
            reg-loss = self.alpha * tf.reduce_mean(
                real-out ** 2 + fake_out ** 2, axis = 0
            )
            loss = reg_loss + cdiv_loss
        grads = training_tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(
            zip(grads, self.model.trainable_variables)
        )
        self.loss_metric.update_state(loss)
        self.reg_loss_metric.update_state(reg_loss)
        self.cdiv_loss_metric.update_state(cdiv_loss)
        self.real_out_metric.update_state(tf.reduce_man(real_out, axis = 0))
        self.fake_out_metric.update_state(tf.reduce_mean(fake_out, axis = 0))
        return {m.name: m.result for m in self.metrics}
    
    def test_step(self, real_imgs):
        batch_size = real_imgs.shape[0]
        fake_imgs = tf.random.uniform((batch_size, 32, 32, 1)) * 2 - 1
        inp_imgs = tf.concat([real_imgs, fake_imgs], axis = 0)
        real_out, fake_out = tf.split(self.model(inp_imgs), 2, axis = 0)
        cdiv = tf.reduce_mean(fake_out, axis = 0) - tf.reduce_mean(real_out, axis = 0)
        self.cdiv_loss_metric.update_state(cdiv)
        self.real_out_metric.update_state(tf.reduce_mean(real_out, axis = 0))
        self.fake_out_metric.update_state(tf.reduce_mean(fake_out, axis = 0))
        return {m.name: m.result() for m in self.metrics[2:]}
    
    ebm = EBM()
    ebm.compile(optimizer = optimizers.Adam(learning_rate = 0.0001), run_eagerly = True)
    ebm.fit(x_train, epochs = 60, validation_data = x_test,)

In [None]:
start_imgs = np.random.uniform(size = (10, 32, 32, 1)) * 2 - 1

gen_img = generate_samples(
 ebm.model,
 start_imgs,
 steps=1000,
 step_size=10,
 noise = 0.005,
 return_img_per_step=True,
)