In [1]:
import torch
import math
from typing import Dict, List, Tuple, Any
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.checkpoint import checkpoint
from config import NODE_FEATURE_DIMENSION, EDGE_FEATURE_DIMENSION, MAX_NUM_PRIMITIVES, GRAPH_EMBEDDING_SIZE
from matplotlib import pyplot as plt
from dataset1 import SketchDataset
import os
os.chdir('SketchGraphs/')
import sketchgraphs.data as datalib
os.chdir('../')
%matplotlib inline

In [6]:
# %%
class GD3PM(nn.Module):
  def __init__(self, device : torch.device):
    super().__init__()
    self.device = device
    self.node_dim = NODE_FEATURE_DIMENSION
    self.edge_dim = EDGE_FEATURE_DIMENSION
    self.hidden_dim = 256
    self.num_tf_layers = 24
    self.num_heads = 16
    self.max_timestep = 500
    self.max_steps = self.max_timestep + 1
    self.noise_scheduler = CosineNoiseScheduler(self.max_timestep, self.device)
    self.time_embedder = TimeEmbedder(self.max_timestep, self.hidden_dim, self.device)
    self.mlp_in_nodes = nn.Sequential(nn.Linear(in_features = self.node_dim, out_features = self.hidden_dim, device = device),
                                            nn.LeakyReLU(.1),
                                            # nn.Dropout(p = 0.1),
                                            nn.Linear(in_features = self.hidden_dim, out_features = self.hidden_dim, device = device),
                                            nn.LeakyReLU(.1),
                                            # nn.Dropout(p = 0.1)
                                           )
    self.mlp_in_edges = nn.Sequential(nn.Linear(in_features = self.edge_dim, out_features = self.hidden_dim, device = device),
                                            nn.LeakyReLU(.1),
                                            # nn.Dropout(p = 0.1),
                                            nn.Linear(in_features = self.hidden_dim, out_features = self.hidden_dim, device = device),
                                            nn.LeakyReLU(.1),
                                            # nn.Dropout(p = 0.1)
                                           )
    self.tf_layers = nn.ModuleList([TransformerLayer(num_heads = self.num_heads,
                                                     node_dim = self.hidden_dim,
                                                     edge_dim = self.hidden_dim,
                                                     device = device
                                                    )
                                    for _ in range(self.num_tf_layers)])
    self.mlp_out_nodes = nn.Sequential(nn.Linear(in_features = self.hidden_dim, out_features = self.node_dim, device = device),
                                       nn.LeakyReLU(.1),
                                    #    nn.Dropout(p = 0.1),
                                       nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = device)
                                      )
    self.mlp_out_edges = nn.Sequential(nn.Linear(in_features = self.hidden_dim, out_features = self.edge_dim, device = device),
                                       nn.LeakyReLU(.1),
                                    #    nn.Dropout(p = 0.1),
                                       nn.Linear(in_features = self.edge_dim, out_features = self.edge_dim, device = device)
                                      )

  def forward(self, nodes : Tensor, edges : Tensor, timestep : Tensor):
    # embed timestep
    time_embs = self.time_embedder(timestep) # batch_size x hidden_dim
    nodes = self.mlp_in_nodes(nodes) # batch_size x num_nodes x hidden_dim
    edges = self.mlp_in_edges(edges) # batch_size x num_nodes x num_nodes x hidden_dim
    for idx, layer in enumerate(self.tf_layers):
      nodes, edges, time_embs = layer(nodes, edges, time_embs) if idx % 12 == 0 else checkpoint(layer, nodes, edges, time_embs, use_reentrant = False)
    nodes = self.mlp_out_nodes(nodes)
    edges = self.mlp_out_edges(edges)
    return nodes, edges
  
  @torch.no_grad()
  def sample(self, batch_size : int):
    # Sample Noise
    num_nodes = MAX_NUM_PRIMITIVES
    num_node_features = NODE_FEATURE_DIMENSION
    num_edge_features = EDGE_FEATURE_DIMENSION
    nodes = torch.zeros(batch_size, num_nodes, num_node_features)
    edges = torch.zeros(batch_size, num_nodes, num_nodes, num_edge_features)
    # binary noise (isConstructible)
    nodes[:,:,0] = torch.ones(size = (batch_size * num_nodes, 2)).multinomial(1)\
                        .reshape(batch_size, num_nodes).float()
    # categorical noise (primitive type)
    nodes[:,:,1:6] = F.one_hot(torch.ones(size = (batch_size * num_nodes, 5)).multinomial(1), 5)\
                      .reshape(batch_size, num_nodes, -1).float()
    # gaussian noise (primitive parameters)
    nodes[:,:,6:] = torch.randn(size = (batch_size, num_nodes, 14))
    # categorical noise (subnode a type)
    edges[:,:,:,0:4] = F.one_hot(torch.ones(size = (batch_size * num_nodes * num_nodes, 4)).multinomial(1), 4)\
                      .reshape(batch_size, num_nodes, num_nodes, -1).float()
    # categorical noise (subnode b type)
    edges[:,:,:,4:8] = F.one_hot(torch.ones(size = (batch_size * num_nodes * num_nodes, 4)).multinomial(1), 4)\
                      .reshape(batch_size, num_nodes, num_nodes, -1).float()
    # categorical noise (subnode a type)
    edges[:,:,:,8:] = F.one_hot(torch.ones(size = (batch_size * num_nodes * num_nodes, 9)).multinomial(1), 9)\
                     .reshape(batch_size, num_nodes, num_nodes, -1).float()
    
    nodes = nodes.to(self.device)
    edges = edges.to(self.device)
    return self.denoise(nodes, edges)

  @torch.no_grad()
  def denoise(self, nodes, edges):
    for t in reversed(range(1, self.max_steps)):
      # model expects a timestep for each batch
      batch_size = nodes.size(0)
      time = torch.Tensor([t]).expand(batch_size).int() # batch size of 1
      pred_node_noise, pred_edge_noise = self.forward(nodes, edges, time)
      # Normalize output into probabilities
      pred_node_noise[:,:,0] = F.sigmoid(input = pred_node_noise[:,:,0])
      pred_node_noise[:,:,1:6] = F.softmax(input = pred_node_noise[:,:,1:6], dim = 2)
      pred_edge_noise[:,:,:,0:4] = F.softmax(input = pred_edge_noise[:,:,:,0:4], dim = 3)
      pred_edge_noise[:,:,:,4:8] = F.softmax(input = pred_edge_noise[:,:,:,4:8], dim = 3)
      pred_edge_noise[:,:,:,8:] = F.softmax(input = pred_edge_noise[:,:,:,8:], dim = 3)

      nodes, edges = self.reverse_step(nodes, edges, pred_node_noise, pred_edge_noise, t)
    return nodes, edges
  @torch.no_grad()
  def noise(self, nodes, edges):
    nodes, edges, _ = self.noise_scheduler(nodes, edges, self.max_timestep)
    return nodes, edges
  
  @torch.no_grad()
  def reverse_step(self, nodes, edges, pred_node_noise, pred_edge_noise, timestep):
    # IsConstructible denoising
    nodes[:,:,0] = self.noise_scheduler.apply_bernoulli_posterior_step(nodes[:,:,0], pred_node_noise[:,:,0], timestep)
    # Primitive Types denoising
    nodes[:,:,1:6] = self.noise_scheduler.apply_multinomial_posterior_step(nodes[:,:,1:6], pred_node_noise[:,:,1:6], timestep)
    # Primitive parameters denoising
    nodes[:,:,6:] = self.noise_scheduler.apply_gaussian_posterior_step(nodes[:,:,6:], pred_node_noise[:,:,6:], timestep)
    # Subnode A denoising
    edges[:,:,:,0:4] = self.noise_scheduler.apply_multinomial_posterior_step(edges[:,:,:,0:4], pred_edge_noise[:,:,:,0:4], timestep)
    # Subnode B denoising
    edges[:,:,:,4:8] = self.noise_scheduler.apply_multinomial_posterior_step(edges[:,:,:,4:8], pred_edge_noise[:,:,:,4:8], timestep)
    # Constraint Types denoising
    edges[:,:,:,8:] = self.noise_scheduler.apply_multinomial_posterior_step(edges[:,:,:,8:], pred_edge_noise[:,:,:,8:], timestep)
    return nodes, edges
  
  # @torch.no_grad()
  # def forward_step(self, nodes, edges, timestep):
  #   bool_class_probs = torch.cat(((1 - nodes[:,:,0]).unsqueeze(-1), nodes[:,:,0].unsqueeze(-1)), dim = -1) # batch_size x num_nodes x 2 i.e. [p(fail), p(success)]
  #   nodes[:,:,0] = (bool_class_probs @ self.noise_scheduler.get_transition_matrix(2, timestep)).reshape(-1, 2).multinomial(1).reshape(nodes[:,:,0].size()).float()
  #   # Primitive Types noising
  #   nodes[:,:,1:6] = self.noise_scheduler.sample_discrete(nodes[:,:,1:6] @ self.noise_scheduler.get_transition_matrix(5, timestep)).float()
  #   # Primitive parameters noising
  #   nodes[:,:,6:] = self.noise_scheduler.get_transition_noise(nodes[:,:,6:], timestep)
  #   # Subnode A noising
  #   edges[:,:,:,0:4] = self.noise_scheduler.sample_discrete(edges[:,:,:,0:4] @ self.noise_scheduler.get_transition_matrix(4, timestep)).float()
  #   # Subnode B noising
  #   edges[:,:,:,4:8] = self.noise_scheduler.sample_discrete(edges[:,:,:,4:8] @ self.noise_scheduler.get_transition_matrix(4, timestep)).float()
  #   # Constraint Types noising
  #   edges[:,:,:,8:] = self.noise_scheduler.sample_discrete(edges[:,:,:,8:] @ self.noise_scheduler.get_transition_matrix(9, timestep)).float()
  #   return nodes, edges

