In [None]:
!nvidia-smi  # check GPU usage -- can ignore this

In [None]:
# ignore this cell -- stuff for our server
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ['HTTP_PROXY']='http://proxy:3128/'
os.environ['HTTPS_PROXY']='http://proxy:3128/'

In [None]:
# imports
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import image as mpimage

from data.utils import parse_image_example
from modeling.layers import ConvNormAct, ResidualBlock

tfkl = tf.keras.layers

In [None]:
# needs preprocessed flickr faces data
batch_size = 256
train_data = tf.data.TFRecordDataset("data/flickr_64_train.TFR").shuffle(60000).map(parse_image_example).batch(batch_size, drop_remainder=True)
test_data = tf.data.TFRecordDataset("data/flickr_64_test.TFR").map(parse_image_example).batch(batch_size)

train_data = train_data.map(tf.image.random_flip_left_right)

In [None]:
test_images = np.concatenate([batch for batch in iter(test_data)], axis=0)

plt.figure(figsize=(15,15))
for ind, img in enumerate(test_images[:64]):
    plt.subplot(8, 8, ind+1)
    plt.imshow(img)
    plt.axis("off")
plt.show()

In [None]:
# for the improved paper version, we need the maximum euclidean distance between data points
# this runs basically forever; just cancelling it after some time is fine though
train_images_flat = test_images.reshape((-1, 64*64*3))
max_distance_so_far = 0.

_batchs = 64
for ind in range(0, len(train_images_flat), _batchs):
    img = train_images_flat[ind:ind+_batchs]
    all_distances = np.sqrt(np.sum((train_images_flat[None] - img[:, None])**2, axis=-1))
    max_distance_here = all_distances.max()
    max_distance_so_far = np.maximum(max_distance_here, max_distance_so_far)
    if not ind % 10:
        print(ind, max_distance_so_far)

In [None]:
# how many scales??
from scipy import stats

d = 64*64*3  # data dimensionality
wish_gamma = 1.015  # <-- this is the target ratio between successive noise scales, try values > 1

upper_limit = np.sqrt(2*d) * (wish_gamma - 1) + 3*wish_gamma
lower_limit = np.sqrt(2*d) * (wish_gamma - 1) - 3*wish_gamma

c_value = stats.norm.cdf(upper_limit) - stats.norm.cdf(lower_limit)

print("C value is {}; should be 0.5 or higher! Too low? Make gamma smaller!".format(c_value))

In [None]:
n_noise_scales = 800
target_noise = 0.001
noise_scales = np.geomspace(max_distance_so_far, target_noise, n_noise_scales, dtype=np.float32)
true_gamma = noise_scales[0] / noise_scales[1]
print("Gamma is {}, should be {} or lower! Too high? Make n_noise_scales larger!".format(true_gamma, wish_gamma))

In [None]:
# for langevin sampler
def wow_formula(gamma, t, eps):
    final_sig_sq = noise_scales[-1]**2
    first = (1 - (eps / final_sig_sq))**(2*t)
    second = gamma**2 - 2*eps / (final_sig_sq - final_sig_sq * (1 - eps/final_sig_sq)**2)
    third = 2*eps / (final_sig_sq - final_sig_sq * (1 - eps/final_sig_sq)**2)
    
    return first*second + third

In [None]:
t_total = 2400
t_per_noise_scale = t_total // n_noise_scales

epsilon = 0.00000005
some_value = wow_formula(true_gamma, t_per_noise_scale, epsilon)

print("The thingy value is {}! It should be close to 1! Try playing around with the epsilon value.".format(some_value))
print(t_per_noise_scale)

In [None]:
# just plot some noise scales to get an impression of how the data looks like
for scale in noise_scales[::40]:
    noisy_imgs = test_images[:64] + scale*np.random.normal(size=(64, 64, 64, 3))
    
    plt.figure(figsize=(15, 15))
    for ind, image in enumerate(noisy_imgs):
        plt.subplot(8, 8, ind+1)
        plt.imshow(np.clip(image, 0, 1))
        plt.axis("off")
        plt.suptitle("Noise scale: {}".format(scale))
    plt.show()
    
    plt.hist(noisy_imgs.reshape(-1), bins=100)
    plt.show()

