In [1]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import trimesh
from skimage import measure
import meshplot as mp
from torch.utils.data import DataLoader, Dataset
import os
import time
from datetime import timedelta, datetime
import random
import math

In [2]:
# 4 layer autodecoder MLP class for 10 shapes 
class MLP_4(nn.Module):
    def __init__(self, n_shapes, shape_code_length, n_inner_nodes):
        super(MLP_4, self).__init__()
        self.shape_code_length = shape_code_length
        self.shape_codes = nn.Embedding(n_shapes, shape_code_length, max_norm=0.01) # shape code as an embedding # TODO: take this outside 
        
        self.linear1 = nn.Linear(3 + shape_code_length, n_inner_nodes) # (x, y, z) + shape code 
        self.linear2 = nn.Linear(n_inner_nodes, n_inner_nodes)
        self.linear3 = nn.Linear(n_inner_nodes, n_inner_nodes)
        self.linear4 = nn.Linear(n_inner_nodes, 1)
        
        self.relu = nn.ReLU()

    def forward(self, shape_idx, x):
        shape_code = self.shape_codes(shape_idx.view(1, -1))
        print(shape_code.shape)
        shape_code = shape_code.view(-1, self.shape_code_length)
        shape_code_with_xyz = torch.cat((x, shape_code), dim=1) # concatenate horizontally
        
        out = self.linear1(shape_code_with_xyz)
        out = self.relu(out)
        out = self.linear2(out)
        out = self.relu(out)
        out = self.linear3(out)
        out = self.relu(out)
        out = self.linear4(out)

        return out
    

filename = './models/autodecoder_08052022_073446' # trained for 10 shapes
model = MLP_4(10, 256, 256)
model.load_state_dict(torch.load(filename))

shape_codes = model.state_dict()['shape_codes.weight']
print(shape_codes.shape)

torch.Size([10, 256])


In [3]:
class DecoderPartial(nn.Module):
    def __init__(self,shape_code_length, n_inner_nodes):
        super(DecoderPartial, self).__init__()
        self.shape_code_length = shape_code_length
        self.linear1 = nn.Linear(3 + shape_code_length, n_inner_nodes) # (x, y, z) + shape code 
        self.linear2 = nn.Linear(n_inner_nodes, n_inner_nodes)
        self.linear3 = nn.Linear(n_inner_nodes, n_inner_nodes)
        self.linear4 = nn.Linear(n_inner_nodes, 1)
        
        self.relu = nn.ReLU()
    
    def forward(self,shape_code, x):
        shape_code = shape_code.repeat([x.size()[0],1])
        shape_code = shape_code.view(-1, self.shape_code_length)
        shape_code_with_xyz = torch.cat((x, shape_code), dim=1) # concatenate horizontally
        out = self.linear1(shape_code_with_xyz)
        out = self.relu(out)
        out = self.linear2(out)
        out = self.relu(out)
        out = self.linear3(out)
        out = self.relu(out)
        out = self.linear4(out)

        return out

In [4]:
mod = DecoderPartial(256,256)
#mod_dict = mod.state_dict()
mod.load_state_dict(torch.load(filename),strict=False)
mod.eval()
#for param_tensor in mod.state_dict():
    #print(param_tensor, "\t", mod.state_dict()[param_tensor])
# 1. filter out unnecessary keys
#pretrained_dict = {k: v for k, v in model.iteritems() if k in model_dict}
# 2. overwrite entries in the existing state dict
#model_dict.load_state_dict(pretrained_dict) 
# 3. load the new state dict
#model.load_state_dict(pretrained_dict)

DecoderPartial(
  (linear1): Linear(in_features=259, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=256, bias=True)
  (linear3): Linear(in_features=256, out_features=256, bias=True)
  (linear4): Linear(in_features=256, out_features=1, bias=True)
  (relu): ReLU()
)

## Diffusion

In [6]:
dataset = shape_codes

In [7]:
dataset.shape

torch.Size([10, 256])

In [8]:
std = np.std(np.asarray(dataset))
mean = np.mean(np.asarray(dataset))
dataset = (dataset-mean)/std

