In [1]:
import sys
import os

PROJ_ROOT = '/data/lry/code/style3d_gen'

sys.path.append(PROJ_ROOT)
sys.path.append(os.path.join(PROJ_ROOT, 'src'))
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

# Test VAE (decoder only)

In [None]:
import torch
from network import AutoencoderKLFastDecode, AutoencoderKLFastEncode

import torch
from torchvision.utils import make_grid

from matplotlib import pyplot as plt

import random
import pickle

cache_fp = '/data/lry/code/style3d_gen/log/stylexd_vae_surf_256_xyz_uv_mask_unet6/cache/epoch_1800/surfz_train.pkl'
with open(cache_fp, 'rb') as f: data_cache = pickle.load(f)
print(data_cache.keys(), len(data_cache['item_idx']))

ckpt_fp = '/data/lry/code/style3d_gen/log/stylexd_vae_surf_256_xyz_uv_mask_unet6/ckpts/epoch_1800.pt'
block_dims = [16,32,32,64,64,128]
sample_size = 256
latent_channels = 8
latent_size = sample_size // 2**(len(block_dims)-1)
print('*** latent size: ', latent_size)

surf_vae_decoder = AutoencoderKLFastDecode( 
                                    in_channels=6,
                                    out_channels=6,
                                    down_block_types=['DownEncoderBlock2D']*len(block_dims),
                                    up_block_types=['UpDecoderBlock2D']*len(block_dims),
                                    block_out_channels=block_dims,
                                    layers_per_block=2,
                                    act_fn='silu',
                                    latent_channels=latent_channels,
                                    norm_num_groups=8,
                                    sample_size=sample_size
                                    )
surf_vae_decoder.load_state_dict(torch.load(ckpt_fp), strict=False)


sample_data_idx = random.randint(0, len(data_cache['item_idx']) - 1)
start_idx, end_idx = data_cache['item_idx'][sample_data_idx]

print(sample_data_idx, start_idx, end_idx)

surf_pos = data_cache['surf_pos'][start_idx:end_idx]
surf_latent = data_cache['latent'][start_idx:end_idx]
surf_cls = data_cache['surf_cls'][start_idx:end_idx]
caption = data_cache['caption'][sample_data_idx]

print(surf_pos.shape, surf_latent.shape, surf_cls.shape, caption)

print('*** surf_latent: ', surf_latent.shape, surf_latent.min(), surf_latent.max(), surf_latent.mean(), surf_latent.std())
with torch.no_grad():
    decoded_surf_pos = surf_vae_decoder(surf_latent.view(-1, latent_channels, latent_size, latent_size))
print('*** decoded: ', decoded_surf_pos.shape, decoded_surf_pos.min(), decoded_surf_pos.max())

pred_img = make_grid(decoded_surf_pos, nrow=8, normalize=True, value_range=(-1,1))
print(pred_img.shape, pred_img.min(), pred_img.max())

fig, ax = plt.subplots(3, 1, figsize=(40, 40))
ax[0].imshow(pred_img[:3, ...].permute(1, 2, 0).cpu().numpy())
ax[1].imshow(pred_img[3:, ...].permute(1, 2, 0).cpu().numpy())
ax[2].imshow(pred_img[-1:, ...].permute(1, 2, 0).cpu().numpy())

ax[0].set_title('Geometry Images')
ax[1].set_title('UV Images')
ax[2].set_title('Mask Images')

plt.tight_layout()
plt.axis('off')
plt.show()


# Test VAE with Encoder

In [None]:
import numpy as np
import torch
from network import AutoencoderKLFastDecode, AutoencoderKLFastEncode

from glob import glob

import torch
from torchvision.utils import make_grid

from matplotlib import pyplot as plt

import random
import pickle

_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

ckpt_fp = '/data/lry/code/style3d_gen/log/stylexd_vae_surf_256_xyz_uv_mask_unet6/ckpts/epoch_1800.pt'
block_dims = [16,32,32,64,64,128]
sample_size = 256
latent_channels = 8
latent_size = sample_size // 2**(len(block_dims)-1)
print('*** latent size: ', latent_size, ' latent_channels: ', latent_channels)

