In [2]:
from typing import Tuple
import os 
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import numpy as np 
from typing import List 
import tensorflow as tf
from keras import layers, models, activations, losses, optimizers, utils


class GaussianDiffusion():
    def __init__(self):
        # args 
        beta_min, beta_max, t_size = 1e-4, 0.02, 1000

        # members
        self.t_num = t_size
        
        self.betas = tf.cast(tf.linspace(beta_min, beta_max, t_size), tf.float32)
        self.alphas = 1.0 - self.betas
        
        self.alphas_cumprod = tf.math.cumprod(self.alphas)
        self.alphas_cumprod_prev = tf.concat([[1], self.alphas_cumprod[:-1]], axis=0)
        self.alphas_1 = tf.sqrt(self.alphas_cumprod)
        self.alphas_2 = tf.sqrt(1.0 - self.alphas_cumprod)

        self.alphas_preds_1 = tf.sqrt(1. / self.alphas_cumprod)
        self.alphas_preds_2 = tf.sqrt(1. / self.alphas_cumprod - 1)
        

    def get_alphas(self, alphas: tf.Tensor, t: tf.Tensor, shape: List[int]) -> tf.Tensor:
        data = tf.gather(alphas, t)
        dims = [-1] + [1] * (len(shape) - 1)
        return tf.reshape(data, dims)

    def get_train_sample(self, x: tf.Tensor, t: tf.Tensor, noise: tf.Tensor) -> tf.Tensor:
        alphas_a = self.get_alphas(self.alphas_1, t, x.shape)
        alphas_b = self.get_alphas(self.alphas_2, t, x.shape)
        return x * alphas_a - noise * alphas_b

    def get_predict_sample(self, x: tf.Tensor, t: tf.Tensor, noise: tf.Tensor) -> tf.Tensor:
        alphas_a = self.get_alphas(self.alphas_preds_1, t, x.shape)
        alphas_b = self.get_alphas(self.alphas_preds_2, t, x.shape)
        return alphas_a * x - noise * alphas_b

class NINBlock(layers.Layer):
    def __init__(self, num_units, init_scale=1.0, **kwargs):
        super(NINBlock, self).__init__(**kwargs)
        self.num_units = num_units
        self.init_scale = init_scale

    def build(self, input_shape):
        in_dim = input_shape[-1]
        self.W = self.add_weight("W", shape=(in_dim, self.num_units), initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.init_scale))
        self.b = self.add_weight("b", shape=(self.num_units,), initializer="zeros")

    def call(self, inputs):
        return tf.tensordot(inputs, self.W, axes=1) + self.b
    
class AttentionBlock(layers.Layer):
    def __init__(self, **kwargs):
        super(AttentionBlock, self).__init__(**kwargs)

    def build(self, input_shape):
        self.reshape_1 = layers.Reshape([input_shape[1] * input_shape[2], -1])
        self.reshape_2 = layers.Reshape(input_shape[1:])
        self.att = layers.Attention()

    def call(self, x: tf.Tensor):
        x = self.reshape_1(x)
        x = self.att([x, x])
        x = self.reshape_2(x)
        return x

class ResnetBlock(layers.Layer):
    def __init__(self, filters: int, **kwargs):
        super(ResnetBlock, self).__init__(**kwargs)
        self.filters = filters
        

    def build(self, input_shape):
        x_shape, _ = input_shape
        self.nin = NINBlock(self.filters) if self.filters != x_shape[-1] else layers.Activation(activations.linear)
            
        self.block_x_1 = models.Sequential([
            layers.LayerNormalization(),
            layers.Activation(activations.swish),
            layers.Conv2D(self.filters, (3, 3), padding="same")
        ])
        
        self.block_x_2 = models.Sequential([
            layers.LayerNormalization(),
            layers.Activation(activations.swish),
            layers.Conv2D(self.filters, (3, 3), padding="same")
        ])

        self.block_t = models.Sequential([
            layers.Activation(activations.swish),
            layers.Dense(self.filters),
            layers.Reshape((1, 1, -1))
        ])

    def call(self, inputs: Tuple[tf.Tensor, tf.Tensor]):
        x_inp, t_inp = inputs
        x = self.block_x_1(x_inp)
        t = self.block_t(t_inp)
        x = self.block_x_2(x + t)
        return self.nin(x_inp) + x