In [9]:
def make_beta_schedule(schedule='linear', n_timesteps=1000, start=1e-5, end=1e-2):
    if schedule == 'linear':
        betas = torch.linspace(start, end, n_timesteps)
    elif schedule == "quad":
        betas = torch.linspace(start ** 0.5, end ** 0.5, n_timesteps) ** 2
    elif schedule == "sigmoid":
        betas = torch.linspace(-6, 6, n_timesteps)
        betas = torch.sigmoid(betas) * (end - start) + start
    return betas

In [10]:
n_steps = 30000
betas = make_beta_schedule(schedule='linear', n_timesteps=n_steps, start=1e-5, end=1e-2)
alphas = 1 - betas
alphas_prod = torch.cumprod(alphas, 0)
alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)

In [11]:
alphas_bar_sqrt

tensor([1.0000, 1.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000])

In [12]:
def extract(input, t, x):
    shape = x.shape
    out = torch.gather(input, 0, t.to(input.device))
    reshape = [t.shape[0]] + [1] * (len(shape) - 1)
    return out.reshape(*reshape)
def q_sample(x_0, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_0)
    alphas_t = extract(alphas_bar_sqrt, t, x_0)
    alphas_1_m_t = extract(one_minus_alphas_bar_sqrt, t, x_0)
    return (alphas_t * x_0 + alphas_1_m_t * noise)

In [89]:
# shape_6_q_max_list = []
# for i in range(100):
#     shape_6_q_max_list.append(q_sample(dataset[6,:], torch.tensor([10000-1])).detach().numpy())

In [94]:
# mean1 = np.mean(shape_6_q_max_list)
# std1=np.std(shape_6_q_max_list)


In [None]:
x = np.linspace(-1, 1, 100, dtype=np.float32)
y = np.linspace(-1, 1, 100, dtype=np.float32)
z = np.linspace(-1, 1, 100, dtype=np.float32)
P = np.vstack(np.meshgrid(x,y,z)).reshape(3,-1).T  # format: [[x1, y1, z1], [x1, y1, z2], [] ...]
P = torch.from_numpy(P)


In [37]:
# inferring
q_i = std*q_sample(dataset[5,:], torch.tensor([0]))+mean
volume = mod(q_i,P).view(len(x), len(y), len(z)) 

# marching cube to visualize
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces) # plot original shape


# inferring
q_i = std*q_sample(dataset[5,:], torch.tensor([1000-1]))+mean
volume = mod(q_i,P).view(len(x), len(y), len(z)) 

# marching cube to visualize
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces) # plot original shape


# inferring
q_i = std*q_sample(dataset[5,:], torch.tensor([5000-1]))+mean
volume = mod(q_i,P).view(len(x), len(y), len(z)) 

# marching cube to visualize
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces) # plot original shape


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.498744…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.495571…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.498734…

<meshplot.Viewer.Viewer at 0x7f3cbec40e80>

In [38]:

# inferring
q_i = std*q_sample(dataset[5,:], torch.tensor([10000-1]))+mean
volume = mod(q_i,P).view(len(x), len(y), len(z)) 

# marching cube to visualize
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces) # plot shape after forward process


# inferring
q_i = std*q_sample(dataset[5,:], torch.tensor([10000-1]))+mean
volume = mod(q_i,P).view(len(x), len(y), len(z)) 

# marching cube to visualize
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces) # plot shape after forward process


# inferring
q_i = std*q_sample(dataset[5,:], torch.tensor([10000-1]))+mean
volume = mod(q_i,P).view(len(x), len(y), len(z)) 

# marching cube to visualize
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces) # plot shape after forward process


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.500711…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.497913…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.500980…

<meshplot.Viewer.Viewer at 0x7f3ccd741f10>

In [13]:
import torch.nn.functional as F

class ConditionalLinear(nn.Module):
    def __init__(self, num_in, num_out, n_steps, temb_ch):
        super(ConditionalLinear, self).__init__()
        self.num_out = num_out
        self.lin = nn.Linear(num_in, num_out)
        # self.embed = nn.Embedding(n_steps, num_out)
        # self.embed.weight.data.uniform_()

        self.temb_proj = torch.nn.Linear(temb_ch,
                                         num_out)
        self.nonlin = torch.nn.SiLU()


    def forward(self, x, y):
        out = self.lin(x)

        out = self.temb_proj(self.nonlin(y)) + out
        return out

