In [None]:
import torch
from models.common import *
from models.vae_gaussian import *
from models.vae_flow import *
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from torch.utils.data import DataLoader, TensorDataset

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Load the pretrained models

In [None]:
ckpt = torch.load('logs_gen/Line_sim/ckpt_0.000000_1000000.pt', map_location=torch.device(device))
model_sim = GaussianVAE(ckpt['args']).to(device)
model_sim.load_state_dict(ckpt['state_dict'])

In [None]:
ckpt = torch.load('logs_gen/Line_exp/ckpt_0.000000_1000000.pt', map_location=torch.device(device))
model_exp = GaussianVAE(ckpt['args']).to(device)
model_exp.load_state_dict(ckpt['state_dict'])

## Load Datasets for translation

In [None]:
# Load the data
sim_data = np.load('data/toy/line.npy')
exp_data = np.load('data/toy/line_noisy.npy')
# Convert to PyTorch Tensors
sim_data = torch.from_numpy(sim_data).float()
exp_data = torch.from_numpy(exp_data).float()

# Create TensorDatasets
sim_dset = TensorDataset(sim_data)
exp_dset = TensorDataset(exp_data)

exp_loader = DataLoader(
    exp_dset,
    batch_size=10,
    num_workers=0,
)
sim_loader = DataLoader(
    sim_dset,
    batch_size=10,
    num_workers=0,
)

## Util Functions for Cycle Diffusion

In [None]:
def forward_diffusion(x_0, model):
    """
    Simulates the forward diffusion process in a diffusion model.

    Parameters:
        x_0 (torch.Tensor): The initial input tensor representing the starting point cloud.
        model (object): The diffusion model containing the variance schedule and other parameters.

    Returns:
        list: A list of tensors representing the trajectory of the forward diffusion process. 
              Each tensor corresponds to an intermediate noisy point cloud at a given timestep.
    """
    diffusion = model.diffusion
    num_steps = diffusion.var_sched.num_steps

    trajectory = [x_0]

    for t in range(1, num_steps + 1):
        beta = diffusion.var_sched.betas[t]
        c0 = torch.sqrt(beta).view(-1, 1, 1)       
        c1 = torch.sqrt(1 - beta).view(-1, 1, 1)   
        e_rand = torch.randn_like(x_0)                  
        x_t = c1 * trajectory[-1] + c0 * e_rand         
        trajectory.append(x_t)

    return trajectory

In [None]:
def forward_diffusion(x_0, model):
    """
    Simulates the forward diffusion process in a diffusion model.

    Parameters:
        x_0 (torch.Tensor): The initial input tensor representing the starting point cloud.
        model (object): The diffusion model containing the variance schedule and other parameters.

    Returns:
        list: A list of tensors representing the trajectory of the forward diffusion process. 
              Each tensor corresponds to an intermediate noisy point cloud at a given timestep.
    """
    diffusion = model.diffusion
    num_steps = diffusion.var_sched.num_steps

    trajectory = [x_0]

    for t in range(1, num_steps + 1):
        beta = diffusion.var_sched.betas[t]
        c0 = torch.sqrt(beta).view(-1, 1, 1)       
        c1 = torch.sqrt(1 - beta).view(-1, 1, 1)   
        e_rand = torch.randn_like(x_0)                  
        x_t = c1 * trajectory[-1] + c0 * e_rand         
        trajectory.append(x_t)

    return trajectory

In [None]:
def visualize_point_cloud(point_cloud, sample_index=0):
    """
    Visualize point cloud with charges

    Parameters:
        point_cloud (torch.Tensor): The initial input batch of tensors representing point clouds.
        sample_index: the index of the point cloud to visualize in the batch
    """
    # Extract the sample
    sample = point_cloud[sample_index]

    # Extract x, y, z coordinates
    x = sample[:, 0].numpy()
    y = sample[:, 1].numpy()
    z = sample[:, 2].numpy()
    c = sample[:, 3].numpy()

    # Create a 3D plot
    fig = plt.figure(figsize=(6,4))
    ax = fig.add_subplot(111, projection='3d')

    # Scatter plot
    ax.scatter(x, y, z, c = c, s=2, cmap=plt.cool())

    # Setting labels
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('Z axis')

    # Show the plot
    plt.show()