In [None]:
class ScoreMatching(tf.keras.Model):
    def __init__(self, inputs, outputs, noise_scales, **kwargs):
        super().__init__(inputs, outputs, **kwargs)
        self.loss_tracker = tf.keras.metrics.Mean("loss")
        
        self.num_noise_scales = len(noise_scales)
        self.noise_scales_tensor = tf.convert_to_tensor(noise_scales, dtype=tf.float32)
        
    def train_step(self, data):
        with tf.GradientTape() as tape:
            loss = self.denoising_score_matching_loss(data, training=True)
        gradients = tape.gradient(loss, self.trainable_variables)
        optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}
    
    def test_step(self, data):
        loss = self.denoising_score_matching_loss(data, training=False)

        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}
    
    def denoising_score_matching_loss(self, image_batch, training=None):
        sampled_noise_index = tf.random.uniform([1], 0, self.num_noise_scales, dtype=tf.int32)[0]
        noise = self.noise_scales_tensor[sampled_noise_index]

        noisy_batch = image_batch + noise * tf.random.normal(tf.shape(image_batch))
        noise_input = tf.repeat(noise, tf.shape(noisy_batch)[0])[:, None, None, None]

        score = self([noisy_batch, noise_input], training=training)
        target_score = -1 * (noisy_batch - image_batch) / noise
        loss = tf.reduce_mean(0.5 * tf.reduce_sum((score - target_score)**2, axis=[1,2,3]))

        weight = 1

        return weight * loss
    
    # this is not used
    def FULLdenoising_score_matching_loss(self, image_batch, training=None):
        total_loss = 0.
        for noise in self.noise_scales_tensor:
            noisy_batch = image_batch + noise * tf.random.normal(tf.shape(image_batch))
            noise_input = tf.repeat(noise, tf.shape(noisy_batch)[0])[:, None, None, None]

            score = self([noisy_batch, noise_input], training=training)
            target_score = -1 * (noisy_batch - image_batch) / noise
            loss = tf.reduce_mean(0.5 * tf.reduce_sum((score - target_score)**2, axis=[1,2,3]))

            weight = 1 / self.num_noise_scales
            total_loss += weight * loss

        return total_loss
    
    @tf.function(jit_compile=True)
    def langevin_step(self, sample, alpha, noise, noise_input):
        sample = (sample 
                  + alpha*self([sample, noise_input], training=False) / noise
                  + tf.math.sqrt(2*alpha)*tf.random.normal(tf.shape(sample)))
        return sample

    def langevin_sampler(self, n_steps, epsilon, n_samples=64, denoise=True, show_intermediate=False):
        sample = self.noise_scales_tensor[0] * tf.random.normal((n_samples,) + self.input_shape[0][1:])

        if show_intermediate:
            plt.figure(figsize=(15, 15))
            for ind, image in enumerate(sample):
                plt.subplot(8, 8, ind+1)
                plt.imshow(image, cmap="Greys")
                plt.axis("off")
            plt.suptitle("Initial samples")
            plt.show()

        for index, noise in enumerate(self.noise_scales_tensor):
            alpha = tf.cast(epsilon * (noise / self.noise_scales_tensor[-1])**2, tf.float32)
            noise_input = tf.repeat(noise, n_samples, axis=0)[:, None, None, None]

            for step in tf.range(n_steps):
                sample = self.langevin_step(sample, alpha, noise, noise_input)

            if show_intermediate and not index % show_intermediate:
                plt.figure(figsize=(15, 15))
                for ind, image in enumerate(sample):
                    plt.subplot(8, 8, ind+1)
                    plt.imshow(image, cmap="Greys", vmin=0, vmax=1)
                    plt.axis("off")
                plt.suptitle("Noise scale {}".format(noise))
                plt.show()
                
        if denoise:
            sample = sample + noise**2 * score_model([sample, noise_input]) / noise

        return sample

In [None]:
import tensorflow_addons as tfa

norm = lambda **kwargs: tfa.layers.GroupNormalization(groups=8, **kwargs)

def residual_stack(inputs, filters, strides, blocks_per_level, mode, name):
    all_outputs = []
    outputs = inputs
    for level_ind, (level_filters, level_stride) in enumerate(zip(filters, strides)):
        for block_ind in range(blocks_per_level):
            outputs = ResidualBlock(level_filters,
                                    mode, 
                                    strides=level_stride if block_ind == 0 else 1,
                                    name="_".join([name, str(level_ind+1), str(block_ind+1)]),
                                    normalization=norm)(outputs)
            all_outputs.append(outputs)
        
    return outputs, all_outputs

def residual_stack_d(inputs, all_hidden, filters, strides, blocks_per_level, mode, name):
    outputs = inputs
    global_ind = 0
    for level_ind, (level_filters, level_stride) in enumerate(zip(filters, strides)):
        for block_ind in range(blocks_per_level):
            if global_ind > 0:
                if outputs.shape[1] != all_hidden[global_ind].shape[1]:
                    all_hidden[global_ind] = tfkl.AvgPool2D(padding="same")(all_hidden[global_ind])
                outputs = tf.concat((outputs, all_hidden[global_ind]), axis=-1)
            global_ind += 1
            
            outputs = ResidualBlock(level_filters,
                                    mode, 
                                    strides=level_stride if block_ind == 0 else 1,
                                    name="_".join([name, str(level_ind+1), str(block_ind+1)]),
                                    normalization=norm)(outputs)
        
    return outputs

