In [3]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tqdm.auto import trange, tqdm

In [4]:
dataset = tf.keras.preprocessing.image_dataset_from_directory(
    "../data0/lsun/bedroom",
    labels=None,
    label_mode=None,
    class_names=None,
    color_mode='rgb',
    batch_size=None,
    image_size=(64, 64),
    shuffle=True,
    seed=2024,
    validation_split=None,
    subset=None,
    interpolation='bilinear',
    follow_links=False,
    crop_to_aspect_ratio=False
)

Found 303125 files.


In [4]:
IMG_SIZE = 64
BATCH_SIZE = 125
timesteps = 1000  # Increased timesteps for better diffusion
beta = np.linspace(0.0001, 0.02, timesteps)  # Linear noise schedule
alpha = 1 - beta
alpha_bar = np.cumprod(alpha)

In [5]:
# Forward diffusion process
def forward_diffusion_sample(x0, t, alpha_bar):
    noise = np.random.normal(size=x0.shape)
    alpha_t = alpha_bar[t]
    alpha_t = alpha_t.reshape((-1, 1, 1, 1))
    return np.sqrt(alpha_t) * x0 + np.sqrt(1 - alpha_t) * noise

In [6]:
# Define the vanilla pixel diffusion model
def make_simple_model():
    inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
    x = layers.Conv2D(64, kernel_size=3, padding='same', activation='relu')(inputs)
    x = layers.Conv2D(64, kernel_size=3, padding='same', activation='relu')(x)
    x = layers.Conv2D(3, kernel_size=1, padding='same')(x)
    model = tf.keras.models.Model(inputs, x)
    return model

In [7]:
model = make_simple_model()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)
loss_func = tf.keras.losses.MeanSquaredError()
model.compile(loss=loss_func, optimizer=optimizer)

In [8]:
# Training Loop
def train_model(model, dataset, alpha_bar, timesteps, epochs=1):
    for epoch in range(epochs):
        for x in tqdm(dataset.batch(BATCH_SIZE)):
            t = np.random.randint(0, timesteps, size=(BATCH_SIZE,))
            xt = forward_diffusion_sample(x, t, alpha_bar)
            loss = model.train_on_batch(xt, x)
        print(f"Epoch {epoch + 1}, Loss: {loss}")

In [None]:
# Train the model
train_model(model, dataset, alpha_bar, timesteps, epochs=10)

  0%|          | 0/1200 [00:00<?, ?it/s]

