# Simplified Diffusion Schrodinger Bridge


```bib
@article{tang2024simplified,
  title={Simplified Diffusion Schrodinger Bridge},
  author={Tang, Zhicong and Hang, Tiankai and Gu, Shuyang and Chen, Dong and Guo, Baining},
  journal={arXiv preprint arXiv:2403.14623},
  year={2024}
}
```

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import Model
from tensorflow.keras.layers import Conv2D, MaxPool2D, UpSampling2D
from tqdm.auto import tqdm
from PIL import Image

import matplotlib.pyplot as plt

seed = 123
tf.random.set_seed(seed)
np.random.seed(seed)

In [None]:
ds = tfds.load('mnist')
data = ds['train']

# data = data.shuffle(60000, seed=seed).take(1000)

In [None]:
def preprocess_fn(entry):
    image = float(entry['image']) / 255.
    image = tf.where(image > .5, 1.0, 0.0)
    image = tf.cast(image, tf.float32)
    return image

bs = 256
train_data = data.map(preprocess_fn).shuffle(1024).batch(bs).prefetch(tf.data.AUTOTUNE)

In [None]:
def corrupt(xs, gamma=0.01):
    noise = tf.random.uniform(xs.shape)
    gamma = tf.reshape(gamma, (-1, 1, 1, 1)) # Sort shape so broadcasting works
    return xs + noise * gamma

In [None]:
class BasicUNet(Model):
    '''
    net = BasicUNet()
    x = tf.random.uniform((8, 28, 28, 1))
    net(x).shape
    # TensorShape([8, 28, 28, 1])
    '''

    def __init__(self, in_channels=1, out_channels=1):
        super().__init__(name='basic-unet')
        self.down_layers = [
            Conv2D(32, kernel_size=5, padding="same"),
            Conv2D(64, kernel_size=5, padding="same"),
            Conv2D(64, kernel_size=5, padding="same"),
        ]
        self.up_layers = [
            Conv2D(64, kernel_size=5, padding="same"),
            Conv2D(32, kernel_size=5, padding="same"),
            Conv2D(out_channels, kernel_size=5, padding="same"),
        ]
        self.act = tf.keras.activations.swish # also know as SiLU
        self.downscale = MaxPool2D(2)
        self.upscale = UpSampling2D(size=2)

    def call(self, x, t=None):
        h = []
        for i, l in enumerate(self.down_layers):
            x = self.act(l(x))
            # Store x for skip connection and downscale for all but the third (final) down layer
            if i < 2:
              h.append(x)
              x = self.downscale(x)

        for i, l in enumerate(self.up_layers):
            # Fetch x for skip connection and Upscale for all except the first up layer
            if i > 0:
              x = self.upscale(x)
              x += h.pop()
            x = self.act(l(x))

        return x

In [None]:
bnet = BasicUNet()
fnet = BasicUNet()

In [None]:
T=5

epochs = 5

# Define a loss finction
loss_fn = tf.keras.losses.MeanSquaredError()
opt = tf.keras.optimizers.Adam(learning_rate=1e-3)
losses, avg_losses = [], []

train_data = data.map(preprocess_fn).shuffle(1024).batch(bs).prefetch(tf.data.AUTOTUNE)

for epoch in range(epochs):

    for t in range(T):
        for step, xb in enumerate(train_data):
            with tf.GradientTape() as tape:
                if epoch == 0:
                    noised_xb = corrupt(xb)
                else:
                    noised_xb = corrupt(fnet(xb))
                loss = loss_fn(bnet(noised_xb), xb)

            if step == 0:
                noised_data = noised_xb
            else:
                noised_data = np.concatenate((noised_data, noised_xb))

            grads = tape.gradient(loss, bnet.trainable_weights)
            opt.apply_gradients(zip(grads, bnet.trainable_weights))
        train_data = noised_data

        losses.append(loss.numpy())

    # Calculate the average loss for this epoch
    avg_loss = np.mean(losses[-len(xb):])
    avg_losses.append(avg_loss)


    dataset_size = len(list(data.map(lambda x: x['image'])))
    prior_data = tf.data.Dataset.from_tensor_slices(
        np.random.uniform(size=(dataset_size, 28, 28, 1))
    )
    train_data = data.map(preprocess_fn).shuffle(1024).batch(bs).prefetch(tf.prior_data.AUTOTUNE)
    for t in range(T):
        for step, xb in enumerate(prior_data):
            with tf.GradientTape() as tape:
                noised_xb = corrupt(bnet(xb))
                loss = loss_fn(fnet(noised_xb), xb)
            if step == 0:
                noised_data = noised_xb
            else:
                noised_data = np.concatenate((noised_data, noised_xb))

            grads = tape.gradient(loss, fnet.trainable_weights)
            opt.apply_gradients(zip(grads, fnet.trainable_weights))
        priori_data = noised_data
        losses.append(loss.numpy())

    # Calculate the average loss for this epoch
    avg_loss = np.mean(losses[-len(xb):])
    avg_losses.append(avg_loss)

# View the loss curve
plt.plot(losses)
plt.ylim(0, 0.1);

# Fetch some data (using the first 8 for easy plotting)
batch = data.map(lambda x: x['image']).take(8)
xs = tf.convert_to_tensor(list(batch), dtype=tf.float32)

# Corrupt the images with a range of amounts
amount = tf.linspace(0.0, 1.0, xs.shape[0])
noised_xs = corrupt(xs, amount)


AttributeError: 'numpy.ndarray' object has no attribute 'map'

In [None]:

plt.plot(losses)
plt.ylim(0, 0.1);


In [None]:
def make_grid(xs, rows=1, cols=8):
    xs = xs.numpy().squeeze()
    images = [Image.fromarray(x) for x in xs]
    return image_grid(images, rows, cols, 'L')

def image_grid(imgs, rows, cols, mode='RGB'):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new(mode, size=(cols*w, rows*h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [None]:
n_steps = 15

# Start from random noise
x = tf.random.uniform((8, 28, 28, 1))
step_history = [x]

for t in range(n_steps, 0, -1):
    # Predict the denoised x0
    pred = bnet(x)
    # How much we move towards the prediction
    mix_factor = 1/(n_steps+1)
    x = corrupt(pred)
    # Store step for plotting
    x = tf.clip_by_value(x, 0, 1)
    step_history.append(x)

fig, axs = plt.subplots(n_steps+1, 1, figsize=(8, 8), sharex=True)
axs[0].set_xlabel('x(t)')
axs[0].imshow(make_grid(step_history[0]), cmap='Greys')
for i in range(1, n_steps+1):
    axs[i].imshow(make_grid(step_history[i]), cmap='Greys')
    axs[i].set_ylabel(f't=-{i}')