In [31]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
import math
import tensorflow_addons as tfa
import time
import os

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input ,Conv2D, Conv2DTranspose, LeakyReLU, Activation, Concatenate, Dropout, BatchNormalization, LeakyReLU, Dense


In [32]:
from  VAE_trained import build_model, meanvar, reparameterize

In [33]:
img_size = 128
new_img_size = 32
img_channels = 3

path = 'C:\\Users\\sayan\\Desktop\\anime_pics\\128x128_'
train_paths = [os.path.join(path, img) for img in os.listdir(path)]

train_image = list(train_paths)
train_dataset = tf.data.Dataset.from_tensor_slices(train_image)

def load_images(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img)
    img = tf.image.resize(img, [img_size, img_size])
    img = (img / 127.5) - 1

    return img

train_dataset = train_dataset.map(lambda x: load_images(x))
train_dataset = train_dataset.shuffle(buffer_size=len(train_image)).batch(16).prefetch(tf.data.AUTOTUNE)

In [34]:
depth = 3

channel_multiplier = [i for i in range(depth)]
filters = [64 * 2**mult for mult in channel_multiplier]

In [35]:
enc, dec = build_model(img_size, img_channels, filters)

In [36]:
enc.load_weights('LDM_encoder.h5')

In [37]:
dec.load_weights('LDM_decoder.h5')

In [38]:
encoder = enc
decoder = dec

In [39]:
patch_size = 4
num_layers = 12
hidden_size = 384
num_heads = 6
units = 64
temb_dim = 384

In [40]:
kernel_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.05, seed=32)


In [41]:
def time_embedding(timesteps, dim):
        half_dim = dim // 2
        emb = math.log(10000) / (half_dim - 1)                    # a = 2 * ln(10000)/ d
        emb = tf.exp(tf.range(half_dim, dtype=tf.float32) * -emb) # e^(-a * i) where i runs from 0 to half_dim-1
        time = tf.cast(timesteps, dtype=tf.float32)               # pos
        emb = time[:, None] * emb[None, :]                        # pos * e^(-a * i)
        emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)      # [sin(pos * e^(-2 i ln(10000)/ d)) , cos(pos * e^(-2 i ln(10000)/ d)] where i runs from 0 to half_dim-1
        return emb

In [42]:
def TimeMLP(units, activation_fn=keras.activations.swish):
    def apply(inputs):
        time_emb = layers.Dense(units, activation=activation_fn, kernel_initializer=kernel_init)(inputs)
        time_emb = layers.Dense(units, kernel_initializer=kernel_init)(time_emb)
        return time_emb
    return apply

In [43]:
class SineCosinePositionalEmbedding2D(tf.keras.layers.Layer):
    def __init__(self, embed_dim, **kwargs):
        super(SineCosinePositionalEmbedding2D, self).__init__(**kwargs)
        assert embed_dim % 4 == 0, "Embed dim must be divisible by 4"
        self.embed_dim = embed_dim
        self.d_quarter = embed_dim // 4

    def call(self, height, width):

        H = height
        W = width

        # Create position indices
        rows = tf.cast(tf.range(H), tf.float32)[:, tf.newaxis]  # [H, 1]
        cols = tf.cast(tf.range(W), tf.float32)[:, tf.newaxis]  # [W, 1]
        div_term = tf.exp(tf.range(self.d_quarter, dtype=tf.float32) * -(tf.math.log(10000.0) / self.d_quarter))  # [D/4]

        # Sin/cos for rows
        sin_row = tf.sin(rows * div_term)                       # [H, D/4]
        cos_row = tf.cos(rows * div_term)                       # [H, D/4]
        sin_col = tf.sin(cols * div_term)                       # [W, D/4]
        cos_col = tf.cos(cols * div_term)                       # [W, D/4]

        # Expand to 3D and tile
        sin_row = tf.tile(sin_row[:, tf.newaxis, :], [1, W, 1])  # [H, W, D/4]
        cos_row = tf.tile(cos_row[:, tf.newaxis, :], [1, W, 1])
        sin_col = tf.tile(sin_col[tf.newaxis, :, :], [H, 1, 1])
        cos_col = tf.tile(cos_col[tf.newaxis, :, :], [H, 1, 1])

        # Concatenate and reshape
        pos_emb = tf.concat([sin_row, cos_row, sin_col, cos_col], axis=-1)  # [H, W, D]
        pos_emb = tf.reshape(pos_emb, [H *  W, self.embed_dim])              # [H*W, D]

        return pos_emb