In [None]:
inp = tf.keras.Input((64, 64, 3))
noise_input = tf.keras.Input((1, 1, 1))

blocks_per_level = 4
filters = [64, 128, 256, 512, 768]
strides = [1, 2, 2, 2, 2]
encoder_output, all_hidden = residual_stack(inp, filters, strides, blocks_per_level, "conv", "encoder")

decoder_output = residual_stack_d(encoder_output, list(reversed(all_hidden)), reversed(filters), strides, blocks_per_level, "upconv", "decoder")
decoder_final = tfkl.Conv2D(3, 1)(decoder_output)

score_model = ScoreMatching([inp, noise_input], decoder_final, noise_scales)
score_model.summary()

In [None]:
ema = 0.9999
t = np.arange(250000)
plt.plot(t, ema**t)

In [None]:
train_steps = 250000
n_data = 60000
n_epochs = train_steps // (n_data // batch_size)
lr = tf.optimizers.schedules.CosineDecay(0.0001, train_steps)
optimizer = tf.optimizers.Adam(lr, use_ema=True, ema_momentum=ema)

score_model.compile(optimizer=optimizer, jit_compile=True)

In [None]:
class ImageGenCallback(tf.keras.callbacks.Callback):
    def __init__(self, frequency, **kwargs):
        super().__init__(**kwargs)
        self.frequency = frequency
        
    def on_epoch_begin(self, epoch, logs=None):
        if not epoch % self.frequency:
            generated_batch = self.model.langevin_sampler(t_per_noise_scale, epsilon)
        
            plt.figure(figsize=(15,15))
            for ind, image in enumerate(generated_batch):
                plt.subplot(8, 8, ind+1)
                plt.imshow(np.clip(image, 0, 1))
                plt.axis("off")
            plt.suptitle("Random generations")
            plt.show()


score_model.fit(train_data, validation_data=test_data, epochs=n_epochs, callbacks=ImageGenCallback(10))

In [None]:
score_model.save_weights("weights/weights_score_faces.hdf5")

In [None]:
generated_batch = score_model.langevin_sampler(t_per_noise_scale, epsilon)
        
plt.figure(figsize=(15,15))
for ind, image in enumerate(generated_batch):
    plt.subplot(8, 8, ind+1)
    plt.imshow(np.clip(image, 0, 1))
    plt.axis("off")
plt.suptitle("Random generations")
plt.show()

In [None]:
# alternative sampler version where we start at some noise scale that is not the largest one.
# "encode" data to noisy version, then decode.
# -> get similar, but different samples.
def langevin_sampler_partial(self, start, n_steps, epsilon, from_=0, n_samples=64, denoise=True, show_intermediate=False):
        sample = start + self.noise_scales_tensor[from_] * tf.random.normal(start.shape)

        if show_intermediate:
            plt.figure(figsize=(15, 15))
            for ind, image in enumerate(sample):
                plt.subplot(8, 8, ind+1)
                plt.imshow(image, cmap="Greys")
                plt.axis("off")
            plt.suptitle("Initial samples")
            plt.show()

        for index, noise in enumerate(self.noise_scales_tensor[from_:]):
            alpha = tf.cast(epsilon * (noise / self.noise_scales_tensor[-1])**2, tf.float32)
            noise_input = tf.repeat(noise, n_samples, axis=0)[:, None, None, None]

            for step in tf.range(n_steps):
                sample = self.langevin_step(sample, alpha, noise, noise_input)

            if show_intermediate and not index % show_intermediate:
                plt.figure(figsize=(15, 15))
                for ind, image in enumerate(sample):
                    plt.subplot(8, 8, ind+1)
                    plt.imshow(image, cmap="Greys", vmin=0, vmax=1)
                    plt.axis("off")
                plt.suptitle("Noise scale {}".format(noise))
                plt.show()
                
        # weird formula using the final noise
        if denoise:
            sample = sample + noise**2 * score_model([sample, noise_input]) / noise

        return sample

In [None]:
generated_batch = langevin_sampler_partial(score_model, test_images[:64], t_per_noise_scale, epsilon, from_=400)
        
plt.figure(figsize=(15,15))
for ind, image in enumerate(generated_batch):
    plt.subplot(8, 8, ind+1)
    plt.imshow(np.clip(image, 0, 1))
    plt.axis("off")