class NoiseModel():
    def __init__(self) -> None:
        self.t_size = 1000
        self.GD = GaussianDiffusion()
        self.model: tf.keras.Model = self.get_model()
        self.model.compile(optimizers.Adam(1e-4), losses.MeanSquaredError())

    def embedding_series(self, t: tf.Tensor, filters: int) -> tf.Tensor:
        half_dim = filters // 2
        emb = -(tf.math.log(10000.0) / (half_dim - 1))
        emb = tf.exp(tf.range(half_dim, dtype=emb.dtype) * emb)
        emb = t * emb[tf.newaxis, :]
        emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=1)
        if filters % 2 == 1:
            emb = tf.pad(emb, [[0, 0], [0, 1]])
        return emb

    def get_model(self):
        filters = 32
        res_blocks = 2
        attention_width = (48,)
        filters_coef = (1, 2, 4, 8)
        
        x = input_image = layers.Input((192, 240, 1), name = "image")
        t = input_time = layers.Input((1,), name = "time")

        t = self.embedding_series(t, filters)
        t = layers.Dense(512, activation=activations.swish)(t)
        t = layers.Dense(512)(t)

        encoder = [layers.Conv2D(filters, (3, 3), padding="same")(x)]
        for i, coef in enumerate(filters_coef):
            for _ in range(res_blocks):
                r = ResnetBlock(filters * coef)([encoder[-1], t])
                if r.shape[1] in attention_width:
                    r = AttentionBlock()(r)
                encoder.append(r)
                
            if i != len(filters_coef) - 1:
                encoder.append(layers.Conv2D(filters * coef, (3, 3), padding="same", strides=2)(encoder[-1]))

        r = ResnetBlock(filters * coef)([encoder[-1], t])
        r = AttentionBlock()(r)
        r = ResnetBlock(filters * coef)([r, t])

        for i, coef in enumerate(reversed(filters_coef)):
            for _ in range(res_blocks + 1):
                c = layers.Concatenate()([r, encoder.pop()])
                r = ResnetBlock(filters * coef)([c, t])
                if r.shape[1] in attention_width:
                    r = AttentionBlock()(r)
            if i != len(filters_coef) - 1:
                r = layers.UpSampling2D((2, 2))(r)
                r = layers.Conv2D(filters * coef, (3, 3), padding="same")(r)

        x = layers.LayerNormalization()(r)
        x = layers.Activation(activations.swish)(x)
        x = layers.Conv2D(1, (3, 3), padding="same")(x)

        return models.Model([input_image, input_time], x)
        

    @tf.function
    def train_batch(self, x_noisy: tf.Tensor, t: tf.Tensor, noise: tf.Tensor):
        with tf.GradientTape() as tape:
            noise_pred = self.model([x_noisy, t])
            loss = self.model.loss(noise, noise_pred)

        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.model.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

        return {"loss_mse": loss}

    def predict(self, images: tf.Tensor = None):
        x = images
        if images is None:
            x = tf.random.normal([1] + self.model.input_shape[0][1:])

        
        for _t in range(self.t_size - 1, -1, -1):
            t = tf.constant([[_t]] * len(x), tf.int32)
            noise_pred = self.model.predict_on_batch([x, t])
            

        return x
            

    def train(self):
        pass
            
                

m = NoiseModel()
m.train()

{'loss_mse': <tf.Tensor: shape=(), dtype=float32, numpy=2.9299161>}

In [3]:
m.train_batch(np.zeros((4, 192, 240, 1)), np.ones((4, 1)), np.ones((4, 192, 240, 1)))

{'loss_mse': <tf.Tensor: shape=(), dtype=float32, numpy=1.3384873>}