In [44]:
def PatchEmbedding(tensor, patch_h, patch_w, dims = hidden_size):
  
  batch_size, height, width, channels = tf.shape(tensor)[0], tf.shape(tensor)[1], tf.shape(tensor)[2], tf.shape(tensor)[3]
  h_patches, w_patches = height // patch_h, width // patch_w
  tensor = tf.reshape(tensor, (batch_size, h_patches * w_patches, patch_h , patch_w , channels))
  tensor = tf.reshape(tensor, (batch_size, h_patches * w_patches, patch_h * patch_w * channels))

  tensor = layers.Dense(dims)(tensor)
  
  #temb = get_2d_sincos_pos_embed(dims, h_patches, w_patches)
  temb = SineCosinePositionalEmbedding2D(dims)(h_patches, w_patches)
  temb = tf.expand_dims(temb, axis=0)
  tensor = tensor + temb
  
  return tensor




In [None]:

class AttentionBlock(layers.Layer):

    def __init__(self, heads, units,**kwargs):
        super(AttentionBlock,self).__init__()
        self.units = units
        self.num_heads = heads
        assert self.units%self.num_heads == 0, 'nummber of heads is incompatible with hidden dims size'
        self.units_per_heads = self.units//self.num_heads
        
        #self.norm = layers.LayerNormalization()
        self.query_ = layers.Dense(units)
        self.key_ = layers.Dense(units)
        self.value_ = layers.Dense(units)
        self.proj = layers.Dense(units)

    def split_heads(self, input, batch_size):
        input = tf.reshape(input, (batch_size, -1, self.num_heads, self.units_per_heads))
        return tf.transpose(input, perm=[0, 2, 1, 3])
    
    def dot_prod(self,x):
      q_i, k_i, v_i = x
      scale = tf.cast(self.units_per_heads, q_i.dtype)**-0.5
      dot   = tf.matmul(q_i, k_i, transpose_b=True) * scale
      w     = tf.nn.softmax(dot, axis=-1)
      return tf.matmul(w, v_i)


    
    def call (self, inputs):
        #inputs = self.norm(inputs)
        batch_size = tf.shape(inputs)[0]

        q = self.query_(inputs)
        k = self.key_(inputs)
        v = self.value_(inputs)

        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)
        
        attn = tf.map_fn(self.dot_prod, (q, k, v), fn_output_signature=q.dtype)

        attn = tf.transpose(attn, perm=[0, 2, 1, 3])
        attn = tf.reshape(attn, (batch_size, -1, self.units))
        attn = self.proj(attn)

        return  attn








In [46]:
def MLP(attn_score, units = hidden_size):
    #x = layers.LayerNormalization()(attn_score)
    x = layers.Dense(3*units)(attn_score)
    x = keras.activations.gelu(x, approximate='tanh')
    x = layers.Dense(units)(x)
    return x 