In [None]:
def dpm_encoder(x_0, model, context):
    """
    Implementation of the DPM_Encoder.

    This function generates a noisy trajectory of point clouds from an initial point cloud `x_0` and
    then encodes it using a diffusion model by computing the corresponding noise vectors for each
    step in the diffusion process. The function returns the final noisy point cloud and a tensor
    of noise vectors.

    Parameters:
        x_0 (torch.Tensor): The initial input tensor representing the starting point cloud.
        model (object): The diffusion model containing the network and variance schedule.
        context (torch.Tensor): the latent from pointnet encoder of the DPM model.

    Returns:
        tuple: A tuple containing:
            - x_T (torch.Tensor): The final noisy point cloud after the forward diffusion process.
            - eps_tensor (torch.Tensor): A tensor containing the noise vectors for each step in the diffusion process.
    """
    with torch.no_grad():
        batch_size, num_point, dim = x_0.shape
        traj = forward_diffusion(x_0, model)
        diffusion = model.diffusion
        
        epsilon_list = []
        x_T = traj[-1]  # The final noisy point cloud
        for t in range(len(traj)-1, 0, -1):
            
            alpha = diffusion.var_sched.alphas[t]
            alpha_bar = diffusion.var_sched.alpha_bars[t]
            sigma = diffusion.var_sched.get_sigmas(t, flexibility = 0.1)
            
            c0 = 1.0 / torch.sqrt(alpha)
            c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
            
            beta = diffusion.var_sched.betas[[t]*batch_size]
            e_theta = diffusion.net(traj[t], beta=beta, context=context)
            mean = c0 * (traj[t] - c1 * e_theta)
            epsilon = (traj[t-1] - mean) / sigma
            epsilon_list.append(epsilon)
        # Convert the list of noise tensors to a single tensor
        eps_tensor = torch.stack(epsilon_list, dim=1)
    return x_T, eps_tensor 

In [None]:
def sample(model, num_points, context, x_T, eps, point_dim=4, flexibility=0.1, ret_traj=False):
        """
        Sample point cloud with the DPM.

        Parameters:
            model (object): The diffusion model containing the network and variance schedule.
            num_points (int): The number of points in the final point cloud.
            context (torch.Tensor): The context tensor for conditioning the model.
            x_T (torch.Tensor): The initial noisy point cloud.
            eps (torch.Tensor): A tensor containing the noise vectors for each step in the reverse diffusion process.
            point_dim (int): The dimensionality of each point in the point cloud (default is 4).
            flexibility (float): The flexibility parameter for adjusting the diffusion schedule (default is 0.1).
            ret_traj (bool): Whether to return the trajectory of intermediate steps (default is False).

        Returns:
            Union[dict, torch.Tensor]: The trajectory of intermediate steps as a dictionary if `ret_traj` is True, 
                                    or the final sampled point cloud as a tensor if `ret_traj` is False.
        """
        batch_size = context.size(0)
        diffusion = model.diffusion
        traj = {diffusion.var_sched.num_steps: x_T}
        for t in range(diffusion.var_sched.num_steps, 0, -1):
            z = eps[:, diffusion.var_sched.num_steps-t] if t > 1 else torch.zeros_like(x_T)
            
            alpha = diffusion.var_sched.alphas[t]
            alpha_bar = diffusion.var_sched.alpha_bars[t]
            sigma = diffusion.var_sched.get_sigmas(t, flexibility)

            c0 = 1.0 / torch.sqrt(alpha)
            c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)

            x_t = traj[t]
            # print(x_t.shape)
            beta = diffusion.var_sched.betas[[t]*batch_size]
            e_theta = diffusion.net(x_t, beta=beta, context=context)
            x_next = c0 * (x_t - c1 * e_theta) + sigma * z
            # print("NEXT: ", x_next.shape)
            traj[t-1] = x_next.detach()     # Stop gradient and save trajectory.
            traj[t] = traj[t].cpu()         # Move previous output to CPU memory.
            if not ret_traj:
                del traj[t]
        
        if ret_traj:
            return traj
        else:
            return traj[0]