################# Inintialize the encoder and decoder #################
surf_vae_encoder = AutoencoderKLFastEncode( 
                                    in_channels=6,
                                    out_channels=6,
                                    down_block_types=['DownEncoderBlock2D']*len(block_dims),
                                    up_block_types=['UpDecoderBlock2D']*len(block_dims),
                                    block_out_channels=block_dims,
                                    layers_per_block=2,
                                    act_fn='silu',
                                    latent_channels=latent_channels,
                                    norm_num_groups=8,
                                    sample_size=sample_size
                                    )
surf_vae_encoder.load_state_dict(torch.load(ckpt_fp), strict=False)
surf_vae_encoder.eval().to(_DEVICE)


surf_vae_decoder = AutoencoderKLFastDecode( 
                                    in_channels=6,
                                    out_channels=6,
                                    down_block_types=['DownEncoderBlock2D']*len(block_dims),
                                    up_block_types=['UpDecoderBlock2D']*len(block_dims),
                                    block_out_channels=block_dims,
                                    layers_per_block=2,
                                    act_fn='silu',
                                    latent_channels=latent_channels,
                                    norm_num_groups=8,
                                    sample_size=sample_size
                                    )
surf_vae_decoder.load_state_dict(torch.load(ckpt_fp), strict=False)
surf_vae_decoder.eval().to(_DEVICE)
########################################################################


data_root = '/data/AIGP/brep_reso_256_edge_snap_with_caption'
data_fp = random.choice(glob(os.path.join(data_root, '*.pkl')))
print(data_fp)

with open(data_fp, 'rb') as f: data = pickle.load(f)
surf_ncs = torch.FloatTensor(np.concatenate([
    data['surf_ncs'].astype(np.float32),
    data['surf_uv_ncs'].astype(np.float32),
    data['surf_mask'].astype(np.float32)*2.0-1.0
], axis=-1)).to(_DEVICE)

print('*** caption: ', data['caption'])

with torch.no_grad():
    latent = surf_vae_encoder(surf_ncs.permute(0, 3, 1, 2)).flatten(start_dim=1)
    print('*** latent: ', latent.shape, latent.min(), latent.max(), latent.mean(), latent.std())
    decoded_surf_ncs = surf_vae_decoder(latent.view(-1, latent_channels, latent_size, latent_size))
    print('*** decoded: ', decoded_surf_ncs.shape, decoded_surf_ncs.min(), decoded_surf_ncs.max())

pred_img = make_grid(decoded_surf_ncs, nrow=8, normalize=True, value_range=(-1,1))
print(pred_img.shape, pred_img.min(), pred_img.max())

fig, ax = plt.subplots(3, 1, figsize=(40, 40))
ax[0].imshow(pred_img[:3, ...].permute(1, 2, 0).cpu().numpy())
ax[1].imshow(pred_img[3:, ...].permute(1, 2, 0).cpu().numpy())
ax[2].imshow(pred_img[-1:, ...].permute(1, 2, 0).cpu().numpy())

ax[0].set_title('Geometry Images')
ax[1].set_title('UV Images')
ax[2].set_title('Mask Images')

plt.tight_layout()
plt.axis('off')
plt.show()


# Test VAE && SurfZ

In [2]:
import random
import pickle
import numpy as np

from tqdm import tqdm
from tqdm import trange

import torch
from torchvision.utils import make_grid

from matplotlib import pyplot as plt
from matplotlib.colors import to_hex

from network import AutoencoderKLFastDecode, SurfZNet
from diffusers import DDPMScheduler, PNDMScheduler
from utils import randn_tensor
from vis import draw_bbox_geometry



def _denormalize_pts(pts, bbox):    
    pos_dim =  pts.shape[-1]
    bbox_min = bbox[..., :pos_dim][:, None, ...]
    bbox_max = bbox[..., pos_dim:][:, None, ...]
    bbox_scale = np.max(bbox_max - bbox_min, axis=-1, keepdims=True) * 0.5
    bbox_offset = (bbox_max + bbox_min) / 2.0
    return pts * bbox_scale + bbox_offset