In [47]:
def transformer_block(x, t, heads = num_heads, hidden_size = hidden_size):

  norm_attn = layers.LayerNormalization(center = False, scale = False, trainable = False)
  norm_mlp = layers.LayerNormalization(center = False, scale = False, trainable = False)

  ada_norm = layers.Dense(6* hidden_size, kernel_initializer='zeros', bias_initializer = 'zeros', activation = tf.keras.activations.swish)(t)

  (pre_attn_scale, pre_attn_shift, post_attn_scale, pre_mlp_scale, pre_mlp_shift, post_mlp_scale) = tf.split(ada_norm, num_or_size_splits=6, axis=-1)


  out = norm_attn(x) * (1 + pre_attn_scale) + pre_attn_shift
  out = out + post_attn_scale* AttentionBlock(heads, hidden_size)(out)

  out = norm_mlp(out) * (1 + pre_mlp_scale) + pre_mlp_shift
  attention_out = out + post_mlp_scale * MLP(out, hidden_size)


  return attention_out


In [48]:
def transformer(num_layers = num_layers, patch_height = patch_size, patch_width = patch_size,units = units, hidden_size = hidden_size, img_size = new_img_size, img_channels = img_channels):
    
    input = layers.Input(shape=(img_size, img_size, img_channels), name="image_input")
    time_input = keras.Input(shape=(), dtype=tf.int64, name="time_input")

    num_patches_h = img_size//patch_height
    num_patches_w = img_size//patch_width

    x = PatchEmbedding(input, patch_height, patch_width)

    
    t = time_embedding(time_input, units)
    t = TimeMLP(hidden_size)(t)[:, None, :]

    for _ in range(num_layers):
        x = transformer_block(x, t)

    x = layers.Dense(patch_height*patch_width*img_channels)(x)

    x = layers.Reshape((patch_height*num_patches_h, patch_width*num_patches_w, img_channels))(x)

    return keras.Model([input,time_input], x, name= 'DiT')


In [49]:
diffusion_model = transformer()

In [50]:
timesteps = 1000
beta_start = 0.0001
beta_end = 0.02
betas = tf.linspace(beta_start, beta_end, timesteps)                   #beta schedule

alphas = 1 - betas
alphas_cumprod = tf.math.cumprod(alphas)                               # cumulative product of alpha
alphas_cumprod_prev = tf.concat([[1.0], alphas_cumprod[:-1]], axis=0)  # previous cumulative product

sqrt_alphas_cumprod = tf.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = tf.sqrt(1.0 - alphas_cumprod)
sqrt_reciprocal_alphas_cumprod = tf.sqrt(1.0 / alphas_cumprod)

In [51]:
def forward_diffusion(x0, t, func_sqrt_alphas_cumprod = sqrt_alphas_cumprod,func_sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod):
        ''' given x0 and t, returns xt and noise '''
        
        noise = tf.random.normal(tf.shape(x0))
        
        sqrt_alphas_cumprod_t = tf.gather(func_sqrt_alphas_cumprod, t)
        sqrt_one_minus_alphas_cumprod_t = tf.gather(func_sqrt_one_minus_alphas_cumprod, t)
        
        # reshape for broadcasting
        sqrt_alphas_cumprod_t = tf.reshape(sqrt_alphas_cumprod_t, [-1, 1, 1, 1])
        sqrt_one_minus_alphas_cumprod_t = tf.reshape(sqrt_one_minus_alphas_cumprod_t, [-1, 1, 1, 1])
   
        xt = sqrt_alphas_cumprod_t * x0 + sqrt_one_minus_alphas_cumprod_t * noise #forward process    
        
        return xt, noise