class TimeEmbedder(nn.Module):
  def __init__(self, max_timestep : int, embedding_dimension : int, device : torch.device):
    super().__init__()
    self.device = device
    self.embed_dim = embedding_dimension
    self.max_steps = max_timestep + 1
    self.max_timestep = max_timestep
      
    timesteps = torch.arange(self.max_steps, device = self.device).unsqueeze(1) # num_timesteps x 1
    scales = torch.exp(torch.arange(0, self.embed_dim, 2, device = self.device) * (-math.log(10000.0) / self.embed_dim)).unsqueeze(0) # 1 x (embedding_dimension // 2)
    self.time_embs = torch.zeros(self.max_steps, self.embed_dim, device = self.device) # num_timesteps x embedding_dimension
    self.time_embs[:, 0::2] = torch.sin(timesteps * scales) # fill even columns with sin(timestep * 1000^-(2*i/embedding_dimension))
    self.time_embs[:, 1::2] = torch.cos(timesteps * scales) # fill odd columns with cos(timestep * 1000^-(2*i/embedding_dimension))
      
  def forward(self, timestep : Tensor):
    return self.time_embs[timestep] # batch_size x embedding_dimension

class CosineNoiseScheduler(nn.Module):
    def __init__(self, max_timestep : int, device : torch.device):
        super().__init__()
        self.device = device
        self.max_steps = max_timestep + 1
        self.max_timestep = max_timestep
        self.offset = .008 # Fixed offset to improve noise prediction at early timesteps
        # Cosine Beta Schedule Formula: https://arxiv.org/abs/2102.09672     1.00015543316 is 1/a(0), for offset = .008
        self.cumulative_precisions = torch.cos((torch.linspace(0, 1, self.max_steps).to(self.device) + self.offset) * 0.5 * math.pi / (1 + self.offset)) ** 2 * 1.00015543316
        self.cumulative_variances = 1 - self.cumulative_precisions
        self.variances = torch.cat([torch.Tensor([0]).to(self.device), 1 - (self.cumulative_precisions[1:] / self.cumulative_precisions[:-1])]).clamp(.0001, .9999)
        self.precisions = 1 - self.variances
        self.sqrt_cumulative_precisions = torch.sqrt(self.cumulative_precisions)
        self.sqrt_cumulative_variances = torch.sqrt(self.cumulative_variances)
        self.sqrt_precisions = torch.sqrt(self.precisions)
        self.sqrt_variances = torch.sqrt(self.variances)
        self.sqrt_posterior_variances = torch.cat([torch.Tensor([0]).to(self.device), torch.sqrt(self.variances[1:] * self.cumulative_variances[:-1] / self.cumulative_variances[1:])])

    def forward(self, nodes : Tensor, edges : Tensor, timestep : Tensor):
      ''' Apply noise to graph '''
      noisy_nodes = torch.zeros(size = nodes.size(), device = nodes.device)
      noisy_edges = torch.zeros(size = edges.size(), device = edges.device)

      # nodes = batch_size x num_nodes x NODE_FEATURE_DIMENSION ; edges = batch_size x num_nodes x num_nodes x EDGE_FEATURE_DIMENSION
      bernoulli_is_constructible = nodes[:,:,0] # batch_size x num_nodes x 1
      categorical_primitive_types = nodes[:,:,1:6] # batch_size x num_nodes x 5
      gaussian_primitive_parameters = nodes[:,:,6:] # batch_size x num_nodes x 14
      # subnode just means if the constraint applies to the start, center, or end of a primitive
      categorical_subnode_a_types = edges[:,:,:,0:4] # batch_size x num_nodes x 4
      categorical_subnode_b_types = edges[:,:,:,4:8] # batch_size x num_nodes x 4
      categorical_constraint_types = edges[:,:,:,8:] # batch_size x num_nodes x 9

      # IsConstructible noise
      noisy_nodes[:,:,0] = self.apply_binary_noise(bernoulli_is_constructible, timestep)
      # Primitive Types noise
      noisy_nodes[:,:,1:6] = self.apply_discrete_noise(categorical_primitive_types, timestep) # noised_primitive_types
      # Primitive parameters noise
      gaussian_noise = torch.randn_like(gaussian_primitive_parameters) # standard gaussian noise
      noisy_nodes[:,:,6:] = self.apply_gaussian_noise(gaussian_primitive_parameters, timestep, gaussian_noise)
      # Subnode A noise
      noisy_edges[:,:,:,0:4] = self.apply_discrete_noise(categorical_subnode_a_types, timestep) # noised_subnode_a_types
      # Subnode B noise
      noisy_edges[:,:,:,4:8] = self.apply_discrete_noise(categorical_subnode_b_types, timestep) # noised_subnode_a_types
      # Constraint Types noise
      noisy_edges[:,:,:,8:] = self.apply_discrete_noise(categorical_constraint_types, timestep) # noised_constraint_types

      return noisy_nodes, noisy_edges, gaussian_noise

    def get_transition_noise(self, parameters : Tensor, timestep : int, gaussian_noise : Tensor = None):
      if gaussian_noise is None:
        gaussian_noise = torch.randn_like(parameters) # standard gaussian noise
      return self.sqrt_precisions[timestep] * parameters + self.sqrt_variances[timestep] * gaussian_noise

    def apply_gaussian_noise(self, parameters : Tensor, timestep : Tensor | int, gaussian_noise : Tensor = None):
      if gaussian_noise is None:
        gaussian_noise = torch.randn_like(parameters) # standard gaussian noise
      
      if type(timestep) is int: timestep = [timestep]
      # parameters shape is batch_size x num_nodes x num_params
      # gaussian_noise shape is batch_size x num_nodes x num_params
      batched_precisions = self.sqrt_cumulative_precisions[timestep,None,None] # (b,1,1) or (1,1,1)
      batched_variances = self.sqrt_cumulative_variances[timestep,None,None]   # (b,1,1) or (1,1,1)
      return batched_precisions * parameters + batched_variances * gaussian_noise

    def apply_gaussian_posterior_step(self, parameters : Tensor, pred_noise : Tensor, timestep : int):
      var = self.variances[timestep]
      sqrt_cumulative_var = self.sqrt_cumulative_variances[timestep]
      sqrt_precision = self.sqrt_precisions[timestep]
      
      denoised_mean = (parameters - pred_noise * var / sqrt_cumulative_var) / sqrt_precision
      if timestep > 1:
        gaussian_noise = torch.randn_like(parameters)
        return denoised_mean + gaussian_noise * self.sqrt_posterior_variances[timestep]
      else:
        return denoised_mean

    def get_transition_matrix(self, dimension : int, timestep : int | Tensor):
      if type(timestep) is int: assert timestep > 0; timestep = [timestep]
      batched_precisions = self.sqrt_precisions[timestep,None,None] # (batch_size, 1, 1) or (1, 1, 1)
      return batched_precisions * torch.eye(dimension, dtype = torch.float32, device = self.device) + (1 - batched_precisions) / dimension # (batch_size, d, d) or (1, d, d)

    def get_cumulative_transition_matrix(self, dimension : int, timestep : int | Tensor):
      if type(timestep) is int: assert timestep > 0; timestep = [timestep]
      batched_precisions = self.sqrt_cumulative_precisions[timestep,None,None] # (batch_size, 1, 1) or (1, 1, 1)
      return batched_precisions * torch.eye(dimension, dtype = torch.float32, device = self.device) + (1 - batched_precisions) / dimension # (batch_size, d, d) or (1, d, d)
    
    def get_posterior_transition_matrix(self, xt : Tensor, timestep : Tensor | int) -> torch.Tensor:
      xt_size, xt = self.flatten_middle(xt) # (b, n, d) or (b, n * n, d), for convenience let m = n or n * n
      d = xt_size[-1]

      qt = xt @ self.get_transition_matrix(d, timestep).permute(0, 2, 1) # (b, m, d), since xt is onehot we are plucking out rows corresponding to p(x_t = class | x_(t-1))
      qt_bar = xt @ self.get_cumulative_transition_matrix(d, timestep).permute(0, 2, 1) # (b, m, d), since xt is onehot we are plucking out rows corresponding to p(x_t = class | x_0)

      q = qt.unsqueeze(2) / qt_bar.unsqueeze(3) # (b, m, d, d), perform an outer product so element at (b, m, i, j) = p(x_t = class | x_(t-1) = j) / p(x_t = class | x_0 = i)
      q = q * self.get_cumulative_transition_matrix(d, timestep - 1).unsqueeze(1) # (b, m, d, d), broadcast multiply so element at (b, m, i, j) = p(x_t = class | x_(t-1) = j) * p(x_(t-1) = j | x_0 = i) / p(x_t = class | x_0 = i)

      return q.view(size = xt_size + (d,)) # reshape into (b, n, d, d) or (b, n, n, d, d)

    def apply_discrete_noise(self, x_one_hot : Tensor, timestep : Tensor | int):
      size, x = self.flatten_middle(x_one_hot)
      q = self.get_cumulative_transition_matrix(size[-1], timestep) # (b, d, d) or (1, d, d)
      distribution = x @ q # (b, n, d) or (b, n * n, d)
      distribution = distribution.view(size) # (b, n, d) or (b, n, n, d)
      return self.sample_discrete_distribution(distribution).float()
    
    def apply_multinomial_posterior_step(self, classes_one_hot : Tensor, pred_class_probs : Tensor, timestep : int):
      # classes_one_hot = (b, n, d) or (b, n, n, d)
      # pred_class_probs = (b, n, d) or (b, n, n, d)
      if timestep > 1:
        q = self.get_posterior_transition_matrix(classes_one_hot, timestep) # (b, n, d, d) or (b, n, n, d, d)
        pred_class_probs = pred_class_probs.unsqueeze(-2) # (b, n, 1, d) or (b, n, n, 1, d), make probs into row vector
        posterior_distribution = pred_class_probs @ q # (b, n, 1, d) or (b, n, n, 1, d), batched vector-matrix multiply
        posterior_distribution = posterior_distribution.squeeze(-2) # (b, n, d) or (b, n, n, d)
        return self.sample_discrete_distribution(posterior_distribution).float()
      else:
        return pred_class_probs
      
    def apply_binary_noise(self, boolean_flag : Tensor, timestep : int | Tensor):
      boolean_flag = boolean_flag.unsqueeze(-1)
      one_hot = torch.cat([1 - boolean_flag, boolean_flag], dim = -1) # (b, n, 2)
      noised_one_hot = self.apply_discrete_noise(one_hot, timestep) # (b, n, 2)
      return noised_one_hot[...,1] # (b, n)

    def apply_bernoulli_posterior_step(self, boolean_flag : Tensor, pred_boolean_prob : Tensor, timestep : int):
      if timestep > 1:
        boolean_flag = boolean_flag.unsqueeze(-1) # b, n, 1
        pred_boolean_prob = pred_boolean_prob.unsqueeze(-1) # b, n, 1
        one_hot_xt = torch.cat([1 - boolean_flag, boolean_flag], dim = -1) # (b, n, 2)
        probs = torch.cat([1 - pred_boolean_prob, pred_boolean_prob], dim = -1) # (b, n, 2)
        noised_one_hot = self.apply_multinomial_posterior_step(one_hot_xt, probs, timestep) # (b, n, 2)
        return noised_one_hot[...,1] # (b, n)
      else:
        return pred_boolean_prob
      
    def sample_discrete_distribution(self, tensor : Tensor):
       size = tensor.size()
       num_classes = size[-1]
       return F.one_hot(tensor.reshape(-1, num_classes).multinomial(1), num_classes).reshape(size)
    
    def flatten_middle(self, x : Tensor):
      prev_size = x.size() # shape of x_one_hot is (b, n, d) or (b, n, n, d)
      return prev_size, x.view(prev_size[0], -1, prev_size[-1]) # (b, n, d) or (b, n * n, d)