In [14]:
class ConditionalModel(nn.Module):
    def __init__(self, n_steps, ch=32, num_out=64):
        super(ConditionalModel, self).__init__()
        self.ch = ch
        self.temb_ch = ch * 4
        self.lin1 = ConditionalLinear(256,512, n_steps, self.temb_ch)
        self.lin2 = ConditionalLinear(512, 512, n_steps, self.temb_ch)
        self.lin3 = nn.Linear(512,256)


        # timestep embedding
        self.temb = nn.Sequential(
            torch.nn.Linear(ch,
                            self.temb_ch),
            torch.nn.SiLU(),
            torch.nn.Linear(self.temb_ch,
                            self.temb_ch),
        )

    
    def forward(self, x, y): # x, t
        y = get_timestep_embedding(y, self.ch)
        temb = self.temb(y)
        x = F.softplus(self.lin1(x, temb))
        x = F.softplus(self.lin2(x, temb))
        return self.lin3(x)

def p_sample(model, x, t):
    t = torch.tensor([t])
    # Factor to the model output
    eps_factor = ((1 - extract(alphas, t, x)) / extract(one_minus_alphas_bar_sqrt, t, x))
    # Model output
    eps_theta = model(x, t)
    # Final values
    mean = (1 / extract(alphas, t, x).sqrt()) * (x - (eps_factor * eps_theta))
    # Generate z
    z = torch.randn_like(x)
    # Fixed sigma
    sigma_t = extract(betas, t, x).sqrt()
    sample = mean + sigma_t * z
    return (sample)

In [15]:
def p_sample_loop(model, shape):
    cur_x = torch.randn(shape)
    x_seq = [cur_x]
    for i in reversed(range(n_steps)):
        cur_x = p_sample(model, cur_x, i)
        x_seq.append(cur_x)
    return x_seq