vae_model_fp = '/data/lry/code/style3d_gen/log/stylexd_vae_surf_256_xyz_uv_mask_unet6_latent_1/ckpts/vae_e550.pt'
surfz_model_fp = '/data/lry/code/style3d_gen/log/stylexd_surfz_xyzuv_mask_latent1_mode/ckpts/surfz_e210000.pt'

output_dir = '/data/lry/code/style3d_gen/generated/surfz_e150000'
os.makedirs(output_dir, exist_ok=True)

# ckpt_fp = '/data/lry/code/dy/checkpoint/6G/ldm/surfz10200.pt'
block_dims = [16,32,32,64,64,128]
sample_size = 256
latent_size = sample_size//(2**(len(block_dims)-1))
latent_channels = 1

surf_vae = AutoencoderKLFastDecode( in_channels=6,
                                    out_channels=6,
                                    down_block_types=['DownEncoderBlock2D']*len(block_dims),
                                    up_block_types=['UpDecoderBlock2D']*len(block_dims),
                                    block_out_channels=block_dims,
                                    layers_per_block=2,
                                    act_fn='silu',
                                    latent_channels=latent_channels,
                                    norm_num_groups=8,
                                    sample_size=sample_size
                                    )
surf_vae.load_state_dict(torch.load(vae_model_fp), strict=False)
surf_vae.to('cuda').eval()

pndm_scheduler = PNDMScheduler(
    num_train_timesteps=1000,
    beta_schedule='linear',
    prediction_type='epsilon',
    beta_start=0.0001,
    beta_end=0.02,
)

ddpm_scheduler = DDPMScheduler(
    num_train_timesteps=1000,
    beta_schedule='linear',
    prediction_type='epsilon',
    beta_start=0.0001,
    beta_end=0.02,
    clip_sample=False,
)

# Load SurfZ Net
surfz_model = SurfZNet(
    p_dim=10, 
    z_dim=latent_size**2*latent_channels, 
    num_heads=12, 
    num_cf=-1
    )
surfz_model.load_state_dict(torch.load(surfz_model_fp)['model_state_dict'])
surfz_model.to('cuda').eval()