# Graph Transformer Layer outlined by DiGress Graph Diffusion
class TransformerLayer(nn.Module):
    def __init__(self, num_heads : int, node_dim : int, edge_dim : int, device : torch.device):
        super().__init__()
        self.num_heads = num_heads
        self.node_dim = node_dim
        self.edge_dim = edge_dim

        self.mlp_add_embs = nn.Sequential(nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = device),)
        self.mlp_mul_embs = nn.Sequential(nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = device),)
        
        self.attention_heads = MultiHeadAttention(node_dim = self.node_dim, edge_dim = self.edge_dim, num_heads = self.num_heads, device = device)

        self.layer_norm_nodes = nn.Sequential(nn.LayerNorm(normalized_shape = self.node_dim, device = device),
                                              nn.LeakyReLU(.1)
                                             )
        self.layer_norm_edges = nn.Sequential(nn.LayerNorm(normalized_shape = self.edge_dim, device = device),
                                              # nn.Dropout(p = 0.1)
                                             )

        self.mlp_nodes = nn.Sequential(nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = device),
                                       nn.LeakyReLU(.1),
                                    #    nn.Dropout(p = 0.1),
                                       nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = device),
                                      )
        
        self.mlp_edges = nn.Sequential(nn.Linear(in_features = self.edge_dim, out_features = self.edge_dim, device = device),
                                       nn.LeakyReLU(.1),
                                    #    nn.Dropout(p = 0.1),
                                       nn.Linear(in_features = self.edge_dim, out_features = self.edge_dim, device = device),
                                      )
        
        self.layer_norm_nodes2 = nn.Sequential(nn.LayerNorm(normalized_shape = self.node_dim, device = device),
                                               nn.LeakyReLU(.1)
                                              )
        self.layer_norm_edges2 = nn.Sequential(nn.LayerNorm(normalized_shape = self.edge_dim, device = device),
                                              #  nn.Dropout(p = 0.1)
                                              )
    
    def forward(self, nodes : Tensor, edges : Tensor, time_embs : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        # Incorporate timestep information
        shifts = self.mlp_add_embs(time_embs)
        scales = self.mlp_mul_embs(time_embs)

        nodes = nodes + shifts.unsqueeze(1) + nodes * scales.unsqueeze(1) # batch_size x num_nodes x node_dim
        edges = edges + shifts.unsqueeze(1).unsqueeze(1) + edges * scales.unsqueeze(1).unsqueeze(1) # batch_size x num_nodes x num_nodes x node_dim
        # Perform multi head attention
        attn_nodes, attn_edges = self.attention_heads(nodes, edges)

        # Layer normalization with a skip connection
        attn_nodes = self.layer_norm_nodes(attn_nodes + nodes) # batch_size x num_nodes x node_dim
        attn_edges = self.layer_norm_edges(attn_edges + edges) # batch_size x num_nodes x num_nodes x edge_dim
        del nodes
        del edges

        # MLP out
        new_nodes = self.mlp_nodes(attn_nodes) # batch_size x num_nodes x node_dim
        new_edges = self.mlp_edges(attn_edges) # batch_size x num_nodes x num_nodes x edge_dim

        # Second layer normalization with a skip connection
        new_nodes = self.layer_norm_nodes2(new_nodes + attn_nodes) # batch_size x num_nodes x node_dim
        new_edges = self.layer_norm_edges2(new_edges + attn_edges) # batch_size x num_nodes x num_nodes x edge_dim
        del attn_nodes
        del attn_edges

        new_embs = F.relu(shifts + scales + time_embs)
        return new_nodes, new_edges, new_embs

# Outer Product Attention Head
class MultiHeadAttention(nn.Module):
    def __init__(self, node_dim : int, edge_dim : int, num_heads : int, device : torch.device):
        super().__init__()
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.num_heads = num_heads
        self.attn_dim = node_dim // num_heads

        self.lin_query = nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = device)
        self.lin_key = nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = device)
        self.lin_value = nn.Sequential(nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = device),
                                      #  nn.LeakyReLU(.1),
                                      #  nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = device)
                                      )

        self.lin_mul = nn.Sequential(nn.Linear(in_features = self.edge_dim, out_features = self.node_dim, device = device),
                                    #  nn.GELU(approximate='tanh'),
                                    #  nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = device)
                                    )
        self.lin_add = nn.Sequential(nn.Linear(in_features = self.edge_dim, out_features = self.node_dim, device = device),
                                    #  nn.GELU(approximate='tanh'),
                                    #  nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = device)
                                    )
        #self.edge_film = FiLM(self.edge_dim, self.node_dim, device = device)

        self.lin_nodes_out = nn.Sequential(
                                           nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = device),
                                          #  nn.LeakyReLU(.1),
                                          #  nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = device)
                                          )
        self.lin_edges_out = nn.Sequential(nn.LeakyReLU(.1),
                                           nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = device),
                                          #  nn.LeakyReLU(.1),
                                          #  nn.Linear(in_features = self.node_dim, out_features = self.edge_dim, device = device)
                                          )

    def forward(self, nodes : Tensor, edges : Tensor):
        batch_size, num_nodes, _ = nodes.size()
        
        # Outer Product Attention -------
        queries = self.lin_query(nodes).view(batch_size, num_nodes, self.num_heads, -1) # batch_size x num_nodes x num_heads x attn_dim
        keys = self.lin_key(nodes).view(batch_size, num_nodes, self.num_heads, -1)      # batch_size x num_nodes x num_heads x attn_dim
        # queries = queries.unsqueeze(2)                            # batch_size x num_nodes x 1 x num_heads x attn_dim 
        # keys = keys.unsqueeze(1)                                  # batch_size x 1 x num_nodes x num_heads x attn_dim 
        attention = queries.unsqueeze(2) * keys.unsqueeze(1) / math.sqrt(self.node_dim) # batch_size x num_nodes x num_nodes x num_heads x attn_dim
        del queries
        del keys

        # Condition attention based on edge features
        edges_mul = self.lin_mul(edges).view(batch_size, num_nodes, num_nodes, self.num_heads, -1) # batch_size x num_nodes x num_nodes x num_heads x attn_dim
        edges_add = self.lin_add(edges).view(batch_size, num_nodes, num_nodes, self.num_heads, -1) # batch_size x num_nodes x num_nodes x num_heads x attn_dim
        del edges
        new_edges = attention * edges_mul + attention + edges_add # batch_size x num_nodes x num_nodes x num_heads x attn_dim
        del edges_add
        del edges_mul
        
        # Normalize attention
                                                                           # batch_size x num_nodes x num_nodes x num_heads (Finish dot product)
        attention = torch.softmax(input = new_edges.sum(dim = 4), dim = 2) # batch_size x num_nodes x num_nodes x num_heads (softmax) 

        # Weight node representations and sum
        values = self.lin_value(nodes).view(batch_size, num_nodes, self.num_heads, -1)  # batch_size x num_nodes x num_heads x attn_dim
        del nodes
                                                                                                             # batch_size x num_nodes x num_heads x attn_dim
        weighted_values = (attention.unsqueeze(4) * values.unsqueeze(1)).sum(dim = 2).flatten(start_dim = 2) # batch_size x num_nodes x node_dim
        del values
        # Flatten attention heads
        new_edges = new_edges.flatten(start_dim = 3)
        # weighted_values = weighted_values.flatten(start_dim = 2)
        
        # Combine attention heads
        new_nodes = self.lin_nodes_out(weighted_values)
        new_edges = self.lin_edges_out(new_edges)

        return new_nodes, new_edges
    