In [None]:
def visualize_point_clouds_side_by_side(point_cloud1, point_cloud2, index, sample_index=0):
    """
    Visualize point cloud with charges for two point clouds side by side

    Parameters:
        point_cloud1, point_cloud2, (torch.Tensor): The initial input batch of tensors representing point clouds.
        sample_index: the index of the point cloud to visualize in the batch
    """
    # Extract the samples
    sample1 = point_cloud1[sample_index]
    sample2 = point_cloud2[sample_index]

    # Extract x, y, z coordinates for both samples
    x1, y1, z1, c1 = sample1[:, 0].numpy(), sample1[:, 1].numpy(), sample1[:, 2].numpy(), sample1[:, 3].numpy()
    x2, y2, z2, c2 = sample2[:, 0].numpy(), sample2[:, 1].numpy(), sample2[:, 2].numpy(), sample2[:, 3].numpy()
    
    # Create a figure and a set of subplots
    fig, axs = plt.subplots(1, 2, figsize=(10, 5), subplot_kw={'projection': '3d'})
    
    # Scatter plot for the first point cloud
    pc1 = axs[0].scatter(x1, y1, z1, s=2, cmap='cool', vmin=0, vmax=8000)
    axs[0].set_title('Original', weight='bold')
    axs[0].set_xlabel('X', weight='bold')
    axs[0].set_ylabel('Y', weight='bold')
    axs[0].set_zlabel('Z', weight='bold')
    axs[0].set_xlim([-1, 1])
    axs[0].set_ylim([0, 2])
    axs[0].set_zlim([-2, 2])
    axs[0].set_xticks([-1, -0.5, 0, 0.5, 1])
    axs[0].set_yticks([0, 0.5, 1, 1.5, 2])
    axs[0].set_zticks([-2, -1, 0, 1, 2])
    axs[0].set_xticklabels([-1, -0.5, 0, 0.5, 1], fontweight='bold')
    axs[0].set_yticklabels([0, 0.5, 1, 1.5, 2], fontweight='bold')
    axs[0].set_zticklabels([-2, -1, 0, 1, 2], fontweight='bold')
    
    # Scatter plot for the second point cloud
    pc2 = axs[1].scatter(x2, y2, z2, s=2, cmap='cool', vmin=0, vmax=8000)
    axs[1].set_title('Translation', weight='bold')
    axs[1].set_xlabel('X', weight='bold')
    axs[1].set_ylabel('Y', weight='bold')
    axs[1].set_zlabel('Z', weight='bold')
    axs[1].set_xlim([-1, 1])
    axs[1].set_ylim([0, 2])
    axs[1].set_zlim([-2, 2])
    axs[1].set_xticks([-1, -0.5, 0, 0.5, 1])
    axs[1].set_yticks([0, 0.5, 1, 1.5, 2])
    axs[1].set_zticks([-2, -1, 0, 1, 2])
    axs[1].set_xticklabels([-1, -0.5, 0, 0.5, 1], fontweight='bold')
    axs[1].set_yticklabels([0, 0.5, 1, 1.5, 2], fontweight='bold')
    axs[1].set_zticklabels([-2, -1, 0, 1, 2], fontweight='bold')
    
    # Show the plot
    plt.tight_layout()
    plt.savefig(f"plot{index}.png", dpi=500)
    plt.show()


