In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
from Model import PointNeXt
# from dataset.FFDshape import FFDshape_ptp
from Loss import reg_loss
from Transforms import PCDPretreatment, get_data_augment
import numpy as np
from numpy.core.umath import isnan
from Parameters import *
sys.path.insert(1, os.path.dirname(os.path.abspath(__name__)))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_cfg = MODEL_CONFIG['basic_c']
max_input = model_cfg['max_input']
normal = model_cfg['normal']
model = PointNeXt(model_cfg).to(device=device)

checkpoint_name = 'PointNeXt_shapesffd3_epoch1000.pth'
checkpoint_dir = 'result_train//PointNeXt_model=basic_c_ds=shapesffd3_aug=basic_lr=0.001_wd=0.0001_bs=16_AdamW_cosine//'
checkpoint_file = checkpoint_dir + checkpoint_name
checkpoint = torch.load(checkpoint_file, map_location=device)
model.load_state_dict(checkpoint['model'])
model.eval() 

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


PointNeXt(
  (mlp): Conv1d(6, 32, kernel_size=(1,), stride=(1,))
  (stage): ModuleList(
    (0): Stage(
      (sa): SetAbstraction(
        (mlp): Sequential(
          (0): Conv2d(35, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
      (irm): Sequential(
        (0): InvResMLP(
          (la): LocalAggregation(
            (mlp): Sequential(
              (0): Conv2d(67, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
          )
          (pw_conv): Sequential(
            (0): Conv1d(64, 256, kernel_size=(1,), stride=(1,), bias=False)
            (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3):

In [2]:

class FFDshape_eval(Dataset):
    def __init__(self, root, transforms=None, split='test', npoints=1024, augment=False, dp=False, normalize=False):
        assert(split == 'train' or split == 'test')
        self.npoints = npoints
        self.transforms = transforms
        self.train_files_list = []
        self.test_files_list = []
        if split == 'train':
            self.training = True
        elif split == 'test':
            self.training = False
        
        name_list = os.listdir(os.path.join(root,'pc'))
        for i in range(len(name_list)):
            name_list[i] = os.path.splitext(name_list[i])[0]
        
        test_files_list = self.read_list_file(name_list, root)
        self.test_files_list = test_files_list

        # self.train_files_list = train_files_list # train_files_list

        self.caches = {}
        print(
            f'Training {len(self.train_files_list)} shapes. Testing {len(self.test_files_list)} shapes '
        )

    def read_list_file(self, name_list, root):
        # base = os.path.dirname(file_path)
        files_list = []
        for shape_name in name_list:
            cur = os.path.join(root, 'pc', '{}.txt'.format(shape_name))
            files_list.append(cur)
        return files_list


    def __getitem__(self, index):
        if index in self.caches:
            return self.caches[index]
        file = self.pcd[index]
        pc = np.loadtxt(file, delimiter=',').astype(np.float32)
        xyz_points = pc[:, :6]
        gts = pc[:, 6]
        s_mesh = pc[:, -1]

        # resample
        # choice = np.random.choice(len(xyz_points), self.npoints, replace=True)
        # xyz_points = xyz_points[choice, :]
        # gts = gts[choice]

        xyz_points = torch.from_numpy(xyz_points).float()
        gts = torch.from_numpy(gts).float()
        if self.transforms is not None:
            xyz_points, gts = self.transforms(xyz_points, gts)

        return xyz_points, gts, s_mesh

    def __len__(self):
        return len(self.pcd)

    def train(self):
        self.training = True
        self.pcd = self.train_files_list
        if self.transforms is not None:
            self.transforms.set_mode('train')

    def eval(self):
        self.training = False
        self.pcd = self.test_files_list
        if self.transforms is not None:
            self.transforms.set_mode('eval')

In [3]:
# read *.stl
def stl_read(shape_file):
    import open3d as o3d
    my_mesh = o3d.io.read_triangle_mesh(shape_file)
    Points = np.asarray(my_mesh.vertices)
    Connectivity = np.asarray(my_mesh.triangles)

    # rescale to 25 m in X-axis
    rate = 25/((Points[:,0]).max()-(Points[:,0]).min())
    Points = Points * rate
    Points[:,0] = Points[:,0] - Points[:,0].mean()
    
    return Points, Connectivity, rate

# more feature of pc
def ex_feature(Points, Connectivity):
    normals = np.zeros(np.shape(Connectivity))
    for i in range(len(Connectivity)):
        temp1 = np.cross(Points[Connectivity[i,2],:] - Points[Connectivity[i,0],:],
            Points[Connectivity[i,1],:] - Points[Connectivity[i,0],:]); # mesh outer direction
        normals[i,:] = temp1/np.linalg.norm(temp1,ord=2)
    normals = -normals
    
    # tri face points
    Points_tri = np.zeros(np.shape(Connectivity))
    for i in range(len(Connectivity)):
        Points_tri[i,:] = np.mean(Points[Connectivity[i,:],:],axis=0)

    # cal tri-area
    tri_point = np.zeros((np.size(Connectivity,0),3,3))
    side_len = np.zeros((np.size(Connectivity,0),3))
    for i in range(3):
        tri_point[:,:,i] = Points[Connectivity[:,i],:]

    for i in range(3):
        tmp = np.array([i,i+1]).astype(int)
        tmp[tmp>=3] = 0
        side_len[:,i] = np.sqrt(np.sum(
            (tri_point[:,:,tmp[0]] - tri_point[:,:,tmp[1]])**2
            ,1))
    side_p = np.sum(side_len,1)/2
    tri_area = np.sqrt(side_p*
        (side_p-side_len[:,0]) *
        (side_p-side_len[:,1]) *
        (side_p-side_len[:,2]))
    tri_area[isnan(tri_area)]=0
    return normals, Points_tri, tri_area

# rotate Y-axis
def rotate_pc(pc, rotation_angle):
    
    cosval = np.cos(np.deg2rad(rotation_angle))
    sinval = np.sin(np.deg2rad(rotation_angle))
    rotation_matrix = np.array([[cosval, 0, sinval],
                                [0, 1, 0],
                                [-sinval, 0, cosval]])
    rotated_pc = np.dot(pc, rotation_matrix)
    return rotated_pc

In [4]:
shape_base = 'D:\MyCode\PCdeep_TL\ShapeGenerator\shape_set\waverider\\testing_N200_D60_322\\'
shape_file = shape_base + 'shape_001.stl'

# Tri-reading
# if shape_file[-3:] == 'stl' or 'ply':
#     Points_init, Connectivity, rate = stl_read(shape_file)

In [5]:
pc_root = 'D:\MyCode\PCdeep_TL\shapesffd4\waverider\\testing_N200_D60_322'
# transforms = PCDPretreatment(num=1024, down_sample='random', normal=model_cfg['normal'])
transforms = None
dataset = FFDshape_eval(root=pc_root,split='test',transforms=transforms)
dataset.eval()
eval_dataloader = DataLoader(dataset=dataset,
                                batch_size=1,
                                num_workers=0,
                                pin_memory=False,
                                drop_last=False,
                                shuffle=False)
criterion = reg_loss().to(device)


Training 0 shapes. Testing 13 shapes 


In [6]:
count = 0
for data in eval_dataloader:
    pcd, gts, s = data
    pcd, gts = pcd.to(device, non_blocking=True), gts.to(device, non_blocking=True)
    # with torch.no_grad():
    #     pred = model(pcd)
    #     pred = torch.squeeze(pred,1)
    #     loss = criterion(pred, gts)
    count = count + 1
    if count>=1:
        break

Points_init = np.array(pcd.squeeze(0).to('cpu'))
import open3d as o3d
from math import floor

In [7]:
# # uniform down sample
# num_points = 1024
# pcd_o3d = o3d.geometry.PointCloud()
# pcd_o3d.points = o3d.utility.Vector3dVector(Points_init[:,:3])
# pcd_o3d.normals = o3d.utility.Vector3dVector(Points_init[:,3:6])
# pcd_new = o3d.geometry.PointCloud.uniform_down_sample(pcd_o3d, floor(len(Points_init)/num_points))# 508
# new_points = np.array(pcd_new.points,dtype=float)
# new_normals = np.array(pcd_new.normals,dtype=float)
# new_points, new_normals = new_points[:num_points], new_normals[:num_points]
# pcd_new.points = o3d.utility.Vector3dVector(new_points)
# # pcd_new = torch.tensor(np.concatenate((new_points,new_normals),axis=1))
# pcd_new = torch.from_numpy(np.concatenate((new_points,new_normals),axis=1)).float()

In [8]:
# random choose
num_points = 1024
choice_idx = torch.randperm(pcd.shape[1])[:num_points]
pcd_new = Points_init[choice_idx,:]
pcd_new = torch.from_numpy(pcd_new).float()

In [9]:
with torch.no_grad():
    pred = model(pcd_new.T.unsqueeze(0).to(device))
    pred = torch.squeeze(pred,1).T
    pred = pred.to('cpu')
    pcd_new = pcd_new.to('cpu')
    # loss = criterion(pred, gts)

In [None]:
from scipy.interpolate import griddata
interp_pred = griddata(pcd_new[:,:3], pred.squeeze(1), Points_init[:,:3], method='linear')

In [None]:
interp_pred.shape

In [None]:
import plotly.express as px
fig = px.scatter(x=pred.squeeze(1),
                 y=gts.to('cpu')[0,choice_idx]
                )
fig.show()

In [None]:
import plotly.express as px
fig = px.scatter(x=interp_pred,
                 y=gts.to('cpu')[0]
                )
fig.show()

In [None]:
gts.shape

In [None]:
import plotly.graph_objects as go

tmp = gts.to('cpu')[0,choice_idx]
fig = go.Figure(data=[go.Scatter3d(
    x=pcd_new[:,0].to('cpu'),
    y=pcd_new[:,1].to('cpu'),
    z=pcd_new[:,2].to('cpu'),
    mode='markers',
    marker=dict(
        size=3,
        color=tmp,   # pred[:,0].to('cpu')
        colorscale='Viridis',   # choose a colorscale
        opacity=1
    )
)])

# tight layout
fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()

In [None]:
import plotly.graph_objects as go

# Helix equation

fig = go.Figure(data=[go.Scatter3d(
    x=Points_init[:,0],
    y=Points_init[:,1],
    z=Points_init[:,2],
    mode='markers',
    marker=dict(
        size=1,
        color=interp_pred,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=1
    )
)])

# tight layout
fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
fig.show()