In [None]:
# first on the noise-conditional score model:
# it could be a multi-input model like this

# data shape: for example for padded MNIST (32x32 images with 1 channel)
data_input = tf.keras.Input((32, 32, 1))
# noise shape: convenient to have same number of dimensions for element-wise division later
sigma_input = tf.keras.Input((1, 1, 1))

score_output = Network(data_input)
# this is the conditioning method proposed in the improved paper (technique 3)
conditional_score_output = score_output / sigma_input

score_model = tf.keras.Model([data_input, sigma_input], conditional_score_output)


# ...and can be called like this
sigma_batch = tf.repeat(sigma, tf.shape(noisy_data_batch)[0])[:, None, None, None]
score_model([noisy_data_batch, sigma])

# sigma_batch: take a sigma, repeat over batch axes, add axes for width, height, channels

In [None]:
# on the loss function:
# use the loss from equation 2 in the improved paper https://arxiv.org/pdf/2006.09011.pdf

# advanced note: compare this to equations 5 and 6 from the original https://arxiv.org/pdf/1907.05600.pdf
# there, they have a weight lambda(sigma_i) for each noise level.
# they propose lambda(sigma_i) = sigma_i**2.

# in equation 2 of the advanced paper, they already inserted this weight and did some simplifications.
# equation 2 involves sigma_i * conditional_score_model.

# if we insert conditional_score_model = score_model / sigma_i, we have sigma_i * score_model / sigma_i.
# we can thus remove sigma_i from the equation and are left with the unconditional score_model!

# as such, you could implement the score model completely unconditionally, and ignore the noise_input.
# you only have to remember to manually divide by sigma_i, when needed, e.g. for the langevin sampler.


# if all of this is confusing to you, ignore it and stick with the conditional model from the 1st cell.


# finally, you can use TF functions to sample sigma like this.
# assuming you have a 1D tensor noise_scales_tensor that contains all the different sigmas.
# syntax is a bit annoying -- we sample a random tensor of shape [1], then take the index [0] to get a scalar...
random_index = tf.random.uniform([1], 0, len(sigmas_tensor), dtype=tf.int32)[0]
random_sigma = sigmas_tensor[random_index]

In [None]:
# the rest is on choosing noise scales etc.
# the improved paper proposes to use sigma_L = largest euclidean distance between data points.
# actually computing this distance will take a long time.
# but you can get a good approximation from a subset of data.
# for MNIST I get around 16.3 or so.

# let's just use 20
biggest_distance = 20

In [None]:
# how many scales to choose. this is technique 3 from the paper.
# ideally the value C computed at the bottom should be ~1.
# but they say that aiming for a value such as 0.5 is an acceptable compromise.
# generally, the smaller you make gamma, the larger C will get.
from scipy import stats
import numpy as np

d = 32*32  # data dimensionality -- 32*32 for padded MNIST
wish_gamma = 1.05  # <-- this is the target ratio between successive noise scales, try values > 1

upper_limit = np.sqrt(2*d) * (wish_gamma - 1) + 3*wish_gamma
lower_limit = np.sqrt(2*d) * (wish_gamma - 1) - 3*wish_gamma

c_value = stats.norm.cdf(upper_limit) - stats.norm.cdf(lower_limit)

print("C value is {}; should be 0.5 or higher! Too low? Make gamma smaller!".format(c_value))

In [None]:
# above, we got the wish_gamma that we would LIKE to have.
# you could directly use that to define a noise scale.
# or you use geomspace, but that doesn't allow for specification of gamma.
# you can only say how many scales you want.

# basically, the larger n_noise_scales, the smaller gamma will be (noises will be closer together).
# increase n_noise_scales until gamma is at or below the target gamma you got above for a good C value.

n_noise_scales = 200
target_noise = 0.001
noise_scales = np.geomspace(biggest_distance, target_noise, n_noise_scales, dtype=np.float32)
true_gamma = noise_scales[0] / noise_scales[1]
print("Gamma is {}, should be {} or lower! Too high? Make n_noise_scales larger!".format(true_gamma, wish_gamma))

In [None]:
# finally, technique 4 from the paper concerns the choice of epsilon for the langevin sampler.
def amazing_formula(gamma, t, eps):
    final_sig_sq = noise_scales[-1]**2
    first = (1 - (eps / final_sig_sq))**(2*t)
    second = gamma**2 - 2*eps / (final_sig_sq - final_sig_sq * (1 - eps/final_sig_sq)**2)
    third = 2*eps / (final_sig_sq - final_sig_sq * (1 - eps/final_sig_sq)**2)
    
    return first*second + third

In [None]:
# first, set t_total to however many steps you can afford to run in total.
# 1000 is probably on the lower side tbh.
# also maybe choose it as a multiple of n_noise_scales

t_total = 1000
t_per_noise_scale = t_total // n_noise_scales

# if you kept all the numbers in the notebook the same,
# you likely won't get a much better epsilon than this. it should be around 1.07.
# usually there is a "global optimum" for the some_value below.
# both increasing or decreasing epsilon will increase the value.
epsilon = 0.00000007
some_value = amazing_formula(true_gamma, t_per_noise_scale, epsilon)

print("The thingy value is {}! It should be close to 1! "
      "Try playing around with the epsilon value.".format(some_value))

print("There will be {} iterations per noise scale.".format(t_per_noise_scale))


# "close to 1" is of course a relative statement. but e.g. if I remove two 0s from epsilon,
# i.e. I use 0.000007, the some_value is around 90 million. that's not close to 1.