In [None]:
def visualize_point_clouds_reconstruction(point_cloud1, point_cloud2, point_cloud3, index, sample_index=0):
    # Extract the samples
    sample1 = point_cloud1[sample_index]
    sample2 = point_cloud2[sample_index]
    sample3 = point_cloud3[sample_index]

    # Extract x, y, z coordinates for both samples
    x1, y1, z1, c1 = sample1[:, 0].numpy(), sample1[:, 1].numpy(), sample1[:, 2].numpy(), sample1[:, 3].numpy()
    x2, y2, z2, c2 = sample2[:, 0].numpy(), sample2[:, 1].numpy(), sample2[:, 2].numpy(), sample2[:, 3].numpy()
    x3, y3, z3, c3 = sample3[:, 0].numpy(), sample3[:, 1].numpy(), sample3[:, 2].numpy(), sample3[:, 3].numpy()

    # Scaling and transformation
    x1 *= 250
    x2 *= 250
    x3 *= 250
    
    y1 *= 250
    y2 *= 250
    y3 *= 250
    
    z1 = 500 * (z1 - 1)
    z2 = 500 * (z2 - 1)
    z3 = 500 * (z3 - 1)
    
    c1 = 10 ** c1
    c2 = 10 ** c2
    c3 = 10 ** c3

    # Create a figure and a set of subplots
    fig, axs = plt.subplots(1, 3, figsize=(12, 6), subplot_kw={'projection': '3d'})
    
    # Scatter plot for the first point cloud
    pc1 = axs[0].scatter(x1, z1, y1, s=2, c=c1, cmap='cool', vmin=0, vmax=8000)  
    axs[0].set_title('Original', weight='bold')
    axs[0].set_xlabel('X')
    axs[0].set_ylabel('Z')
    axs[0].set_zlabel('Y')
    axs[0].set_xlim([-500, 500])
    axs[0].set_ylim([-500, 800])
    axs[0].set_zlim([-500, 500])
    
    # Scatter plot for the second point cloud
    pc2 = axs[1].scatter(x2, z2, y2, s=2, c=c2, cmap='cool', vmin=0, vmax=8000)  
    axs[1].set_title('Translation', weight='bold')
    axs[1].set_xlabel('X')
    axs[1].set_ylabel('Z')
    axs[1].set_zlabel('Y')
    axs[1].set_xlim([-500, 500])
    axs[1].set_ylim([-500, 800])
    axs[1].set_zlim([-500, 500])
    
    # Scatter plot for the third point cloud (Reconstruction)
    pc3 = axs[2].scatter(x3, z3, y3, s=2, c=c2, cmap='cool', vmin=0, vmax=8000) 
    axs[2].set_title('Reconstruction', weight='bold')
    axs[2].set_xlabel('X')
    axs[2].set_ylabel('Z')
    axs[2].set_zlabel('Y')
    axs[2].set_xlim([-500, 500])
    axs[2].set_ylim([-500, 800])
    axs[2].set_zlim([-500, 500])
    
    # Show the plot and save
    plt.subplots_adjust(wspace=0.3)
    cbar = fig.colorbar(pc3, ax=axs, shrink=0.5, aspect=10)
    plt.savefig(f"plot{index}.png", dpi=500)
    plt.show()


## Standard deviation check
Used to verify if our model can recognize noise distribution. This is used for the line toy dataset

In [None]:
#sim to sim
for i, batch in enumerate(exp_loader):
    data = batch[0].to(device)
    mu_sim, sigma_sim = model_sim.encoder(data)
    context_sim = reparameterize_gaussian(mean=mu_sim, logvar=sigma_sim)
    
    mu_exp, sigma_exp = model_exp.encoder(data)
    context_exp = reparameterize_gaussian(mean=mu_exp, logvar=sigma_exp)
    
    x_T, eps = dpm_encoder(data, model_exp, context_exp)
    x_T = x_T.to(device)
    eps = eps.to(device)
    print("successfully encoded")
    x = sample(model_exp, 512, context_exp, x_T, eps)
    
    for i in range(len(x)):
        visualize_point_clouds_side_by_side(data, x, sample_index=i)
    break

## Translation of Point Clouds

In [None]:
#sim to exp
for i, batch in enumerate(sim_loader):
    data = batch[0].to(device)
    mu_sim, sigma_sim = model_sim.encoder(data)
    context_sim = reparameterize_gaussian(mean=mu_sim, logvar=sigma_sim)
    
    mu_exp, sigma_exp = model_exp.encoder(data)
    context_exp = reparameterize_gaussian(mean=mu_exp, logvar=sigma_exp)
    
    x_T, eps = dpm_encoder(data, model_sim, context_sim)
    x_T = x_T.to(device)
    eps = eps.to(device)
    print("successfully encoded")
    x = sample(model_exp, 512, context_exp, x_T, eps)
    
    for i in range(len(x)):
        visualize_point_clouds_side_by_side(data, x, i, sample_index=i)
    break