In [52]:
def reverse_diffusion( xt, t, predicted_noise, func_alphas = alphas, func_alphas_cumprod = alphas_cumprod, func_alphas_cumprod_prev = alphas_cumprod_prev, func_betas = betas):
        ''' given xt, t, and original noise added/ predicted noise, returns x_prev ''' 

        alpha_t = tf.gather(func_alphas, t)
        alpha_cumprod_t = tf.gather(func_alphas_cumprod, t)
        alpha_cumprod_prev_t = tf.gather(func_alphas_cumprod_prev, t)
        beta_t = tf.gather(func_betas, t)
        
        # Reshape for broadcasting
        alpha_t = tf.reshape(alpha_t, [-1, 1, 1, 1])
        alpha_cumprod_t = tf.reshape(alpha_cumprod_t, [-1, 1, 1, 1])
        alpha_cumprod_prev_t = tf.reshape(alpha_cumprod_prev_t, [-1, 1, 1, 1])
        beta_t = tf.reshape(beta_t, [-1, 1, 1, 1])
        
        mean = (1.0 / tf.sqrt(alpha_t)) * (xt - (beta_t / tf.sqrt(1.0 - alpha_cumprod_t)) * predicted_noise)
        var = beta_t * (1.0 - alpha_cumprod_prev_t) / (1.0 - alpha_cumprod_t)
       
        noise = tf.random.normal(tf.shape(xt))

        nonzero_mask = tf.reshape(tf.cast(t != 0, tf.float32), [-1, 1, 1, 1]) #if first step, no noise is added
        
        x_prev = mean + nonzero_mask * tf.sqrt(var) * noise
        return x_prev

In [53]:
def sample(func_model, num_steps=None, batch_size=1):
        if num_steps is None:
            num_steps = 1000

        xt = tf.random.normal((batch_size, new_img_size, new_img_size, img_channels)) #last step of reverse diffusion starts with random noise
        
        # Reverse diffusion
        for i in reversed(range(num_steps)):
            t = tf.fill([batch_size], i)

            '''for each step, we predict the original noise added to x0 to give xt, and use it to get x_prev
            jumping directly from xt to x0 given predicted noise will make it not a meaningful image, we need to go step by step'''
            predicted_noise = func_model([xt, t], training=False) 
            xt = reverse_diffusion(xt, t, predicted_noise)
            
        return xt

In [54]:
def train_step_(x_batch):
        batch_size = tf.shape(x_batch)[0]
        t = tf.random.uniform([batch_size], 0, timesteps, dtype=tf.int32) # sample random timesteps
        with tf.GradientTape() as tape:
        
            xt, noise = forward_diffusion(x_batch, t)# forward diffusion
            predicted_noise = diffusion_model([xt, t], training=True)# Predict noise
            loss = tf.reduce_mean(tf.square(noise - predicted_noise))# compute loss (MSE between actual and predicted noise)
        
        gradients = tape.gradient(loss, diffusion_model.trainable_variables)
        return loss, gradients #gradients and loss

def train_diffusion_model(diffusion_model, dataset, epochs=100, learning_rate=1e-4):
    optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

    diffusion_model.compile(optimizer=optimizer)
    
    @tf.function
    def train_step(x_batch):
        loss, gradients = train_step_(x_batch)
        optimizer.apply_gradients(zip(gradients, diffusion_model.trainable_variables))
        return loss
    
    #training loop
    for epoch in range(epochs):
        epoch_loss = 0
        num_batches = 0
        
        for batch in dataset:
            mean, var = meanvar(batch, enc, training=False)
            batch_ = reparameterize(mean, var)
            loss = train_step(batch_)
            epoch_loss += loss
            num_batches += 1
        
        avg_loss = epoch_loss / num_batches
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")
        
        # Generate samples every 10 epochs
        if (epoch + 1) % 10 == 0:
            samples = sample(diffusion_model,batch_size=4)
            decoded_samples = decoder(samples, training=False)
            
            
            # Plot samples
            fig, axes = plt.subplots(1, 4, figsize=(12, 3))
            for i in range(4):
                axes[i].imshow(decoded_samples[i])
                axes[i].axis('off')

            plt.title(f"Generated samples at epoch {epoch + 1}")
            plt.show()

In [55]:
diffusion_model.load_weights('DiT.h5')

In [56]:
train_diffusion_model(diffusion_model, train_dataset, epochs=100000)

KeyboardInterrupt: 

In [None]:
diffusion_model.save_weights('DiT.h5')