In [18]:
def get_timestep_embedding(timesteps, embedding_dim):
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models:
    From Fairseq.
    Build sinusoidal embeddings.
    This matches the implementation in tensor2tensor, but differs slightly
    from the description in Section 3.5 of "Attention Is All You Need".
    """
    assert len(timesteps.shape) == 1

    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = emb.to(device=timesteps.device)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb

In [103]:
# b=[]
# for i in range(n_steps):
#     f=torch.randn_like(dataset[:5,:])
#     b.append(f[None,:,:])
# e = torch.Tensor(n_steps,5, 256)
# torch.cat(b, out=e)
# e.size()

torch.Size([30000, 5, 256])

In [46]:
# torch.manual_seed(0)
# print(torch.randn_like(torch.ones(1, 5)))
# torch.manual_seed(0)
# print(torch.randn_like(torch.ones(1, 5)))


tensor([[ 1.5410, -0.2934, -2.1788,  0.5684, -1.0845]])
tensor([[ 1.5410, -0.2934, -2.1788,  0.5684, -1.0845]])


In [109]:
# def noise_estimation_loss(model, x_0,noise_steps):
def noise_estimation_loss(model, x_0):

    batch_size = x_0.shape[0]
    # Select a random step for each example
    t = torch.randint(0, n_steps, size=(batch_size // 2 + 1,)) # pick (batch_size//2+1) number of rand integers in bw 0 and n_steps
    t = torch.cat([t, n_steps - t - 1], dim=0)[:batch_size].long() # pick other time index symmetrically
    # x0 multiplier
    a = extract(alphas_bar_sqrt, t, x_0)
    # eps multiplier
    am1 = extract(one_minus_alphas_bar_sqrt, t, x_0)
    e = torch.randn_like(x_0)
    # e = noise_steps[t, :,:]
    # model input
    x = x_0 * a + e * am1
    output = model(x, t)
    return (e - output).square().mean()


In [110]:
class EMA(object):
    def __init__(self, mu=0.999):
        self.mu = mu
        self.shadow = {}

    def register(self, module):
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self, module):
        for name, param in module.named_parameters():
            if param.requires_grad:
                self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data

    def ema(self, module):
        for name, param in module.named_parameters():
            if param.requires_grad:
                param.data.copy_(self.shadow[name].data)

    def ema_copy(self, module):
        module_copy = type(module)(module.config).to(module.config.device)
        module_copy.load_state_dict(module.state_dict())
        self.ema(module_copy)
        return module_copy

    def state_dict(self):
        return self.shadow


    def load_state_dict(self, state_dict):
        self.shadow = state_dict

In [111]:
model = ConditionalModel(n_steps)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, betas = (0.9, 0.9), eps=1e-07)
#dataset = torch.tensor(data.T).float()
dataset = dataset.float()
# Create EMA model
ema = EMA(0.9)
ema.register(model)
# Batch size
batch_size = 5

now = datetime.now() 
now = now.strftime("%m%d%Y_%H%M%S")


In [112]:
# import pandas as pd
# import matplotlib.pyplot as plt

# now = datetime.now() 
# now = now.strftime("%m%d%Y_%H%M%S")

# #log loss
# f = open(f'./diffusion logs/{now}.csv','a')
# f.write(str(loss)+'\n') 
# f.close()

# #load loss and plot
# df = pd.read_csv('./diffusion logs/test.csv', header=None)
# df.plot()
# plt.show()

In [138]:
f = open(f'./diffusion logs/{now}.csv','a')

for t in range(100000):
    # X is a torch Variable
    permutation = torch.randperm(dataset.size()[0])
    for i in range(0, dataset.size()[0], batch_size):
        # Retrieve current batch
        indices = permutation[i:i+batch_size]
        batch_x = dataset[indices]
        # Compute the loss.
        # loss = noise_estimation_loss(model, batch_x,e)
        loss = noise_estimation_loss(model, batch_x)
        # Before the backward pass, zero all of the network gradients
        optimizer.zero_grad()
        # Backward pass: compute gradient of the loss with respect to parameters
        loss.backward()
        # Perform gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
        # Calling the step function to update the parameters
        optimizer.step()
        # Update the exponential moving average
        ema.update(model)
    # Print loss
    if (t % 10 == 0):
        f.write(str(loss.item())+'\n') 
    if (t % 100 == 0):
        torch.save(model.state_dict(), f'./diffusion logs/conditional model_{now}')
        print(loss)
    if loss < 0.0012:
        torch.save(model.state_dict(), f'./diffusion logs/conditional model_{now}')
        break
f.close()



tensor(0.0017, grad_fn=<MeanBackward0>)
tensor(0.0022, grad_fn=<MeanBackward0>)
tensor(0.0016, grad_fn=<MeanBackward0>)
tensor(0.0040, grad_fn=<MeanBackward0>)
tensor(0.0534, grad_fn=<MeanBackward0>)
tensor(0.0016, grad_fn=<MeanBackward0>)
tensor(0.0022, grad_fn=<MeanBackward0>)


In [None]:
e.size()

In [141]:
loss

tensor(0.0011, grad_fn=<MeanBackward0>)

In [35]:
filename = './diffusion logs/conditional model_08182022_205351'
model = ConditionalModel(n_steps)
model.load_state_dict(torch.load(filename))
model.eval()


x_0 = dataset
shape_idx = 5
timestep = 30000-1


# x = np.linspace(-1, 1, 100, dtype=np.float32)
# y = np.linspace(-1, 1, 100, dtype=np.float32)
# z = np.linspace(-1, 1, 100, dtype=np.float32)
# P = np.vstack(np.meshgrid(x,y,z)).reshape(3,-1).T  # format: [[x1, y1, z1], [x1, y1, z2], [] ...]
# P = torch.from_numpy(P)


original_code_standardized=std*x_0[shape_idx,:]+mean
volume = mod(original_code_standardized, P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)


noised_code = q_sample(x_0[shape_idx, :], torch.tensor([timestep]))
noised_code_standardized = std*noised_code+mean
volume = mod(noised_code_standardized, P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)


denoised_code = noised_code
for i in reversed(range(timestep)):
    denoised_code = p_sample(model, denoised_code, i)
denoised_code_standardized = std*denoised_code+mean
volume = mod(denoised_code_standardized,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.498733…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(50.034333…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(48.498877…

<meshplot.Viewer.Viewer at 0x7f3cd72c3af0>

In [36]:
denoised_code = noised_code
for i in reversed(range(timestep)):
    denoised_code = p_sample(model, denoised_code, i)
denoised_code_standardized = std*denoised_code+mean
volume = mod(denoised_code_standardized,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)


denoised_code = noised_code
for i in reversed(range(timestep)):
    denoised_code = p_sample(model, denoised_code, i)
denoised_code_standardized = std*denoised_code+mean
volume = mod(denoised_code_standardized,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)


denoised_code = noised_code
for i in reversed(range(timestep)):
    denoised_code = p_sample(model, denoised_code, i)
denoised_code_standardized = std*denoised_code+mean
volume = mod(denoised_code_standardized,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)

denoised_code = noised_code
for i in reversed(range(timestep)):
    denoised_code = p_sample(model, denoised_code, i)
denoised_code_standardized = std*denoised_code+mean
volume = mod(denoised_code_standardized,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)


denoised_code = noised_code
for i in reversed(range(timestep)):
    denoised_code = p_sample(model, denoised_code, i)
denoised_code_standardized = std*denoised_code+mean
volume = mod(denoised_code_standardized,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(50.998323…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(50.495885…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.984013…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.094991…

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(50.684421…

<meshplot.Viewer.Viewer at 0x7f3ccd688100>

In [47]:
filename = './diffusion logs/conditional model_08182022_205351'
model = ConditionalModel(n_steps)
model.load_state_dict(torch.load(filename))
model.eval()


x_0 = dataset
shape_idx = 5
timestep = 10000-1


# x = np.linspace(-1, 1, 100, dtype=np.float32)
# y = np.linspace(-1, 1, 100, dtype=np.float32)
# z = np.linspace(-1, 1, 100, dtype=np.float32)
# P = np.vstack(np.meshgrid(x,y,z)).reshape(3,-1).T  # format: [[x1, y1, z1], [x1, y1, z2], [] ...]
# P = torch.from_numpy(P)

original_code_standardized=std*x_0[shape_idx,:]+mean
volume = mod(original_code_standardized, P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
print('original')
mp.plot(verts, faces)


noised_code = q_sample(x_0[shape_idx, :], torch.tensor([timestep]))
noised_code_standardized = std*noised_code+mean
volume = mod(noised_code_standardized, P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
print('noisy')
mp.plot(verts, faces)


denoised_code = noised_code
for i in reversed(range(timestep)):
    denoised_code = p_sample(model, denoised_code, i)
denoised_code_standardized = std*denoised_code+mean
volume = mod(denoised_code_standardized,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
print('denoised')
mp.plot(verts, faces)



denoised_code = noised_code
for i in reversed(range(timestep)):
    denoised_code = p_sample(model, denoised_code, i)
denoised_code_standardized = std*denoised_code+mean
volume = mod(denoised_code_standardized,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
print('denoised')
mp.plot(verts, faces)

original


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.498733…

noisy


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.499081…

denoised


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(47.513990…

denoised


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(48.005763…

<meshplot.Viewer.Viewer at 0x7f3cd72adb80>

In [51]:
filename = './diffusion logs/conditional model_08182022_205351'
model = ConditionalModel(n_steps)
model.load_state_dict(torch.load(filename))
model.eval()


x_0 = dataset
shape_idx = 0
shape_idx2 =2
timestep = 5000


# x = np.linspace(-1, 1, 100, dtype=np.float32)
# y = np.linspace(-1, 1, 100, dtype=np.float32)
# z = np.linspace(-1, 1, 100, dtype=np.float32)
# P = np.vstack(np.meshgrid(x,y,z)).reshape(3,-1).T  # format: [[x1, y1, z1], [x1, y1, z2], [] ...]
# P = torch.from_numpy(P)

q = (x_0[shape_idx,:] + x_0[shape_idx2,:])/2


original_code_standardized=std*q+mean
volume = mod(original_code_standardized, P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
print('original')
mp.plot(verts, faces)


noised_code = q_sample(q, torch.tensor([timestep]))
noised_code_standardized = std*noised_code+mean
volume = mod(noised_code_standardized, P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
print('noisy')
mp.plot(verts, faces)


denoised_code = noised_code
for i in reversed(range(timestep)):
    denoised_code = p_sample(model, denoised_code, i)
denoised_code_standardized = std*denoised_code+mean
volume = mod(denoised_code_standardized,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
print('denoised 1')
mp.plot(verts, faces)




denoised_code = noised_code
for i in reversed(range(timestep)):
    denoised_code = p_sample(model, denoised_code, i)
denoised_code_standardized = std*denoised_code+mean
volume = mod(denoised_code_standardized,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
print('denoised 2')
mp.plot(verts, faces)


denoised_code = noised_code
for i in reversed(range(timestep)):
    denoised_code = p_sample(model, denoised_code, i)
denoised_code_standardized = std*denoised_code+mean
volume = mod(denoised_code_standardized,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
print('denoised 3')
mp.plot(verts, faces)

original


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.219427…

noisy


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.920176…

denoised 1


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.040920…

denoised 2


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(48.992958…

denoised 3


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(50.500508…

<meshplot.Viewer.Viewer at 0x7f3ccdd7aa60>

In [45]:
filename = './diffusion logs/conditional model_08182022_205351'
model = ConditionalModel(n_steps)
model.load_state_dict(torch.load(filename))
model.eval()


x_0 = dataset
shape_idx = 5
timestep = 5000


# x = np.linspace(-1, 1, 100, dtype=np.float32)
# y = np.linspace(-1, 1, 100, dtype=np.float32)
# z = np.linspace(-1, 1, 100, dtype=np.float32)
# P = np.vstack(np.meshgrid(x,y,z)).reshape(3,-1).T  # format: [[x1, y1, z1], [x1, y1, z2], [] ...]
# P = torch.from_numpy(P)


original_code_standardized=std*x_0[shape_idx,:]+mean
volume = mod(original_code_standardized, P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
print('original')
mp.plot(verts, faces)


noised_code = q_sample(x_0[shape_idx, :], torch.tensor([timestep]))
noised_code_standardized = std*noised_code+mean
volume = mod(noised_code_standardized, P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
print('noisy')
mp.plot(verts, faces)


denoised_code = noised_code
for i in reversed(range(timestep)):
    denoised_code = p_sample(model, denoised_code, i)
denoised_code_standardized = std*denoised_code+mean
volume = mod(denoised_code_standardized,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
print('denoised')
mp.plot(verts, faces)



denoised_code = noised_code
for i in reversed(range(timestep)):
    denoised_code = p_sample(model, denoised_code, i)
denoised_code_standardized = std*denoised_code+mean
volume = mod(denoised_code_standardized,P).view(len(x), len(y), len(z)) 
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
print('denoised')
mp.plot(verts, faces)

original


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.498733…

noisy


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.007408…

denoised


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.473046…

denoised


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(49.500925…

<meshplot.Viewer.Viewer at 0x7f3cbe8ea760>

In [None]:
# for constant random vector

x_0 = dataset[:5,:]
q=std*x_0[0,:]+mean
volume = mod(q,P).view(len(x), len(y), len(z)) 
# marching cube to visualize
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)
t=torch.tensor([600])
a = extract(alphas_bar_sqrt, t, x_0)
# eps multiplier
am1 = extract(one_minus_alphas_bar_sqrt, t, x_0)
#e = torch.randn_like(x_0)
# model input
xx = x_0 * a + e * am1
q=std*xx[0,:]+mean
volume = mod(q,P).view(len(x), len(y), len(z)) 

# marching cube to visualize
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)

#output = model(x, t)
cur_x = xx
for i in reversed(range(600)):
    #for j in range(10):
    cur_x = p_sample(model, cur_x, i)
    #x_seq.append(cur_x)
q=std*cur_x[0,:]+mean
volume = mod(q,P).view(len(x), len(y), len(z)) 

# marching cube to visualize
volume = volume.detach().numpy()
verts, faces, normals, values = measure.marching_cubes(volume, 0)
mp.plot(verts, faces)