In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt

  if not hasattr(np, "object"):


In [2]:
SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)

In [None]:
IMAGE_SIZE = 32
CHANNELS = 1

BATCH_SIZE = 64
EPOCHS = 120

STEPS = 60             
STEP_SIZE = 10.0       
NOISE = 0.005          
GRADIENT_CLIP = 0.03   

ALPHA = 0.1            
LEARNING_RATE = 1e-4

BUFFER_SIZE = 8192     
RANDOM_RESTART_RATE = 0.05   

LOG_DIR = "./logs_ebm"
os.makedirs(LOG_DIR, exist_ok=True)

In [4]:
def preprocess_mnist(x):
    x = x.astype("float32")
    x = (x - 127.5) / 127.5  # [-1, 1]
    # pad to 32x32 with -1.0
    pad = (IMAGE_SIZE - 28) // 2
    x = np.pad(
        x,
        pad_width=((0, 0), (pad, pad), (pad, pad)),
        mode="constant",
        constant_values=-1.0,
    )
    x = np.expand_dims(x, axis=-1)  # (N, 32, 32, 1)
    return x
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_train = preprocess_mnist(x_train)
x_test = preprocess_mnist(x_test)

train_ds = (
    tf.data.Dataset.from_tensor_slices(x_train)
    .shuffle(10_000, seed=SEED)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.AUTOTUNE)
)

In [5]:
def build_energy_network():
    inp = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))
    x = layers.Conv2D(16, 5, strides=2, padding="same", activation="swish")(inp)
    x = layers.Conv2D(32, 3, strides=2, padding="same", activation="swish")(x)
    x = layers.Conv2D(64, 3, strides=2, padding="same", activation="swish")(x)
    x = layers.Conv2D(64, 3, strides=2, padding="same", activation="swish")(x)
    x = layers.Flatten()(x)
    x = layers.Dense(64, activation="swish")(x)
    out = layers.Dense(1)(x)  # energy
    return models.Model(inp, out, name="energy_network")

energy_net = build_energy_network()
energy_net.summary()

In [6]:
@tf.function
def langevin_sample(model, x_init, steps=STEPS, step_size=STEP_SIZE, noise=NOISE, grad_clip=GRADIENT_CLIP):
    x = tf.identity(x_init)

    for _ in tf.range(steps):
        # Inject small Gaussian noise
        x = x + noise * tf.random.normal(tf.shape(x), dtype=x.dtype)

        with tf.GradientTape() as tape:
            tape.watch(x)
            e = model(x, training=False)  # energies
            # Sum energies so tape can take gradient wrt x
            e_sum = tf.reduce_sum(e)

        grads = tape.gradient(e_sum, x)
        grads = tf.clip_by_value(grads, -grad_clip, grad_clip)

        # Update pixels: x <- x + step_size * grads
        x = x + step_size * grads

        # Keep within valid range [-1, 1]
        x = tf.clip_by_value(x, -1.0, 1.0)

    return x

In [7]:
class ReplayBuffer:
    def __init__(self, size, image_shape):
        self.size = int(size)
        self.image_shape = tuple(image_shape)
        self._buffer = None
        self._filled = 0

    def init(self):
        # Start buffer with random images in [-1,1]
        self._buffer = np.random.uniform(-1.0, 1.0, size=(self.size, *self.image_shape)).astype("float32")
        self._filled = self.size

    def sample(self, batch_size):
        idx = np.random.randint(0, self._filled, size=batch_size)
        return self._buffer[idx]

    def update(self, samples):
        # Overwrite random positions with new samples
        samples = samples.numpy() if isinstance(samples, tf.Tensor) else samples
        n = samples.shape[0]
        idx = np.random.randint(0, self.size, size=n)
        self._buffer[idx] = samples

buffer = ReplayBuffer(BUFFER_SIZE, (IMAGE_SIZE, IMAGE_SIZE, CHANNELS))
buffer.init()