In [None]:
#exp to sim
for i, batch in enumerate(exp_loader):
    data = batch[0].to(device)
    mu_sim, sigma_sim = model_sim.encoder(data)
    context_sim = reparameterize_gaussian(mean=mu_sim, logvar=sigma_sim)
    
    mu_exp, sigma_exp = model_exp.encoder(data)
    context_exp = reparameterize_gaussian(mean=mu_exp, logvar=sigma_exp)
    
    x_T, eps = dpm_encoder(data, model_exp, context_exp)
    x_T = x_T.to(device)
    eps = eps.to(device)
    print("successfully encoded")
    x = sample(model_sim, 512, context_sim, x_T, eps)
    
    for i in range(len(x)):
        visualize_point_clouds_side_by_side(data, x, i, sample_index=i)
    break

## Translation of Point Clouds With Reconstruction Test

In [None]:
#reconstruction test
for i, batch in enumerate(exp_loader):
    if i == 1:
        data = batch[0].to(device)
        mu_sim, sigma_sim = model_sim.encoder(data)
        context_sim = reparameterize_gaussian(mean=mu_sim, logvar=sigma_sim)

        mu_exp, sigma_exp = model_exp.encoder(data)
        context_exp = reparameterize_gaussian(mean=mu_exp, logvar=sigma_exp)

        x_T, eps = dpm_encoder(data, model_exp, context_exp)
        x_T = x_T.to(device)
        eps = eps.to(device)
        print("successfully encoded")
        x = sample(model_sim, 512, context_sim, x_T, eps)

        mu_sim2, sigma_sim2 = model_sim.encoder(x)
        context_sim2 = reparameterize_gaussian(mean=mu_sim2, logvar=sigma_sim2)

        mu_exp2, sigma_exp2 = model_exp.encoder(x)
        context_exp2 = reparameterize_gaussian(mean=mu_exp2, logvar=sigma_exp2)

        x_T2, eps2 = dpm_encoder(x, model_sim, context_sim2)
        x_T2 = x_T2.to(device)
        eps2 = eps2.to(device)
        print("successfully encoded reconstuction")
        x2 = sample(model_exp, 512, context_exp2, x_T2, eps2)

        for i in range(len(x)):
            visualize_point_clouds_reconstruction(data, x, x2, i, sample_index=i)
        break

In [None]:
#reconstruction test
for i, batch in enumerate(sim_loader):
    data = batch[0].to(device)
    mu_sim, sigma_sim = model_sim.encoder(data)
    context_sim = reparameterize_gaussian(mean=mu_sim, logvar=sigma_sim)
    
    mu_exp, sigma_exp = model_exp.encoder(data)
    context_exp = reparameterize_gaussian(mean=mu_exp, logvar=sigma_exp)
    
    x_T, eps = dpm_encoder(data, model_sim, context_sim)
    x_T = x_T.to(device)
    eps = eps.to(device)
    print("successfully encoded")
    x = sample(model_exp, 512, context_exp, x_T, eps)
    
    mu_sim2, sigma_sim2 = model_sim.encoder(x)
    context_sim2 = reparameterize_gaussian(mean=mu_sim2, logvar=sigma_sim2)
    
    mu_exp2, sigma_exp2 = model_exp.encoder(x)
    context_exp2 = reparameterize_gaussian(mean=mu_exp2, logvar=sigma_exp2)
    
    x_T2, eps2 = dpm_encoder(x, model_exp, context_exp2)
    x_T2 = x_T2.to(device)
    eps2 = eps2.to(device)
    print("successfully encoded reconstuction")
    x2 = sample(model_sim, 512, context_sim, x_T2, eps2)
    
    for i in range(len(x)):
        visualize_point_clouds_reconstruction(data, x, x2, i, sample_index=i)
    break