plt.suptitle("Random generations")
plt.show()

In [None]:
base_image = np.repeat(test_images[:1], 63, axis=0)
        
plt.figure(figsize=(15,15))
plt.subplot(8, 8, 1)
plt.imshow(base_image[0])
plt.axis("off")

generated_batch = langevin_sampler_partial(score_model, base_image, t_per_noise_scale, epsilon, from_=300, n_samples=63)
for ind, image in enumerate(generated_batch):
    plt.subplot(8, 8, ind+2)
    plt.imshow(np.clip(image, 0, 1))
    plt.axis("off")
plt.suptitle("Random generations")
plt.show()

In [None]:
# inpainting taken from the paper.
# doesn't work too well, in my experience...
@tf.function(jit_compile=True)
def langevin_step_inpaint(self, sample, alpha, noise, noise_input, y, mask):
    sample = (sample 
              + alpha*self([sample, noise_input], training=False) / noise
              + tf.math.sqrt(2*alpha)*tf.random.normal(tf.shape(sample)))
    sample = sample * (1 - mask) + y * mask
    return sample

def langevin_sampler_inpaint(self, start, mask, n_steps, epsilon, n_samples=64, denoise=True, show_intermediate=False):
    sample = self.noise_scales_tensor[0] * tf.random.normal(start.shape)

    if show_intermediate:
        plt.figure(figsize=(15, 15))
        for ind, image in enumerate(sample):
            plt.subplot(8, 8, ind+1)
            plt.imshow(image, cmap="Greys")
            plt.axis("off")
        plt.suptitle("Initial samples")
        plt.show()

    for index, noise in enumerate(self.noise_scales_tensor):
        alpha = tf.cast(epsilon * (noise / self.noise_scales_tensor[-1])**2, tf.float32)
        noise_input = tf.repeat(noise, n_samples, axis=0)[:, None, None, None]
        
        y = start + noise * tf.random.normal(start.shape)

        for step in tf.range(n_steps):
            sample = langevin_step_inpaint(self, sample, alpha, noise, noise_input, y, mask)

        if show_intermediate and not index % show_intermediate:
            plt.figure(figsize=(15, 15))
            for ind, image in enumerate(sample):
                plt.subplot(8, 8, ind+1)
                plt.imshow(image, cmap="Greys", vmin=0, vmax=1)
                plt.axis("off")
            plt.suptitle("Noise scale {}".format(noise))
            plt.show()

    # weird formula using the final noise
    if denoise:
        sample = sample + noise**2 * score_model([sample, noise_input]) / noise

    return sample

In [None]:
# any binary mask is fine
mask = np.zeros((64, 64, 3), dtype=np.float32)
mask[0:32, :, :] = 1

plt.imshow(mask)

In [None]:
generated_batch = langevin_sampler_inpaint(score_model, test_images[:64], mask, t_per_noise_scale, epsilon)
        
plt.figure(figsize=(15,15))
for ind, image in enumerate(generated_batch):
    plt.subplot(8, 8, ind+1)
    plt.imshow(np.clip(image, 0, 1))
    plt.axis("off")
plt.suptitle("Random generations")
plt.show()

In [None]:
base_image = np.repeat(test_images[:1], 63, axis=0)
generated_batch = langevin_sampler_inpaint(score_model, base_image, mask, t_per_noise_scale, epsilon, n_samples=63)

plt.figure(figsize=(15,15))
plt.subplot(8, 8, 1)
plt.imshow(base_image[0])
plt.axis("off")
for ind, image in enumerate(generated_batch):
    plt.subplot(8, 8, ind+2)
    plt.imshow(np.clip(image, 0, 1))
    plt.axis("off")
plt.suptitle("Random generations")
plt.show()

In [None]:
# larger samples show that 64x64 is not sufficient
generated_batch = score_model.langevin_sampler(t_per_noise_scale, epsilon, n_samples=16)
        
plt.figure(figsize=(15,15))
for ind, image in enumerate(generated_batch):
    plt.subplot(4, 4, ind+1)
    plt.imshow(np.clip(image, 0, 1))
    plt.axis("off")
plt.suptitle("Random generations")
plt.show()

In [None]:
generated_batch = score_model.langevin_sampler(t_per_noise_scale, epsilon, n_samples=4)
        
plt.figure(figsize=(15,15))
for ind, image in enumerate(generated_batch):
    plt.subplot(2, 2, ind+1)
    plt.imshow(np.clip(image, 0, 1))
    plt.axis("off")
plt.suptitle("Random generations")
plt.show()