In [11]:
class EBM(tf.keras.Model):
    def __init__(self, energy_model, replay_buffer, **kwargs):
        super().__init__(**kwargs)
        self.energy_model = energy_model
        self.replay_buffer = replay_buffer

        # Track metrics
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")
        self.cdiv_tracker = tf.keras.metrics.Mean(name="cdiv")
        self.reg_tracker = tf.keras.metrics.Mean(name="reg")
        self.real_e_tracker = tf.keras.metrics.Mean(name="real_energy")
        self.fake_e_tracker = tf.keras.metrics.Mean(name="fake_energy")

    @property
    def metrics(self):
        return [self.loss_tracker, self.cdiv_tracker, self.reg_tracker, self.real_e_tracker, self.fake_e_tracker]

    def sample_negatives(self, batch_size):
        # Random restart with small probability; else from replay buffer
        use_random = np.random.binomial(1, RANDOM_RESTART_RATE, size=batch_size).astype(bool)
        start = self.replay_buffer.sample(batch_size).astype("float32")
        if np.any(use_random):
            start[use_random] = np.random.uniform(-1.0, 1.0, size=(np.sum(use_random), IMAGE_SIZE, IMAGE_SIZE, CHANNELS)).astype("float32")

        start = tf.convert_to_tensor(start)
        neg = langevin_sample(self.energy_model, start, steps=STEPS, step_size=STEP_SIZE, noise=NOISE, grad_clip=GRADIENT_CLIP)
        return neg

    def train_step(self, data):
        # data is a batch of real images
        x_real = data

        # Add small noise to real images (matches reference notebook)
        x_real = x_real + NOISE * tf.random.normal(tf.shape(x_real), dtype=x_real.dtype)
        x_real = tf.clip_by_value(x_real, -1.0, 1.0)

        # Generate negative samples with CD (from replay buffer + Langevin)
        batch_size = x_real.shape[0]  # Python int when drop_remainder=True
        x_fake = self.sample_negatives(batch_size)
        with tf.GradientTape() as tape:
            # Compute energies for real and fake
            x_all = tf.concat([x_real, x_fake], axis=0)
            e_all = self.energy_model(x_all, training=True)
            e_real, e_fake = tf.split(e_all, num_or_size_splits=2, axis=0)

            # Contrastive divergence loss
            cdiv_loss = tf.reduce_mean(e_fake) - tf.reduce_mean(e_real)

            # Energy regularization (stabilizes training)
            reg_loss = ALPHA * tf.reduce_mean(tf.square(e_real) + tf.square(e_fake))

            loss = cdiv_loss + reg_loss

        grads = tape.gradient(loss, self.energy_model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.energy_model.trainable_variables))

        # Update replay buffer with newly generated negatives
        self.replay_buffer.update(x_fake)

        # Update metrics
        self.loss_tracker.update_state(loss)
        self.cdiv_tracker.update_state(cdiv_loss)
        self.reg_tracker.update_state(reg_loss)
        self.real_e_tracker.update_state(tf.reduce_mean(e_real))
        self.fake_e_tracker.update_state(tf.reduce_mean(e_fake))

        return {m.name: m.result() for m in self.metrics}

In [None]:
ebm = EBM(energy_net, buffer)
ebm.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    run_eagerly=True,   # IMPORTANT
)

history = ebm.fit(
    train_ds,
    epochs=EPOCHS,
    verbose=1,
)

Epoch 1/120


AttributeError: 'SymbolicTensor' object has no attribute 'numpy'

In [10]:
plt.figure()
plt.plot(history.history["loss"])
plt.title("EBM Training Loss (Total Loss)")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.show()

NameError: name 'history' is not defined

<Figure size 640x480 with 0 Axes>

In [None]:
plt.figure()
plt.plot(history.history["cdiv"], label="Contrastive Divergence (mean E_fake - mean E_real)")
plt.plot(history.history["reg"], label="Regularization")
plt.title("EBM Training Components")
plt.xlabel("Epoch")
plt.ylabel("Value")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
def visualize_generated_samples(model, n=16, steps=1000):
    x0 = tf.random.uniform((n, IMAGE_SIZE, IMAGE_SIZE, CHANNELS), minval=-1.0, maxval=1.0)
    xg = langevin_sample(model, x0, steps=steps, step_size=STEP_SIZE, noise=NOISE, grad_clip=GRADIENT_CLIP)
    imgs = (xg.numpy() + 1.0) / 2.0  # back to [0,1]
    imgs = np.clip(imgs, 0, 1)

    cols = int(np.sqrt(n))
    rows = int(np.ceil(n / cols))
    plt.figure(figsize=(cols * 2, rows * 2))
    for i in range(n):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(imgs[i, :, :, 0], cmap="gray")
        plt.axis("off")
    plt.suptitle(f"Generated Samples (Langevin steps={steps})")
    plt.show()