def inference_one(surf_pos, surf_cls=None, caption='', output_fp='', vis=False):
    n_surfs, n_pads = surf_pos.shape[0], 32-surf_pos.shape[0]
    # # pad zero
    pad_idx = torch.randperm(n_surfs)
    _surf_mask = torch.cat([
        torch.zeros(n_surfs, dtype=bool), torch.ones(n_pads, dtype=bool)
    ], dim=0)[None, ...]
    _surf_pos = torch.cat([
        surf_pos[pad_idx, ...], torch.zeros((n_pads, *surf_pos.shape[1:]), dtype=surf_pos.dtype, device=surf_pos.device)
    ], dim=0)[None, ...]

    # Diffusion Generation
    _surf_z = randn_tensor((1, 32, latent_channels*latent_size*latent_size), device='cuda')
    ddpm_scheduler.set_timesteps(1000)
    for t in ddpm_scheduler.timesteps:
        timesteps = t.reshape(-1).to('cuda')
        pred = surfz_model(_surf_z, timesteps, _surf_pos.to('cuda'), _surf_mask.to('cuda'), None)
        _surf_z = ddpm_scheduler.step(pred, t, _surf_z).prev_sample
        
    _surf_z = _surf_z.squeeze(0)[~_surf_mask.squeeze(0), ...]

    # VAE Decoding
    with torch.no_grad(): decoded_surf_pos = surf_vae(_surf_z.view(-1, latent_channels, latent_size, latent_size))
    pred_img = make_grid(decoded_surf_pos, nrow=6, normalize=True, value_range=(-1,1))

    if vis:
        fig, ax = plt.subplots(3, 1, figsize=(40, 40))
        ax[0].imshow(pred_img[:3, ...].permute(1, 2, 0).detach().cpu().numpy())
        ax[1].imshow(pred_img[3:, ...].permute(1, 2, 0).detach().cpu().numpy())
        ax[2].imshow(pred_img[-1:, ...].permute(1, 2, 0).detach().cpu().numpy())

        ax[0].set_title('Geometry Images')
        ax[1].set_title('UV Images')
        ax[2].set_title('Mask Images')

        plt.tight_layout()
        plt.axis('off')
                
        if output_fp: plt.savefig(output_fp.replace('.pkl', '_geo_img.png'), transparent=True, dpi=72)
        else: plt.show()
        plt.close()

    # plotly visualization
    n_surfs = decoded_surf_pos.shape[0]
    colormap = plt.cm.rainbow

    _surf_bbox = _surf_pos.squeeze(0)[~_surf_mask.squeeze(0), :].detach().cpu().numpy()
    _decoded_surf_pos = decoded_surf_pos.permute(0, 2, 3, 1).detach().cpu().numpy()
    _surf_ncs_mask = _decoded_surf_pos[..., -1:].reshape(n_surfs, -1) > 0.0
    _surf_ncs = _decoded_surf_pos[..., :3].reshape(n_surfs, -1, 3)
    _surf_uv_ncs = _decoded_surf_pos[..., 3:5].reshape(n_surfs, -1, 2)

    _surf_uv_bbox = _surf_bbox[..., 6:]
    _surf_bbox = _surf_bbox[..., :6]

    if vis:
        colors = [to_hex(colormap(i)) for i in np.linspace(0, 1, n_surfs)]
        _surf_wcs = _denormalize_pts(_surf_ncs, _surf_bbox)
        # _surf_uv_wcs = _denormalize_pts(_surf_uv_ncs, _surf_uv_bbox)
        draw_bbox_geometry(
            bboxes = _surf_bbox, 
            bbox_colors = colors, 
            points = _surf_wcs, 
            point_masks = _surf_ncs_mask, 
            point_colors = colors, 
            num_point_samples = 5000, 
            title = caption,
            output_fp = output_fp.replace('.pkl', '_pointcloud.png')
            )

    result = {
        'surf_bbox': _surf_bbox,        # (N, 6)
        'surf_uv_bbox': _surf_uv_bbox,  # (N, 4)
        'surf_ncs': _surf_ncs,          # (N, 256*256, 3)
        'surf_uv_ncs': _surf_uv_ncs,    # (N, 256*256, 2)
        'surf_mask': _surf_ncs_mask,    # (N, 256*256) => bool
        'caption': caption              # str
    }

    if output_fp: 
        with open(output_fp, 'wb') as f: pickle.dump(result, f)
    
    # print('[DONE] save to:', output_fp)
    

cache_fp = '/data/lry/code/style3d_gen/log/stylexd_vae_surf_256_xyz_uv_mask_unet6_latent_1/cache/vae_e550/encoder_mode/surfpos_validate.pkl'
with open(cache_fp, 'rb') as f: data_cache = pickle.load(f)
for sample_data_idx in trange(len(data_cache['item_idx'])):
    # sample_data_idx = random.randint(0, len(data_cache['item_idx']) - 1)

    start_idx, end_idx = data_cache['item_idx'][sample_data_idx]

    surf_pos = data_cache['surf_pos'][start_idx:end_idx].to('cuda')
    surf_cls = data_cache['surf_cls'][start_idx:end_idx].to('cuda')
    caption = data_cache['caption'][sample_data_idx]
    
    output_fp = os.path.join(output_dir, f'{sample_data_idx:04d}.pkl')
    inference_one(surf_pos, surf_cls, caption, output_fp, vis=True)    

# Test SurfPos && SurfZ && VAE

In [None]:
import random
import pickle
import numpy as np

from tqdm import tqdm
from tqdm import trange

import torch
from torchvision.utils import make_grid

from matplotlib import pyplot as plt
from matplotlib.colors import to_hex

from network import AutoencoderKLFastDecode, SurfZNet
from diffusers import DDPMScheduler, PNDMScheduler
from utils import randn_tensor
from vis import draw_bbox_geometry



