<a href="https://colab.research.google.com/github/Pumafi/dl_spatial_gen_geol_facies/blob/main/ScoreBased_TF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Score-Based Generative Modeling (Tensorflow)

Translation of "Generative Modeling by Estimating Gradients of the Data Distribution" by Yang Song from PyTorch to Tensorflow 

Source Tutorial : https://yang-song.net/blog/2021/score/

Source PyTorch code : https://colab.research.google.com/drive/120kYYBOVa1i0TD85RjlEkFjaWDxSFUx3?usp=sharing#scrollTo=YyQtV7155Nht

In [1]:
!pip install tqdm -U
!python3 -m pip install tensorflow_addons

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorflow_addons
  Downloading tensorflow_addons-0.19.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting typeguard>=2.7
  Downloading typeguard-3.0.1-py3-none-any.whl (30 kB)
Installing collected packages: typeguard, tensorflow_addons
Successfully installed tensorflow_addons-0.19.0 typeguard-3.0.1


In [2]:
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import math

## Model

In [785]:
class GaussianFourierProjection(tf.keras.layers.Layer):
    """Gaussian random features for encoding time steps."""  
    def __init__(self, embed_dim, scale=30.):
        super().__init__()
        # Randomly sample weights during initialization. These weights are fixed 
        # during optimization and are not trainable.
        self.W = self.add_weight(shape=(embed_dim // 2,),
                                 trainable=False,
                                 initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=1), name="GFP") * scale # Pb vient pas du scale, empire si on l'enleve
        
    def call(self, x):
        x_proj = tf.expand_dims(x, axis=-1) * tf.expand_dims(self.W, axis=0) * 2 * np.pi
        y = tf.concat([tf.math.sin(x_proj), tf.cos(x_proj)], axis=-1)
        return y # Probleme vient pas de là :()

class CustomLinear(tf.keras.layers.Layer):
    """Rhaaah."""  
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.W = tf.random.uniform((input_dim, output_dim), minval=-tf.math.sqrt(1/input_dim), maxval=tf.math.sqrt(1/input_dim))
        self.b = tf.random.uniform((1, output_dim, ), minval=-tf.math.sqrt(1/input_dim), maxval=tf.math.sqrt(1/input_dim))
        
    def call(self, x):
        y = tf.tensordot(x, self.W, 1) + self.b

        return y


class DenseFeatures(tf.keras.layers.Layer):
    """A fully connected layer that reshapes outputs to feature maps."""
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.output_dim = output_dim
        self.dense = CustomLinear(input_dim, output_dim)

    def call(self, x):
        y = self.dense(x)
        return tf.expand_dims(tf.expand_dims(y, axis=1), axis=1)

# TODO : Dense that can be used in convolutions ? Concat ?
# ==> Used to apply to channels

class ScoreNet(tf.keras.Model):
    """A time-dependent score-based model built upon U-Net architecture."""
    def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):
        super().__init__()

        # Layers
        ## Gaussian random features embedding layer for time
        self.embeding = tf.keras.Sequential([GaussianFourierProjection(embed_dim=embed_dim),
                                             CustomLinear(embed_dim, embed_dim)]) # Problem vient de la Dense ?
        
        # Dimension Reduction
        self.conv1 = tf.keras.layers.Conv2D(channels[0], (3, 3), strides=1, use_bias=False)
        self.dense1 = DenseFeatures(embed_dim, channels[0])
        self.gnorm1 = tfa.layers.GroupNormalization(4, epsilon=1e-05)

        self.conv2 = tf.keras.layers.Conv2D(channels[1], (3, 3), strides=2, use_bias=False)
        self.dense2 = DenseFeatures(embed_dim, channels[1])
        self.gnorm2 = tfa.layers.GroupNormalization(32, epsilon=1e-05)

        self.conv3 = tf.keras.layers.Conv2D(channels[2], (3, 3), strides=2, use_bias=False)
        self.dense3 = DenseFeatures(embed_dim, channels[2])
        self.gnorm3 = tfa.layers.GroupNormalization(32, epsilon=1e-05)

        self.conv4 = tf.keras.layers.Conv2D(channels[3], (3, 3), strides=2, use_bias=False)
        self.dense4 = DenseFeatures(embed_dim, channels[3])
        self.gnorm4 = tfa.layers.GroupNormalization(32, epsilon=1e-05)

        # Dimension Reconstruction
        self.tconv4 = tf.keras.layers.Conv2DTranspose(channels[2], (3, 3), strides=2, use_bias=False)
        self.tdense4 = DenseFeatures(embed_dim, channels[2])
        self.tgnorm4 = tfa.layers.GroupNormalization(32, epsilon=1e-05)

        self.tconv3 = tf.keras.layers.Conv2DTranspose(channels[1], (3, 3), strides=2, use_bias=False, output_padding=1)
        self.tdense3 = DenseFeatures(embed_dim, channels[1])
        self.tgnorm3 = tfa.layers.GroupNormalization(32, epsilon=1e-05)

        self.tconv2 = tf.keras.layers.Conv2DTranspose(channels[0], (3, 3), strides=2, use_bias=False, output_padding=1)
        self.tdense2 = DenseFeatures(embed_dim, channels[0])
        self.tgnorm2 = tfa.layers.GroupNormalization(32, epsilon=1e-05)
        # Gradient explosion lol
        self.tconv1 = tf.keras.layers.Conv2DTranspose(1, (3, 3), strides=1)

        self.act = lambda x: x * tf.math.sigmoid(x)
        self.marginal_prob_std = marginal_prob_std

    def call(self, x, t): 
      # Obtain the Gaussian random feature embedding for t
      embed = self.act(self.embeding(t))

      # Encoding path
      h1 = self.conv1(x) 
      ## Incorporate information from t
      h1 += self.dense1(embed)
      ## Group normalization
      h1 = self.gnorm1(h1)
      h1 = self.act(h1)
      h2 = self.conv2(h1)
      h2 += self.dense2(embed)
      h2 = self.gnorm2(h2)
      h2 = self.act(h2)
      h3 = self.conv3(h2)
      h3 += self.dense3(embed)
      h3 = self.gnorm3(h3)
      h3 = self.act(h3)
      h4 = self.conv4(h3)
      h4 += self.dense4(embed)
      h4 = self.gnorm4(h4)
      h4 = self.act(h4)

      # Decoding path
      h = self.tconv4(h4)
      ## Skip connection from the encoding path
      h += self.tdense4(embed)
      h = self.tgnorm4(h)
      h = self.act(h)

      h = self.tconv3(tf.concat([h, h3], axis=-1))
      h += self.tdense3(embed)
      h = self.tgnorm3(h)
      h = self.act(h)

      h = self.tconv2(tf.concat([h, h2], axis=-1))
      h += self.tdense2(embed)
      h = self.tgnorm2(h)
      h = self.act(h)

      h = self.tconv1(tf.concat([h, h1], axis=-1))

      # Normalize output
      h = h / self.marginal_prob_std(t)
      h = tf.reshape(h, (-1, *x.shape))
      return h

