#Generate 3D point cloud with keras_cv stable_diffusion and OpenAI point-e

In [1]:
!pip install -q git+https://github.com/keras-team/keras-cv
!pip install -q git+https://github.com/openai/CLIP.git
!pip install -q transformers
!pip -q install --upgrade gdown
!pip -q install trimesh

  Building wheel for keras-cv (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 53 kB 2.0 MB/s 
[?25h  Building wheel for clip (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 5.8 MB 29.4 MB/s 
[K     |████████████████████████████████| 182 kB 78.3 MB/s 
[K     |████████████████████████████████| 7.6 MB 66.8 MB/s 
[K     |████████████████████████████████| 669 kB 33.7 MB/s 
[?25h

In [2]:
import tensorflow as tf
from tensorflow import keras
from keras_cv.models.stable_diffusion import StableDiffusion, StableDiffusionV2
from PIL import Image
import numpy as np
import plotly
import clip
from transformers import TFAutoModel
import math
import matplotlib.pyplot as plt

In [None]:
clip_model = TFAutoModel.from_pretrained('openai/clip-vit-large-patch14')
clip_vision_model = clip_model.layers[0].vision_model

Downloading:   0%|          | 0.00/4.52k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

In [None]:
keras.mixed_precision.set_global_policy('mixed_float16')

In [None]:
# Lets define model configurations
config = {
    "base40M": {
        "heads": 8,
        "layers": 12,
        "n_ctx": 1024,
        "width": 512,
        "cond_ctx":256,
        "cond_width":1024,
        "input_channels":6,
        "output_channels":12,
        "weights":{"origin": "https://huggingface.co/Jobayer/keras_point_e/resolve/main/point_diffusion1.h5",
                   "file_hash": "6385fe7c70793f74a99683b105a13a6f0cbc324bf93adb22f6f8cab824279d3b"}
    },
    "base300M": {
        "heads": 16,
        "layers": 24,
        "n_ctx": 1024,
        "width": 1024,
        "cond_ctx":256,
        "cond_width":1024,
        "input_channels":6,
        "output_channels":12,
        "weights":{"origin": "https://huggingface.co/Jobayer/keras_point_e/resolve/main/point_diffusion_300M.h5",
                   "file_hash": "84220fb81e504ebe65c0a4701f9643f1875f0e0f114539da766c426ce2bd9231"}
    },
    "upsample": {
        "channel_biases": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0],
        "channel_scales": [2.0, 2.0, 2.0, 0.007843137255, 0.007843137255, 0.007843137255],
        "cond_ctx": 256,
        "cond_drop_prob": 0.1,
        "heads": 8,
        "init_scale": 0.25,
        "input_channels": 6,
        "layers": 12,
        "n_ctx": 3072,
        "width": 512,
        "cond_width": 1024,
        "low_res_ctx": 1024,
        "input_channels": 6,
        "output_channels": 12,
        "weights":{"origin": "https://huggingface.co/Jobayer/keras_point_e/resolve/main/point_upsampler.h5",
                   "file_hash": "004e909c76e8ea0fc987dfb7ae84eca22bb918f09fa2fa828ce6884ff8a1a93f"}
        
    },
    "sdf": {
        "decoder_heads": 4,
        "decoder_layers": 4,
        "encoder_heads": 4,
        "encoder_layers": 8,
        "init_scale": 0.25,
        "n_ctx": 4096,
        "name": "CrossAttentionPointCloudSDFModel",
        "width": 256,
        "weights":{"origin": "https://huggingface.co/Jobayer/keras_point_e/resolve/main/sdf.h5",
                   "file_hash": "c0261aca2f0603cece98c981615682225c1e8ea2f1a1b99432ba0a328efa1e62"}
    },
}

In [None]:
class MultiHeadAttention(keras.layers.Layer):
    def __init__(self, embed_dim=768, num_heads=12, causal=True, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.causal = causal
        self.head_dim = self.embed_dim // self.num_heads
        self.scale = ((self.head_dim)**-0.5)**0.5
        self.in_proj = keras.layers.Dense(self.embed_dim*3)
        self.out_proj = keras.layers.Dense(self.embed_dim)

    def call(self, hidden_state):
        _, tgt_length, embed_dim = hidden_state.shape
        qkv = self.in_proj(hidden_state)
        qkv = tf.reshape(qkv, [-1, tgt_length, self.num_heads, self.head_dim*3])
        query, key, value = tf.split(qkv, 3, axis=-1)
        attn_weights = tf.einsum("bthc,bshc->bhts", query*self.scale, key*self.scale)
        attn_weights = tf.nn.softmax(attn_weights)
        attn_output = tf.einsum("bhts,bshc->bthc", attn_weights, value)
        attn_output = tf.reshape(attn_output, (-1, tgt_length, embed_dim))
        return self.out_proj(attn_output)

In [None]:
class MultiHeadCrossAttention(keras.layers.Layer):
    def __init__(self, embed_dim=768, num_heads=12, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.scale = ((self.head_dim)**-0.5)**0.5
        self.query = keras.layers.Dense(self.embed_dim)
        self.key_value = keras.layers.Dense(self.embed_dim*2)
        self.out_proj = keras.layers.Dense(self.embed_dim)
        
    def call(self, hidden_state, encoded_hidden_states):
        _, tgt_length, embed_dim = hidden_state.shape
        query = self.query(hidden_state)
        query = tf.reshape(query, [-1, tgt_length, self.num_heads, self.head_dim])
        key_value = self.key_value(encoded_hidden_states)
        key_value = tf.reshape(key_value, [-1, tgt_length, self.num_heads, self.head_dim*2])
        key, value = tf.split(key_value, 2, axis=-1)
        attn_weights = tf.einsum("bthc,bshc->bhts", query*self.scale, key*self.scale)
        attn_weights = tf.nn.softmax(attn_weights)
        attn_output = tf.einsum("bhts,bshc->bthc", attn_weights, value)
        attn_output = tf.reshape(attn_output, (-1, tgt_length, embed_dim))
        return self.out_proj(attn_output)

In [None]:
class CLIPEncoderLayer(keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.layer_norm1 = keras.layers.LayerNormalization(epsilon=1e-5)
        self.attn = MultiHeadAttention(embed_dim, num_heads, causal=True)
        self.layer_norm2 = keras.layers.LayerNormalization(epsilon=1e-5)
        self.fc1 = keras.layers.Dense(embed_dim * 4)
        self.fc2 = keras.layers.Dense(embed_dim)
        
    def call(self, inputs):
        residual = inputs
        x = self.layer_norm1(inputs)
        x = self.attn(x)
        x = residual + x
        residual = x
        x = self.layer_norm2(x)
        x = self.fc1(x)
        x = tf.nn.gelu(x)
        x = self.fc2(x)
        return x + residual

In [None]:
class CLIPDecoderLayer(keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.layer_norm1 = keras.layers.LayerNormalization(epsilon=1e-5)
        self.attn = MultiHeadCrossAttention(embed_dim, num_heads)
        self.layer_norm2 = keras.layers.LayerNormalization(epsilon=1e-5)
        self.layer_norm3 = keras.layers.LayerNormalization(epsilon=1e-5)
        self.fc1 = keras.layers.Dense(embed_dim * 4)
        self.fc2 = keras.layers.Dense(embed_dim)
        
    def call(self, hidden_state, encoded_hidden_state):
        encoded_hidden_state = self.layer_norm1(encoded_hidden_state)
        residual = hidden_state
        x = self.layer_norm2(hidden_state)
        x = self.attn(x, encoded_hidden_state)
        x = residual + x
        residual = x
        x = self.layer_norm3(x)
        x = self.fc1(x)
        x = tf.nn.gelu(x)
        x = self.fc2(x)
        return x + residual

In [None]:
class CLIPEmbedding(keras.layers.Layer):
  def __init__(self, width=512, **kwargs):
    super().__init__(**kwargs)
    self.ln = keras.layers.LayerNormalization(epsilon=1e-5)
    self.fc = keras.layers.Dense(width)
  def call(self, x):
    x = self.ln(x)
    x = self.fc(x)
    return x

In [None]:
class CLIPImageGridPointDiffusionTransformer(keras.Model):

  def __init__(self, config, img_ctx_length=256, name=None, download_weights=True):
    input = keras.layers.Input(shape=(config['n_ctx'], config['input_channels']))
    t_emb_input = keras.layers.Input(shape=(config['width']))
    cond_input = keras.layers.Input(shape=(config['cond_ctx'], config['cond_width']))

    t_emb = keras.layers.Dense(config['width']*4)(t_emb_input)
    t_emb = tf.nn.gelu(t_emb)
    t_emb = keras.layers.Dense(config['width'])(t_emb)
    
    x = keras.layers.Dense(config['width'])(input)
    cond_emb = CLIPEmbedding(config['width'])(cond_input)
    x = keras.layers.Concatenate(axis=1)([x, t_emb[:,None], cond_emb])
    x = keras.layers.LayerNormalization(epsilon=1e-5)(x)
    for _ in range(config['layers']):
      x = CLIPEncoderLayer(config['width'], config['heads'])(x)
    x = keras.layers.LayerNormalization(epsilon=1e-5)(x) 
    output = keras.layers.Dense(config['output_channels'])(x[:, :config['n_ctx']])
    super().__init__([input, t_emb_input, cond_input], output, name=name)
    if download_weights:
        weights_fpath = keras.utils.get_file(
            origin=config["weights"]["origin"],
            file_hash=config["weights"]["file_hash"],
        )
        self.load_weights(weights_fpath)


In [None]:
class CLIPImageGridPointUpsamplerTransformer(keras.Model):

  def __init__(self, config, img_ctx_length=256, name=None, download_weights=True):
    input = keras.layers.Input(shape=(config['n_ctx'], config['input_channels']))
    t_emb_input = keras.layers.Input(shape=(config['width']))
    cond_input = keras.layers.Input(shape=(config['cond_ctx'], config['cond_width']))
    low_res_input = keras.layers.Input(shape=(config['low_res_ctx'], config['input_channels']))

    t_emb = keras.layers.Dense(config['width']*4, name='f_c')(t_emb_input)
    t_emb = tf.nn.gelu(t_emb)
    t_emb = keras.layers.Dense(config['width'], name='fc')(t_emb)
    
    if config['channel_scales'] is not None:
        low_res = low_res_input * tf.convert_to_tensor(config['channel_scales'])[None,None,:]
    if config['channel_biases'] is not None:
        low_res = low_res + tf.convert_to_tensor(config['channel_biases'])[None,None,:]
    low_res = keras.layers.Dense(config['width'])(low_res_input)
    
    cond_emb = CLIPEmbedding(config['width'])(cond_input)

    x = keras.layers.Dense(config['width'], name='ic')(input)
    x = keras.layers.Concatenate(axis=1)([x, t_emb[:,None], cond_emb, low_res])
    x = keras.layers.LayerNormalization(epsilon=1e-5)(x)
    for _ in range(config['layers']):
      x = CLIPEncoderLayer(config['width'], config['heads'])(x)
    x = keras.layers.LayerNormalization(epsilon=1e-5)(x) 
    output = keras.layers.Dense(config['output_channels'])(x[:, :config['n_ctx']])
    super().__init__([input, t_emb_input, cond_input, low_res_input], output, name=name)
    if download_weights:
        weights_fpath = keras.utils.get_file(
            origin=config["weights"]["origin"],
            file_hash=config["weights"]["file_hash"],
        )
        self.load_weights(weights_fpath)



In [None]:
class PointCloudSDFModel(keras.Model):
  def __init__(self, config, name=None, download_weights=True):
    input = keras.layers.Input(shape=(config['n_ctx'], 3))
    point_cloud = keras.layers.Input(shape=(config['n_ctx'], 3))

    encoded_pc = keras.layers.Dense(config['width'])(point_cloud)
    for _ in range(config['encoder_layers']):
       encoded_pc = CLIPEncoderLayer(config['width'], config['encoder_heads'])(encoded_pc)

    x = keras.layers.Dense(config['width'])(input)
    for _ in range(config['decoder_layers']):
      x = CLIPDecoderLayer(config['width'], config['decoder_heads'])(x, encoded_pc)
    x = keras.layers.LayerNormalization(epsilon=1e-5)(x)
    x = keras.layers.Dense(1)(x)
    output = tf.squeeze(x)
    super().__init__([input, point_cloud], output, name=name)
    if download_weights:
        weights_fpath = keras.utils.get_file(
            origin=config["weights"]["origin"],
            file_hash=config["weights"]["file_hash"],
        )
        self.load_weights(weights_fpath)


In [None]:
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
  """
    Create a beta schedule that discretizes the given alpha_t_bar function,
    which defines the cumulative product of (1-beta) over time from t = [0,1].

    :param num_diffusion_timesteps: the number of betas to produce.
    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that
                      part of the diffusion process.
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
  t1 = tf.range(num_diffusion_timesteps)/num_diffusion_timesteps
  t2 = (tf.range(num_diffusion_timesteps) + 1)/num_diffusion_timesteps
  betas = tf.minimum(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)
  return betas

In [None]:
num_diffusion_timesteps = 1024
num_steps=(25, 64, 64)
batch_size=1
unconditional_guidance_scale=3
seed = 000

In [None]:
stable_diffusion = StableDiffusion()

In [None]:
class TextTo3D:

    def __init__(self, stable_diffusion, vision_model, num_steps=(25,64,64)):
        #text_encoder = sd._text_encoder(jit_compile=True)
        #imge_diffusion_model = sd._diffusion_model(jit_compile=True)
        self.sd = stable_diffusion
        self.point_diffusion_model = CLIPImageGridPointDiffusionTransformer(config['base300M'])
        self.point_upsample_model = CLIPImageGridPointUpsamplerTransformer(config['upsample'])
        self.sdf_model = PointCloudSDFModel(config['sdf'])
        self.clip_vision_model = vision_model
        self.preprocessing = clip.load('ViT-L/14')[-1]
        self.num_steps=num_steps
    def image_to_point_cloud(self, img_emb, num_steps=250):
        latent = tf.random.normal(shape=(1,1024,6))
        unconditional_emb = tf.zeros_like(img_emb)
        betas = betas_for_alpha_bar(1024,
                            lambda t: tf.math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
        )
        timesteps = tf.range(1, 1024, 1024 // num_steps)
        alphas_cumprod = tf.math.cumprod(1.0 - betas)
        alphas = [alphas_cumprod[t] for t in timesteps]
        alphas_prev = [1.0] + alphas[:-1]
        progbar = keras.utils.Progbar(len(timesteps))
        iteration = 0
        for index, timestep in list(enumerate(timesteps))[::-1]:
            latent_prev = tf.cast(latent, tf.float32)  # Set aside the previous latent vector
            t_emb = sd._get_timestep_embedding(timestep, batch_size, dim=config['base300M']['width'])
            latent = self.point_diffusion_model([
                tf.concat([latent, latent],axis=0),
                tf.concat([t_emb, t_emb],axis=0),
                tf.concat([img_emb, unconditional_emb],axis=0)])
            latent, rest = latent[...,:3], latent[...,3:6]
            
            latent, unconditional_latent = tf.split(latent, 2, axis=0)
            latent = unconditional_latent + unconditional_guidance_scale * (
                latent - unconditional_latent
                )
            latent = tf.concat([latent, rest[:batch_size]], axis=-1)
            latent = tf.cast(latent, tf.float32)
            
            a_t, a_prev = alphas[index], alphas_prev[index]
            pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent[...,:6]) / math.sqrt(a_t)
            latent = latent * math.sqrt(1.0 - a_prev) + math.sqrt(a_prev) * pred_x0
            iteration += 1
            progbar.update(iteration)
        return latent

    def upsample_point_cloud(self, img_emb, low_res, num_steps=250):
        latent = tf.random.normal(shape=(1,3072,6))

        unconditional_emb = tf.zeros_like(img_emb)
        unconditional_low_res = tf.zeros_like(low_res)

        timesteps = tf.range(1, 1024, 1024 // num_steps)
        betas = tf.linspace((1000/1024) * 0.0001,  (1000/1024) * 0.02, 1024)
        alphas_cumprod = tf.math.cumprod(1.0 - betas)
        alphas = [alphas_cumprod[t] for t in timesteps]
        alphas_prev = [1.0] + alphas[:-1]
        
        progbar = keras.utils.Progbar(len(timesteps))
        iteration = 0
        for index, timestep in list(enumerate(timesteps))[::-1]:
            latent_prev = tf.cast(latent, tf.float32)  # Set aside the previous latent vector
            t_emb = sd._get_timestep_embedding(timestep, batch_size, dim=config['upsample']['width'])
            latent = self.point_upsample_model([
                tf.concat([latent, latent],axis=0),
                tf.concat([t_emb, t_emb],axis=0),
                tf.concat([img_emb, unconditional_emb],axis=0),
                tf.concat([low_res, unconditional_low_res], axis=0)]
                )
            latent, rest = latent[...,:3], latent[...,3:6]
            
            latent, unconditional_latent = tf.split(latent, 2, axis=0)
            latent = unconditional_latent + unconditional_guidance_scale * (
                latent - unconditional_latent
                )
            latent = tf.concat([latent, rest[:batch_size]], axis=-1)
            latent = tf.cast(latent, tf.float32)
            
            a_t, a_prev = alphas[index], alphas_prev[index]
            pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent[...,:6]) / math.sqrt(a_t)
            latent = latent * math.sqrt(1.0 - a_prev) + math.sqrt(a_prev) * pred_x0
            iteration += 1
            progbar.update(iteration)
        return latent

    def clip_image_encoder(self, image):
        image = Image.fromarray(image[0])
        image = self.preprocessing(image)
        embd = self.clip_vision_model.embeddings(tf.convert_to_tensor(image)[None,])
        embd = self.clip_vision_model.pre_layernorm(embd)
        img_emb = self.clip_vision_model.encoder(embd,
                     attention_mask=None,
                     causal_attention_mask=True,
                     output_attentions=False,
                     output_hidden_states=False,
                     return_dict=False)[0][:,1:]
        return img_emb
    def text_to_3d(self, prompt='A 3D avater of a dog'):
        image = self.sd.text_to_image(prompt=prompt, num_steps=self.num_steps[0], seed=seed)
        img_emb = self.clip_image_encoder(image)
        pc = self.image_to_point_cloud(img_emb, self.num_steps[1])
        upsampled_pc = self.upsample_point_cloud(img_emb, pc, self.num_steps[2])
        final_pc = tf.concat([pc, upsampled_pc], axis=1)
        return {"image": image,
                "pc":final_pc
        }
   
    def plot(self, latent):
        return plotly.graph_objects.Figure(
            data=[
                plotly.graph_objects.Scatter3d(
                    x=latent[:,0], y=latent[:,1], z=latent[:,2], 
                    mode='markers',
                    marker=dict(
                      size=2,
                      color=['rgb({},{},{})'.format(r,g,b) for r,g,b in zip(np.clip(latent[:,3],0.,1.)*255.,
                                                                            np.clip(latent[:,4], 0.,1.)*255.,
                                                                            np.clip(latent[:,5], 0.,1.)*255.)],
                                )
                    )
                ],
                layout=dict(
                    scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False))
                    ),
                    )

In [None]:
text_to_3d = TextTo3D(stable_diffusion, clip_vision_model, num_steps=num_steps)

In [None]:
sample = text_to_3d.text_to_3d('A high quality 3D render of a dog. Left view')

In [None]:
def plot_image(image):
  plt.figure(figsize=(10,10))
  plt.imshow(image)
  plt.axis('off')

In [None]:
plot_image(sample['image'][0])

In [None]:
text_to_3d.plot(sample['pc'][0])