def _denormalize_pts(pts, bbox):    
    pos_dim =  pts.shape[-1]
    bbox_min = bbox[..., :pos_dim][:, None, ...]
    bbox_max = bbox[..., pos_dim:][:, None, ...]
    bbox_scale = np.max(bbox_max - bbox_min, axis=-1, keepdims=True) * 0.5
    bbox_offset = (bbox_max + bbox_min) / 2.0
    return pts * bbox_scale + bbox_offset


vae_model_fp = '/data/lry/code/style3d_gen/log/stylexd_vae_surf_256_xyz_uv_mask_unet6_latent_1/ckpts/vae_e550.pt'
surf_pos_model = ''
surfz_model_fp = '/data/lry/code/style3d_gen/log/stylexd_surfz_xyzuv_mask_latent1_mode/ckpts/surfz_e150000.pt'

# ckpt_fp = '/data/lry/code/dy/checkpoint/6G/ldm/surfz10200.pt'
block_dims = [16,32,32,64,64,128]
sample_size = 256
latent_size = sample_size//(2**(len(block_dims)-1))
latent_channels = 1

surf_vae = AutoencoderKLFastDecode( in_channels=6,
                                    out_channels=6,
                                    down_block_types=['DownEncoderBlock2D']*len(block_dims),
                                    up_block_types=['UpDecoderBlock2D']*len(block_dims),
                                    block_out_channels=block_dims,
                                    layers_per_block=2,
                                    act_fn='silu',
                                    latent_channels=latent_channels,
                                    norm_num_groups=8,
                                    sample_size=sample_size
                                    )
surf_vae.load_state_dict(torch.load(vae_model_fp), strict=False)
surf_vae.to('cuda').eval()

pndm_scheduler = PNDMScheduler(
    num_train_timesteps=1000,
    beta_schedule='linear',
    prediction_type='epsilon',
    beta_start=0.0001,
    beta_end=0.02,
)

ddpm_scheduler = DDPMScheduler(
    num_train_timesteps=1000,
    beta_schedule='linear',
    prediction_type='epsilon',
    beta_start=0.0001,
    beta_end=0.02,
    clip_sample=False,
)

# Load SurfZ Net
surfz_model = SurfZNet(
    p_dim=10, 
    z_dim=latent_size**2*latent_channels, 
    num_heads=12, 
    num_cf=-1
    )
surfz_model.load_state_dict(torch.load(surfz_model_fp)['model_state_dict'])
surfz_model.to('cuda').eval()

output_dir = '/data/lry/code/style3d_gen/generated/surfz_e150000'
os.makedirs(output_dir, exist_ok=True)