## Setting up the SDE

In [786]:
import functools

@tf.function
def marginal_prob_std(t, sigma=25.):
    """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.

    Args:    
      t: A vector of time steps.
      sigma: The $\sigma$ in our SDE.  
    
    Returns:
      The standard deviation.
    """
    return tf.math.sqrt((sigma**(2 * t) - 1.) / 2 / tf.math.log(sigma))

@tf.function
def diffusion_coeff(t, sigma=25.):
  """Compute the diffusion coefficient of our SDE.

  Args:
    t: A vector of time steps.
    sigma: The $\sigma$ in our SDE.
  
  Returns:
    The vector of diffusion coefficients.
  """
  return sigma**t

sigma =  25.0# @param {'type':'number'}


marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

In [787]:
# std is clear
# y not okay
# 

In [788]:
tf.config.run_functions_eagerly(True)

In [789]:
@tf.function
def loss_fn(model, x, marginal_prob_std, eps=1e-5):
    random_t = tf.random.uniform((x.shape[0],), minval=0., maxval=1.0) * (1. - eps) + eps
    z = tf.random.normal(x.shape)
    std = marginal_prob_std(random_t)
    #print(std.numpy())
    perturbed_x = x + z * tf.reshape(std, (-1, 1, 1, 1))
    score = model(perturbed_x, random_t)
    y = score * tf.reshape(std, (-1, 1, 1, 1)) + z
    #print("\nY")
    #print(y.shape)
    #print(y.numpy().min())
    #print(y.numpy().mean())
    #print(y.numpy().max())
    #print("Y\n")
    loss = tf.math.reduce_mean(tf.reduce_sum(y**2, axis=[1, 2, 3]))
    return loss

## Loading the data

In [790]:
from tensorflow.keras.datasets import fashion_mnist, mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = np.expand_dims(x_train.astype("float32") / 255, axis=-1)
x_test = np.expand_dims(x_test.astype("float32") / 255, axis=-1)

In [791]:
x_train.shape

(60000, 28, 28, 1)

In [792]:
## size of a mini-batch
batch_size =  32 #@param {'type':'integer'}

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)


## Training

In [793]:
import tqdm

In [794]:
from tqdm import notebook

In [None]:
n_epochs=50#@param {'type':'integer'}

## learning rate
lr=1e-4 #@param {'type':'number'}

score_model = ScoreNet(marginal_prob_std=marginal_prob_std_fn)

optimizer = tf.keras.optimizers.experimental.Adam(learning_rate=lr)
tqdm_epoch = notebook.trange(n_epochs) #tqdm.notebook.trange(n_epochs)

for epoch in tqdm_epoch:
  avg_loss = 0.
  num_items = 0
  for x, _ in train_dataset:
      with tf.GradientTape() as tape:
          loss = loss_fn(score_model, x, marginal_prob_std_fn)
      avg_loss += loss.numpy() * x.shape[0]
      num_items += x.shape[0]
      gradients = tape.gradient(loss, score_model.trainable_variables)
      optimizer.apply_gradients(zip(gradients, score_model.trainable_variables))

  tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))


  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
# 25616.156604: 16%
# 25753.208589: 18%



In [None]:
def aef