class SoftAttentionLayer(nn.Module):
  def __init__(self, dim : int, num_heads : int, device : torch.device):
    super().__init__()
    self.device = device
    self.node_dim = dim
    self.edge_dim = dim
    self.num_heads = num_heads
    self.attn_dim = dim // num_heads
    self.num_nodes = MAX_NUM_PRIMITIVES

    concat_dim = 2 * self.node_dim + self.edge_dim
    self.mlp_haggr_weights = nn.Sequential(nn.Linear(in_features = concat_dim, out_features = concat_dim, device = self.device),
                                           nn.LeakyReLU(.1),
                                           nn.Linear(in_features = concat_dim, out_features = 1, device = self.device),
                                           nn.Softmax(dim = 2)
                                          )
    
    self.mlp_haggr_values = nn.Sequential(nn.Linear(in_features = concat_dim, out_features = concat_dim, device = self.device),
                                           nn.LeakyReLU(.1),
                                           nn.Linear(in_features = concat_dim, out_features = self.node_dim, device = self.device),
                                         )

    self.query_key_value_mlp = nn.Linear(in_features = self.node_dim, out_features = 3 * self.node_dim, device = self.device)

    self.layer_norm_embs = nn.Sequential(nn.LayerNorm(normalized_shape = self.node_dim, device = self.device),
                                         nn.LeakyReLU(.1),
                                        )

    self.node_mlp = nn.Sequential(nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = self.device),
                                  nn.LeakyReLU(.1),
                                  nn.Linear(in_features = self.node_dim, out_features = self.node_dim, device = self.device),
                                 )
    self.edge_mlp = nn.Sequential(nn.Linear(in_features = 2 * self.node_dim, out_features = 2 * self.node_dim, device = self.device),
                                  nn.LeakyReLU(.1),
                                  nn.Linear(in_features = 2 * self.node_dim, out_features = self.node_dim, device = self.device),
                                 )
    
    self.layer_norm_out_nodes = nn.Sequential(nn.LayerNorm(normalized_shape = self.node_dim, device = self.device),
                                              nn.LeakyReLU(.1),
                                             )
    
    self.layer_norm_out_edges = nn.Sequential(nn.LayerNorm(normalized_shape = self.node_dim, device = self.device),
                                              nn.LeakyReLU(.1),
                                             )
    
  def forward(self, nodes : Tensor, edges : Tensor) -> Tuple[Tensor, Tensor]:
    # Outer Product Concatenation
    hstack = nodes.unsqueeze(2).expand(-1, -1, self.num_nodes, -1) # (b, n, n, d)
    vstack = hstack.permute(0, 2, 1, 3) # (b, n, n, d)
    graph_features = torch.cat(tensors = (hstack, vstack, edges), dim = 3) # (b, n, n, 3 * d)

    # Soft Attentional Encoder
    haggr_weights = self.mlp_haggr_weights(graph_features).permute(0, 1, 3, 2) # (b, n, 1, n)
    haggr_values = self.mlp_haggr_values(graph_features) # (b, n, n, d)
    graph_embs = (haggr_weights @ haggr_values).squeeze(2) # (b, n, d)

    # Low Dimensional Attention
    b, n, d = graph_embs.size()
    query_key_value = self.query_key_value_mlp(graph_embs).view(b, n, self.num_heads, 3 * self.attn_dim).permute(0, 2, 1, 3) # (b, h, n, 3 * attn_dim)
    queries, keys, values = query_key_value.reshape(b * self.num_heads, n, 3 * self.attn_dim).chunk(3, dim = 2) # (b * h, n, attn_dim) is shape for the three tensors
    attn_embs = F.scaled_dot_product_attention(queries, keys, values).view(b, self.num_heads, n, self.attn_dim) # (b * h, n, attn_dim)
    attn_embs = attn_embs.permute(0, 2, 1, 3).reshape(b, n, self.node_dim) # (b, n, d)

    # Residual Connection and LayerNorm
    self.layer_norm_embs(attn_embs + graph_embs)

    # Outer Product Decoder
    emb_hstack = attn_embs.unsqueeze(2).expand(-1, -1, self.num_nodes, -1) # (b, n, n, d)
    emb_vstack = emb_hstack.permute(0, 2, 1, 3) # (b, n, n, d)
    emb_edges = torch.cat(tensors = (emb_hstack, emb_vstack), dim = 3) # (b, n, n, 2 * d)

    new_edges = self.edge_mlp(emb_edges) # (b, n, n, d)
    new_nodes = self.node_mlp(attn_embs) # (b, n, d)

    # Residual Connection and LayerNorm
    new_edges = self.layer_norm_out_nodes(new_edges + new_edges)
    new_nodes = self.layer_norm_out_nodes(new_nodes + nodes)

    return new_nodes, new_edges