def inference_one(surf_pos, surf_cls=None, caption='', output_fp='', vis=False):
    n_surfs, n_pads = surf_pos.shape[0], 32-surf_pos.shape[0]
    # # pad zero
    pad_idx = torch.randperm(n_surfs)
    _surf_mask = torch.cat([
        torch.zeros(n_surfs, dtype=bool), torch.ones(n_pads, dtype=bool)
    ], dim=0)[None, ...]
    _surf_pos = torch.cat([
        surf_pos[pad_idx, ...], torch.zeros((n_pads, *surf_pos.shape[1:]), dtype=surf_pos.dtype, device=surf_pos.device)
    ], dim=0)[None, ...]

    # Diffusion Generation
    _surf_z = randn_tensor((1, 32, latent_channels*latent_size*latent_size), device='cuda')
    ddpm_scheduler.set_timesteps(1000)
    for t in ddpm_scheduler.timesteps:
        timesteps = t.reshape(-1).to('cuda')
        pred = surfz_model(_surf_z, timesteps, _surf_pos.to('cuda'), _surf_mask.to('cuda'), None)
        _surf_z = ddpm_scheduler.step(pred, t, _surf_z).prev_sample
        
    _surf_z = _surf_z.squeeze(0)[~_surf_mask.squeeze(0), ...]

    # VAE Decoding
    with torch.no_grad(): decoded_surf_pos = surf_vae(_surf_z.view(-1, latent_channels, latent_size, latent_size))
    pred_img = make_grid(decoded_surf_pos, nrow=6, normalize=True, value_range=(-1,1))

    if vis:
        fig, ax = plt.subplots(3, 1, figsize=(40, 40))
        ax[0].imshow(pred_img[:3, ...].permute(1, 2, 0).detach().cpu().numpy())
        ax[1].imshow(pred_img[3:, ...].permute(1, 2, 0).detach().cpu().numpy())
        ax[2].imshow(pred_img[-1:, ...].permute(1, 2, 0).detach().cpu().numpy())

        ax[0].set_title('Geometry Images')
        ax[1].set_title('UV Images')
        ax[2].set_title('Mask Images')

        plt.tight_layout()
        plt.axis('off')
                
        if output_fp: plt.savefig(output_fp.replace('.pkl', '_geo_img.png'), transparent=True, dpi=72)
        else: plt.show()
        plt.close()

    # plotly visualization
    n_surfs = decoded_surf_pos.shape[0]
    colormap = plt.cm.rainbow

    _surf_bbox = _surf_pos.squeeze(0)[~_surf_mask.squeeze(0), :].detach().cpu().numpy()
    _decoded_surf_pos = decoded_surf_pos.permute(0, 2, 3, 1).detach().cpu().numpy()
    _surf_ncs_mask = _decoded_surf_pos[..., -1:].reshape(n_surfs, -1) > 0.0
    _surf_ncs = _decoded_surf_pos[..., :3].reshape(n_surfs, -1, 3)
    _surf_uv_ncs = _decoded_surf_pos[..., 3:5].reshape(n_surfs, -1, 2)

    _surf_uv_bbox = _surf_bbox[..., 6:]
    _surf_bbox = _surf_bbox[..., :6]

    if vis:
        colors = [to_hex(colormap(i)) for i in np.linspace(0, 1, n_surfs)]
        _surf_wcs = _denormalize_pts(_surf_ncs, _surf_bbox)
        # _surf_uv_wcs = _denormalize_pts(_surf_uv_ncs, _surf_uv_bbox)
        draw_bbox_geometry(
            bboxes = _surf_bbox, 
            bbox_colors = colors, 
            points = _surf_wcs, 
            point_masks = _surf_ncs_mask, 
            point_colors = colors, 
            num_point_samples = 5000, 
            title = caption,
            output_fp = output_fp.replace('.pkl', '_pointcloud.png')
            )

    result = {
        'surf_bbox': _surf_bbox,        # (N, 6)
        'surf_uv_bbox': _surf_uv_bbox,  # (N, 4)
        'surf_ncs': _surf_ncs,          # (N, 256*256, 3)
        'surf_uv_ncs': _surf_uv_ncs,    # (N, 256*256, 2)
        'surf_mask': _surf_ncs_mask,    # (N, 256*256) => bool
        'caption': caption              # str
    }

    if output_fp: 
        with open(output_fp, 'wb') as f: pickle.dump(result, f)
    
    # print('[DONE] save to:', output_fp)
    

cache_fp = '/data/lry/code/style3d_gen/log/stylexd_vae_surf_256_xyz_uv_mask_unet6_latent_1/cache/vae_e550/encoder_mode/surfpos_validate.pkl'
with open(cache_fp, 'rb') as f: data_cache = pickle.load(f)


for sample_data_idx in trange(1100, len(data_cache['item_idx'])):
    # sample_data_idx = random.randint(0, len(data_cache['item_idx']) - 1)

    start_idx, end_idx = data_cache['item_idx'][sample_data_idx]

    surf_pos = data_cache['surf_pos'][start_idx:end_idx].to('cuda')
    surf_cls = data_cache['surf_cls'][start_idx:end_idx].to('cuda')
    caption = data_cache['caption'][sample_data_idx]
    
    output_fp = os.path.join(output_dir, f'{sample_data_idx:04d}.pkl')
    inference_one(surf_pos, surf_cls, caption, output_fp, vis=True)    