In [87]:
scheduler = CosineNoiseScheduler(max_timestep = 500, device = 'cpu')

In [None]:
# T = 500
# offset = .008
# precisions = torch.cos((torch.linspace(0, 1, T + 1) + offset) * 0.5 * math.pi / (1 + offset)) ** 2
plt.plot(torch.linspace(0, scheduler.max_timestep, scheduler.max_steps), scheduler.sqrt_posterior_variances)
scheduler.sqrt_posterior_variances[0]

In [3]:
model = GD3PM(device = 2)

In [7]:
nodes = torch.randn(2, 24, 32).to(2)
edges = torch.randn(2, 24, 24, 32).to(2)

layer = SoftAttentionLayer(32, 8, 2)

layer(nodes, edges)

(tensor([[[-7.1943e-02,  3.8405e-02, -3.8214e-03,  ...,  8.6844e-01,
           -8.0635e-03,  5.0961e-01],
          [ 4.4182e-01,  3.8175e-01,  7.4728e-01,  ..., -4.6879e-02,
           -6.0337e-02, -4.6395e-02],
          [ 1.8042e+00, -3.1544e-02, -2.0649e-01,  ...,  1.4710e+00,
            1.3961e+00,  5.4910e-01],
          ...,
          [ 1.0218e+00,  1.1989e+00,  1.1728e+00,  ..., -9.2041e-03,
           -1.0685e-02, -2.6172e-02],
          [-9.1021e-02,  1.6088e-03,  1.0256e+00,  ..., -9.2267e-02,
            1.0816e+00, -1.2574e-01],
          [ 1.8320e+00, -1.7733e-02, -3.9278e-02,  ...,  6.8177e-01,
           -2.4127e-02,  9.4185e-01]],
 
         [[ 2.2566e+00,  1.1865e+00, -3.6246e-02,  ..., -2.0274e-01,
           -1.3671e-01, -6.7719e-02],
          [-8.2890e-02, -1.2020e-01,  6.7462e-01,  ...,  1.1532e-01,
           -5.6733e-02,  8.8157e-01],
          [-6.2868e-02, -1.9113e-02,  9.2896e-01,  ..., -3.1593e-03,
            5.5211e-01, -9.2183e-02],
          ...,
    

In [42]:
true_nodes = torch.cat((dataset.nodes[0].unsqueeze(0), dataset.nodes[1].unsqueeze(0)), dim = 0)
true_edges = torch.cat((dataset.edges[0].unsqueeze(0), dataset.edges[1].unsqueeze(0)), dim = 0)
timesteps = torch.randint(0, scheduler.max_steps, (2,))
nodes, edges, _ = scheduler(true_nodes, true_edges, timesteps)

In [5]:
dataset = SketchDataset('data/')

In [6]:
nodes = dataset.nodes[0]
edges = dataset.edges[0]
noisy_nodes, noisy_edges, _ = scheduler(nodes.unsqueeze(0), edges.unsqueeze(0), 0)
sketch = SketchDataset.preds_to_sketch(nodes, edges)
noised_sketch = SketchDataset.preds_to_sketch(noisy_nodes.squeeze(0), noisy_edges.squeeze(0))

In [6]:
import io

def fig_to_tensor(fig):
    with io.BytesIO() as buff:
        fig.savefig(buff, format='raw')
        buff.seek(0)
        data = torch.frombuffer(buff.getvalue(), dtype=torch.uint8)
    w, h = fig.canvas.get_width_height()
    plt.close()
    return data.reshape((int(h), int(w), -1)).permute(2, 0, 1)

In [7]:
import torch
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(f'runs2/new_test')

In [37]:
state_dict = torch.load(f"model_checkpoint_gd3pm_ddp_Adam_mse-25_kld-.001_24layers16heads256hidden.pth")

from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove 'module.' of dataparallel
    new_state_dict[name]=v

model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [None]:
scheduler = CosineNoiseScheduler(max_timestep = 500, device = 2)

vid_tensor = torch.zeros(size = (2, scheduler.max_steps, 4, 480, 640))
nodes = torch.cat([dataset.nodes[0].unsqueeze(0), dataset.nodes[1].unsqueeze(0)], dim = 0) # 2 x num_nodes x node_features
edges = torch.cat([dataset.edges[0].unsqueeze(0), dataset.edges[1].unsqueeze(0)], dim = 0) # 2 x num_nodes x num_nodes x edge_features
nodes = nodes.to(2)
edges = edges.to(2)

for timestep in range(1, scheduler.max_steps):
    nodes, edges, _ = scheduler(nodes, edges, timestep)
    
    vid_tensor[0][timestep] = fig_to_tensor(datalib.render_sketch(SketchDataset.preds_to_sketch(nodes[0].squeeze(0), edges[0].squeeze(0))))
    vid_tensor[1][timestep] = fig_to_tensor(datalib.render_sketch(SketchDataset.preds_to_sketch(nodes[1].squeeze(0), edges[1].squeeze(0))))


writer.add_video("Forward Process 2", vid_tensor, 0)

In [69]:
def reverse_test(noise_scheduler, nodes, edges, true_nodes, true_edges, timestep):
    # IsConstructible denoising
      nodes[:,:,0] = noise_scheduler.apply_bernoulli_posterior_step(nodes[:,:,0], true_nodes[:,:,0], timestep)
      # Primitive Types denoising
      nodes[:,:,1:6] = noise_scheduler.apply_multinomial_posterior_step(nodes[:,:,1:6], true_nodes[:,:,1:6], timestep)
      # Primitive parameters denoising
      true_noise = (nodes[:,:,6:] - noise_scheduler.sqrt_cumulative_precisions[timestep] * true_nodes[:,:,6:]) / noise_scheduler.sqrt_cumulative_variances[timestep]
      nodes[:,:,6:] = noise_scheduler.apply_gaussian_posterior_step(nodes[:,:,6:], true_noise, timestep)
      # Subnode A denoising
      edges[:,:,:,0:4] = noise_scheduler.apply_multinomial_posterior_step(edges[:,:,:,0:4], true_edges[:,:,:,0:4], timestep)
      # Subnode B denoising
      edges[:,:,:,4:8] = noise_scheduler.apply_multinomial_posterior_step(edges[:,:,:,4:8], true_edges[:,:,:,4:8], timestep)
      # Constraint Types denoising
      edges[:,:,:,8:] = noise_scheduler.apply_multinomial_posterior_step(edges[:,:,:,8:], true_edges[:,:,:,8:], timestep)
      return nodes, edges


In [9]:
vid_tensor = torch.zeros(size = (2, model.noise_scheduler.max_steps, 4, 480, 640))
nodes = torch.cat([dataset.nodes[0].unsqueeze(0), dataset.nodes[1].unsqueeze(0)], dim = 0) # 2 x num_nodes x node_features
edges = torch.cat([dataset.edges[0].unsqueeze(0), dataset.edges[1].unsqueeze(0)], dim = 0) # 2 x num_nodes x num_nodes x edge_features
true_nodes = nodes.to(2)
true_edges = edges.to(2)

nodes, edges = model.noise(true_nodes, true_edges)
last_step = model.noise_scheduler.max_steps - 1
for timestep in reversed(range(1, model.noise_scheduler.max_steps)):
    true_noise = (nodes[:,:,6:] - model.noise_scheduler.sqrt_cumulative_precisions[timestep] * true_nodes[:,:,6:]) / model.noise_scheduler.sqrt_cumulative_variances[timestep]
    temp_nodes = true_nodes.clone().to(2)
    temp_nodes[:,:,6:] = true_noise
    
    nodes, edges = model.reverse_step(nodes, edges, temp_nodes, true_edges, timestep)
    # Fill tensor in reverse order
    vid_tensor[0][last_step - timestep] = fig_to_tensor(datalib.render_sketch(SketchDataset.preds_to_sketch(nodes[0].squeeze(0), edges[0].squeeze(0))))
    vid_tensor[1][last_step - timestep] = fig_to_tensor(datalib.render_sketch(SketchDataset.preds_to_sketch(nodes[1].squeeze(0), edges[1].squeeze(0))))

writer.add_video("Reverse Process 2", vid_tensor, 0)

In [13]:
writer.add_video("Forward Process", vid_tensor, 0, 1)
writer.flush()

In [11]:
vid_tensor[0,0]

tensor([[[255., 255., 255.,  ..., 255., 255., 255.],
         [255., 255., 255.,  ..., 255., 255., 255.],
         [255., 255., 255.,  ..., 255., 255., 255.],
         ...,
         [255., 255., 255.,  ..., 255., 255., 255.],
         [255., 255., 255.,  ..., 255., 255., 255.],
         [255., 255., 255.,  ..., 255., 255., 255.]],

        [[255., 255., 255.,  ..., 255., 255., 255.],
         [255., 255., 255.,  ..., 255., 255., 255.],
         [255., 255., 255.,  ..., 255., 255., 255.],
         ...,
         [255., 255., 255.,  ..., 255., 255., 255.],
         [255., 255., 255.,  ..., 255., 255., 255.],
         [255., 255., 255.,  ..., 255., 255., 255.]],

        [[255., 255., 255.,  ..., 255., 255., 255.],
         [255., 255., 255.,  ..., 255., 255., 255.],
         [255., 255., 255.,  ..., 255., 255., 255.],
         ...,
         [255., 255., 255.,  ..., 255., 255., 255.],
         [255., 255., 255.,  ..., 255., 255., 255.],
         [255., 255., 255.,  ..., 255., 255., 255.]],

In [116]:
timestep = torch.randn(size = (3, 24, 2))
x = torch.randn(size = (2, 2, 2))

In [117]:
print(timestep)
print(x)

tensor([[[-1.2011e+00, -4.6220e-01],
         [-1.7614e+00,  8.0611e-01],
         [ 2.7006e-01, -1.8678e+00],
         [-8.5972e-01, -4.4823e-01],
         [ 9.2746e-01, -7.0386e-01],
         [-1.0524e+00,  1.6565e+00],
         [ 1.5218e-01,  1.2881e+00],
         [-7.5594e-02,  3.0399e-02],
         [-7.5931e-01,  2.2981e+00],
         [ 6.4255e-01,  1.4393e-01],
         [ 1.0635e-01,  2.8390e-01],
         [-4.8726e-04, -1.1606e+00],
         [-1.2173e+00, -8.3023e-01],
         [-5.5989e-01, -7.9893e-02],
         [-5.3027e-01,  1.5316e-01],
         [-1.4280e-01,  7.5005e-01],
         [ 1.2629e-02,  8.7853e-01],
         [ 6.3101e-01,  1.6780e-01],
         [-1.0147e+00,  4.6798e-01],
         [ 9.1126e-01, -4.4612e-01],
         [ 2.6410e-01, -3.6911e-02],
         [ 7.8104e-02,  7.2003e-01],
         [ 5.6776e-01,  8.4092e-01],
         [ 3.8145e-01,  1.3298e+00]],

        [[ 6.7772e-01, -7.9119e-01],
         [ 9.4839e-01, -1.5864e-02],
         [ 1.8776e+00, -2.9493e+00],

In [126]:
post = torch.einsum('bij,jkt->bikt', timestep, x)
print(post.size())

torch.Size([3, 24, 2, 2])


In [137]:
bool_flag = F.one_hot(torch.ones(3, 24, 2).reshape(3 * 24, 2).multinomial(1), 2).reshape(3, 24, 2).unsqueeze(2)
print(bool_flag.size())
out = bool_flag.float() @ post
out.size()

torch.Size([3, 24, 1, 2])


torch.Size([3, 24, 1, 2])

In [35]:
temp = torch.randn(1, 5, 2, 2)
bool_flag = torch.ones(1, 5, 2).reshape(1 * 5, 2).multinomial(1).reshape(1, 5, 1)
print(temp)
print(bool_flag)

temp.gather(2, bool_flag.expand(-1, -1, 2).unsqueeze(2)).squeeze(2)

tensor([[[[ 2.0195,  0.8963],
          [ 0.4355,  0.3487]],

         [[-1.1306, -1.1809],
          [-1.3799, -1.1867]],

         [[ 0.4282, -0.9926],
          [-0.5353, -1.1723]],

         [[ 0.4859, -0.3495],
          [ 0.6585, -1.6807]],

         [[-0.2893,  1.2866],
          [-0.8726, -0.1874]]]])
tensor([[[1],
         [0],
         [0],
         [1],
         [1]]])


tensor([[[ 0.4355,  0.3487],
         [-1.1306, -1.1809],
         [ 0.4282, -0.9926],
         [ 0.6585, -1.6807],
         [-0.8726, -0.1874]]])

In [27]:
qt = torch.randint(1, 10, (2, 2))
qt_1bar = torch.randint(1, 10, (2, 2))
qt_bar = torch.randint(1, 10, (2, 2))
print(qt)
print(qt_1bar)
print(qt_bar)

tensor([[9, 3],
        [9, 9]])
tensor([[4, 3],
        [7, 6]])
tensor([[5, 8],
        [8, 6]])


In [28]:
cond_posterior_xt_x0 = (qt.T.unsqueeze(1) * qt_1bar.unsqueeze(0)) / qt_bar.T.unsqueeze(2)
cond_posterior_xt_x0

tensor([[[7.2000, 5.4000],
         [7.8750, 6.7500]],

        [[1.5000, 3.3750],
         [3.5000, 9.0000]]])